mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
5 Commits
jessegross
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ade9205cc | ||
|
|
06edabdde1 | ||
|
|
8b4e5a82a8 | ||
|
|
3445223311 | ||
|
|
fa6c0127e6 |
@@ -41,6 +41,11 @@ type InferenceCompute struct {
|
||||
VRAM string
|
||||
}
|
||||
|
||||
type InferenceInfo struct {
|
||||
Computes []InferenceCompute
|
||||
DefaultContextLength int
|
||||
}
|
||||
|
||||
func New(s *store.Store, devMode bool) *Server {
|
||||
p := resolvePath("ollama")
|
||||
return &Server{store: s, bin: p, dev: devMode}
|
||||
@@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
|
||||
|
||||
// Attempt to retrieve inference compute information from the server
|
||||
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
||||
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
inference := []InferenceCompute{}
|
||||
marker := regexp.MustCompile(`inference compute.*library=`)
|
||||
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||
info := &InferenceInfo{}
|
||||
computeMarker := regexp.MustCompile(`inference compute.*library=`)
|
||||
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
|
||||
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
|
||||
|
||||
q := `inference compute.*%s=["]([^"]*)["]`
|
||||
nq := `inference compute.*%s=(\S+)\s`
|
||||
type regex struct {
|
||||
@@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
match := marker.FindStringSubmatch(line)
|
||||
if len(match) > 0 {
|
||||
// Check for inference compute lines
|
||||
if computeMarker.MatchString(line) {
|
||||
ic := InferenceCompute{
|
||||
Library: get("library", line),
|
||||
Variant: get("variant", line),
|
||||
@@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
}
|
||||
|
||||
slog.Info("Matched", "inference compute", ic)
|
||||
inference = append(inference, ic)
|
||||
} else {
|
||||
// Break out on first non matching line after we start matching
|
||||
if len(inference) > 0 {
|
||||
return inference, nil
|
||||
info.Computes = append(info.Computes, ic)
|
||||
continue
|
||||
}
|
||||
// Check for default context length line
|
||||
if defaultCtxMarker.MatchString(line) {
|
||||
match := defaultCtxRegex.FindStringSubmatch(line)
|
||||
if len(match) > 1 {
|
||||
numCtx, err := strconv.Atoi(match[1])
|
||||
if err == nil {
|
||||
info.DefaultContextLength = numCtx
|
||||
slog.Info("Matched default context length", "default_num_ctx", numCtx)
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
// If we've found compute info but hit a non-matching line, return what we have
|
||||
// This handles older server versions that don't log the default context line
|
||||
if len(info.Computes) > 0 {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceComputer(t *testing.T) {
|
||||
func TestGetInferenceInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
log string
|
||||
exp []InferenceCompute
|
||||
name string
|
||||
log string
|
||||
expComputes []InferenceCompute
|
||||
expDefaultCtxLen int
|
||||
}{
|
||||
{
|
||||
name: "metal",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 262144,
|
||||
},
|
||||
{
|
||||
name: "cpu",
|
||||
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
||||
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
||||
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "cpu",
|
||||
Driver: "0.0",
|
||||
VRAM: "31.3 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "cuda1",
|
||||
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
||||
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{{
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
Compute: "6.1",
|
||||
@@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
Name: "NVIDIA GeForce GT 1030",
|
||||
VRAM: "3.9 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 4096,
|
||||
},
|
||||
{
|
||||
name: "frank",
|
||||
@@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
||||
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
exp: []InferenceCompute{
|
||||
expComputes: []InferenceCompute{
|
||||
{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
@@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
VRAM: "16.0 GiB",
|
||||
},
|
||||
},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "missing_default_context",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 0, // No default context line, should return 0
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
ics, err := GetInferenceComputer(ctx)
|
||||
info, err := GetInferenceInfo(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf(" failed to get inference compute: %v", err)
|
||||
t.Fatalf("failed to get inference info: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(ics, tt.exp) {
|
||||
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
|
||||
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
|
||||
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
|
||||
}
|
||||
if info.DefaultContextLength != tt.expDefaultCtxLen {
|
||||
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceComputerTimeout(t *testing.T) {
|
||||
func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
tmpDir := t.TempDir()
|
||||
@@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
||||
}
|
||||
_, err = GetInferenceComputer(ctx)
|
||||
_, err = GetInferenceInfo(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout")
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 13
|
||||
const currentSchemaVersion = 14
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -73,7 +73,7 @@ func (db *database) init() error {
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
@@ -251,6 +251,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
case 13:
|
||||
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
|
||||
if err := db.migrateV13ToV14(); err != nil {
|
||||
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||
}
|
||||
version = 14
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -474,6 +480,22 @@ func (db *database) migrateV12ToV13() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV13ToV14 changes the default context_length from 4096 to 0.
|
||||
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
|
||||
func (db *database) migrateV13ToV14() error {
|
||||
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update context_length default: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
|
||||
@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
db, err := newDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed v13 settings row: %v", err)
|
||||
}
|
||||
|
||||
if err := db.migrate(); err != nil {
|
||||
t.Fatalf("migration from v13 to v14 failed: %v", err)
|
||||
}
|
||||
|
||||
var contextLength int
|
||||
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
|
||||
t.Fatalf("failed to read context_length: %v", err)
|
||||
}
|
||||
|
||||
if contextLength != 0 {
|
||||
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
|
||||
}
|
||||
|
||||
version, err := db.getSchemaVersion()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get schema version: %v", err)
|
||||
}
|
||||
if version != currentSchemaVersion {
|
||||
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDeletionWithCascade(t *testing.T) {
|
||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
2
app/store/testdata/schema.sql
vendored
2
app/store/testdata/schema.sql
vendored
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
|
||||
@@ -289,10 +289,12 @@ export class InferenceCompute {
|
||||
}
|
||||
export class InferenceComputeResponse {
|
||||
inferenceComputes: InferenceCompute[];
|
||||
defaultContextLength: number;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
||||
this.defaultContextLength = source["defaultContextLength"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
|
||||
@@ -4,7 +4,6 @@ import {
|
||||
ChatEvent,
|
||||
DownloadEvent,
|
||||
ErrorEvent,
|
||||
InferenceCompute,
|
||||
InferenceComputeResponse,
|
||||
ModelCapabilitiesResponse,
|
||||
Model,
|
||||
@@ -407,7 +406,7 @@ export async function* pullModel(
|
||||
}
|
||||
}
|
||||
|
||||
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
@@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const inferenceComputeResponse = new InferenceComputeResponse(data);
|
||||
return inferenceComputeResponse.inferenceComputes || [];
|
||||
return new InferenceComputeResponse(data);
|
||||
}
|
||||
|
||||
export async function fetchHealth(): Promise<boolean> {
|
||||
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
type CloudStatusResponse,
|
||||
updateCloudSetting,
|
||||
updateSettings,
|
||||
getInferenceCompute,
|
||||
} from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
@@ -77,6 +78,13 @@ export default function Settings() {
|
||||
|
||||
const settings = settingsData?.settings || null;
|
||||
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
queryFn: getInferenceCompute,
|
||||
});
|
||||
|
||||
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
|
||||
|
||||
const updateSettingsMutation = useMutation({
|
||||
mutationFn: updateSettings,
|
||||
onSuccess: () => {
|
||||
@@ -204,7 +212,7 @@ export default function Settings() {
|
||||
Models: "",
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 4096,
|
||||
ContextLength: 0,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
@@ -507,13 +515,11 @@ export default function Settings() {
|
||||
</Description>
|
||||
<div className="mt-3">
|
||||
<Slider
|
||||
value={(() => {
|
||||
// Otherwise use the settings value
|
||||
return settings.ContextLength || 4096;
|
||||
})()}
|
||||
value={settings.ContextLength || defaultContextLength || 0}
|
||||
onChange={(value) => {
|
||||
handleChange("ContextLength", value);
|
||||
}}
|
||||
disabled={!defaultContextLength}
|
||||
options={[
|
||||
{ value: 4096, label: "4k" },
|
||||
{ value: 8192, label: "8k" },
|
||||
|
||||
@@ -6,10 +6,11 @@ export interface SliderProps {
|
||||
value?: number;
|
||||
onChange?: (value: number) => void;
|
||||
className?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
({ label, options, value = 0, onChange }, ref) => {
|
||||
({ label, options, value = 0, onChange, disabled = false }, ref) => {
|
||||
const [selectedValue, setSelectedValue] = React.useState(value);
|
||||
const [isDragging, setIsDragging] = React.useState(false);
|
||||
const containerRef = React.useRef<HTMLDivElement>(null);
|
||||
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}, [value]);
|
||||
|
||||
const handleClick = (optionValue: number) => {
|
||||
if (disabled) return;
|
||||
setSelectedValue(optionValue);
|
||||
onChange?.(optionValue);
|
||||
};
|
||||
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
};
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent) => {
|
||||
if (disabled) return;
|
||||
setIsDragging(true);
|
||||
e.preventDefault();
|
||||
};
|
||||
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2" ref={ref}>
|
||||
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
|
||||
{label && <label className="text-sm font-medium">{label}</label>}
|
||||
<div className="relative">
|
||||
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
||||
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
<button
|
||||
onClick={() => handleClick(option.value)}
|
||||
onMouseDown={handleMouseDown}
|
||||
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
|
||||
disabled={disabled}
|
||||
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
|
||||
>
|
||||
<div className="relative w-5 h-5 flex items-center justify-center">
|
||||
{selectedValue === option.value && (
|
||||
{selectedValue === option.value && !disabled && (
|
||||
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -28,12 +28,14 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||
);
|
||||
|
||||
const { data: inferenceComputes = [] } = useQuery({
|
||||
queryKey: ["inference-compute"],
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
queryFn: getInferenceCompute,
|
||||
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
||||
});
|
||||
|
||||
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
|
||||
|
||||
const totalVRAM = useMemo(
|
||||
() => getTotalVRAM(inferenceComputes),
|
||||
[inferenceComputes],
|
||||
|
||||
@@ -45,7 +45,8 @@ type InferenceCompute struct {
|
||||
}
|
||||
|
||||
type InferenceComputeResponse struct {
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
DefaultContextLength int `json:"defaultContextLength"`
|
||||
}
|
||||
|
||||
type ModelCapabilitiesResponse struct {
|
||||
|
||||
18
app/ui/ui.go
18
app/ui/ui.go
@@ -1420,11 +1420,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
|
||||
settings.Models = envconfig.Models()
|
||||
}
|
||||
|
||||
// set default context length if not set
|
||||
if settings.ContextLength == 0 {
|
||||
settings.ContextLength = 4096
|
||||
}
|
||||
|
||||
// Include current runtime settings
|
||||
settings.Agent = s.Agent
|
||||
settings.Tools = s.Tools
|
||||
@@ -1500,14 +1495,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
|
||||
info, err := server.GetInferenceInfo(ctx)
|
||||
if err != nil {
|
||||
s.log().Error("failed to get inference compute", "error", err)
|
||||
return fmt.Errorf("failed to get inference compute: %w", err)
|
||||
s.log().Error("failed to get inference info", "error", err)
|
||||
return fmt.Errorf("failed to get inference info: %w", err)
|
||||
}
|
||||
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
|
||||
for i, ic := range serverInferenceComputes {
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
|
||||
for i, ic := range info.Computes {
|
||||
inferenceComputes[i] = responses.InferenceCompute{
|
||||
Library: ic.Library,
|
||||
Variant: ic.Variant,
|
||||
@@ -1519,7 +1514,8 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
||||
}
|
||||
|
||||
response := responses.InferenceComputeResponse{
|
||||
InferenceComputes: inferenceComputes,
|
||||
InferenceComputes: inferenceComputes,
|
||||
DefaultContextLength: info.DefaultContextLength,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1956,6 +1956,10 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
|
||||
@@ -15,8 +15,9 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -139,34 +140,54 @@ func SaveIntegration(appName string, models []string) error {
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
var onboarded bool
|
||||
if existing != nil {
|
||||
aliases = existing.Aliases
|
||||
onboarded = existing.Onboarded
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Onboarded: onboarded,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
existing.Onboarded = true
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return ic.Models[0]
|
||||
return integrationConfig.Models[0]
|
||||
}
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ic.Models
|
||||
return integrationConfig.Models
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
@@ -234,12 +255,12 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return ic, nil
|
||||
return integrationConfig, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
@@ -272,8 +293,8 @@ func listIntegrations() ([]integration, error) {
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
for _, integrationConfig := range cfg.Integrations {
|
||||
result = append(result, *integrationConfig)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@@ -228,6 +228,31 @@ func IsIntegrationInstalled(name string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// AutoInstallable returns true if the integration can be automatically
|
||||
// installed when not found (e.g. via npm).
|
||||
func AutoInstallable(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "openclaw", "clawdbot", "moltbot":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureInstalled checks if an auto-installable integration is present and
|
||||
// offers to install it if missing. Returns nil for non-auto-installable
|
||||
// integrations or when the binary is already on PATH.
|
||||
func EnsureInstalled(name string) error {
|
||||
if !AutoInstallable(name) {
|
||||
return nil
|
||||
}
|
||||
if IsIntegrationInstalled(name) {
|
||||
return nil
|
||||
}
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
}
|
||||
|
||||
// IsEditorIntegration returns true if the named integration uses multi-model
|
||||
// selection (implements the Editor interface).
|
||||
func IsEditorIntegration(name string) bool {
|
||||
@@ -926,6 +951,10 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if err := EnsureInstalled(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
|
||||
modelFlag = ""
|
||||
}
|
||||
|
||||
@@ -1,81 +1,287 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const defaultGatewayPort = 18789
|
||||
|
||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||
var openclawModelShowTimeout = 5 * time.Second
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "openclaw", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
bin, err := ensureOpenclawInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
"--install-daemon",
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
|
||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
||||
// that Edit() wrote before Run() was called.
|
||||
if err := c.Edit([]string{model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
|
||||
err = cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
if len(args) > 0 {
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
if err := restart.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
if !waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// If the gateway isn't running, start it as a background child process.
|
||||
if !portOpen(addr) {
|
||||
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||
gw.Env = openclawEnv()
|
||||
if err := gw.Start(); err != nil {
|
||||
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||
}
|
||||
defer func() {
|
||||
if gw.Process != nil {
|
||||
_ = gw.Process.Kill()
|
||||
_ = gw.Wait()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
if !waitForPort(addr, 30*time.Second) {
|
||||
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||
}
|
||||
|
||||
printOpenclawReady(bin, token, port, firstLaunch)
|
||||
|
||||
tuiArgs := []string{"tui"}
|
||||
if firstLaunch {
|
||||
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||
}
|
||||
tui := exec.Command(bin, tuiArgs...)
|
||||
tui.Env = openclawEnv()
|
||||
tui.Stdin = os.Stdin
|
||||
tui.Stdout = os.Stdout
|
||||
tui.Stderr = os.Stderr
|
||||
if err := tui.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||
port = defaultGatewayPort
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", port
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
continue
|
||||
}
|
||||
gw, _ := config["gateway"].(map[string]any)
|
||||
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||
port = int(p)
|
||||
}
|
||||
auth, _ := gw["auth"].(map[string]any)
|
||||
if t, _ := auth["token"].(string); t != "" {
|
||||
token = t
|
||||
}
|
||||
return token, port
|
||||
}
|
||||
return "", port
|
||||
}
|
||||
|
||||
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||
u := fmt.Sprintf("http://localhost:%d", port)
|
||||
if token != "" {
|
||||
u += "/#token=" + url.QueryEscape(token)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// openclawEnv returns the current environment with provider API keys cleared
|
||||
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||
func openclawEnv() []string {
|
||||
clear := map[string]bool{
|
||||
"ANTHROPIC_API_KEY": true,
|
||||
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||
"OPENAI_API_KEY": true,
|
||||
"GEMINI_API_KEY": true,
|
||||
"MISTRAL_API_KEY": true,
|
||||
"GROQ_API_KEY": true,
|
||||
"XAI_API_KEY": true,
|
||||
"OPENROUTER_API_KEY": true,
|
||||
}
|
||||
var env []string
|
||||
for _, e := range os.Environ() {
|
||||
key, _, _ := strings.Cut(e, "=")
|
||||
if !clear[key] {
|
||||
env = append(env, e)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// portOpen checks if a TCP port is currently accepting connections.
|
||||
func portOpen(addr string) bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func waitForPort(addr string, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func windowsHint(err error) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\n"+
|
||||
"OpenClaw runs best on WSL2.\n"+
|
||||
"Quick setup: wsl --install\n"+
|
||||
"Guide: https://docs.openclaw.ai/windows", err)
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
@@ -107,6 +313,144 @@ func (c *Openclaw) onboarded() bool {
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||
// operator.admin. Only patches the local device, not remote ones.
|
||||
// Best-effort: silently returns on any error.
|
||||
func patchDeviceScopes() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := readLocalDeviceID(home)
|
||||
if deviceID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var devices map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &devices); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, ok := devices[deviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"operator.read",
|
||||
"operator.admin",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
changed := patchScopes(dev, "scopes", required)
|
||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||
for _, tok := range tokens {
|
||||
if tokenMap, ok := tok.(map[string]any); ok {
|
||||
if patchScopes(tokenMap, "scopes", required) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(devices, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||
func readLocalDeviceID(home string) string {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var auth map[string]any
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return ""
|
||||
}
|
||||
id, _ := auth["deviceId"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||
// any scopes were added.
|
||||
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
existing, _ := obj[key].([]any)
|
||||
have := make(map[string]bool, len(existing))
|
||||
for _, s := range existing {
|
||||
if str, ok := s.(string); ok {
|
||||
have[str] = true
|
||||
}
|
||||
}
|
||||
added := false
|
||||
for _, s := range required {
|
||||
if !have[s] {
|
||||
existing = append(existing, s)
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if added {
|
||||
obj[key] = existing
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("openclaw installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
@@ -161,8 +505,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
ollama["api"] = "ollama"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
@@ -175,25 +518,13 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
for _, m := range models {
|
||||
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
if existing, ok := existingByID[m]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
@@ -230,7 +561,213 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear any per-session model overrides so the new primary takes effect
|
||||
// immediately rather than being shadowed by a cached modelOverride.
|
||||
clearSessionModelOverride(models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearSessionModelOverride removes per-session model overrides from the main
|
||||
// agent session so the global primary model takes effect on the next TUI launch.
|
||||
func clearSessionModelOverride(primary string) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if json.Unmarshal(data, &sessions) != nil {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, sess := range sessions {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
out, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
|
||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
|
||||
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
|
||||
// present. Returns true if the extension is available.
|
||||
func ensureWebSearchPlugin() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
|
||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
||||
return true // already installed
|
||||
}
|
||||
|
||||
npmBin, err := exec.LookPath("npm")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||
out, err := pack.Output()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
tgzName := strings.TrimSpace(string(out))
|
||||
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||
defer os.Remove(tgzPath)
|
||||
|
||||
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||
if err := tar.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||
return true
|
||||
}
|
||||
|
||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||
// on any error.
|
||||
func registerWebSearchPlugin() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
plugins = make(map[string]any)
|
||||
}
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
return // already registered
|
||||
}
|
||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||
plugins["entries"] = entries
|
||||
config["plugins"] = plugins
|
||||
|
||||
// Disable the built-in web search since our plugin replaces it.
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
web["search"] = map[string]any{"enabled": false}
|
||||
tools["web"] = web
|
||||
config["tools"] = tools
|
||||
|
||||
out, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(configPath, out, 0o600)
|
||||
}
|
||||
|
||||
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||
// The second return value indicates whether the model is a cloud (remote) model.
|
||||
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||
entry := map[string]any{
|
||||
"id": modelID,
|
||||
"name": modelID,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
showCtx := ctx
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
entry["input"] = []any{"text", "image"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
entry["reasoning"] = true
|
||||
}
|
||||
|
||||
// Cloud models: use hardcoded limits for context/output tokens.
|
||||
// Capability detection above still applies (vision, thinking).
|
||||
if resp.RemoteModel != "" {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
entry["contextWindow"] = l.Context
|
||||
entry["maxTokens"] = l.Output
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo (local models only)
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
entry["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return entry, false
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
@@ -26,6 +36,124 @@ func TestOpenclawIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawRunPassthroughArgs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bin := filepath.Join(tmpDir, "openclaw")
|
||||
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "gateway --someflag" {
|
||||
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
t.Run("success persists onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
@@ -359,19 +487,16 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
// Verify base schema fields (always set regardless of API availability)
|
||||
if entry["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", entry["id"])
|
||||
}
|
||||
if entry["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", entry["name"])
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
@@ -876,3 +1001,589 @@ func TestOpenclawOnboarded(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawGatewayInfo(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns defaults when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reads token and port from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {
|
||||
"port": 9999,
|
||||
"auth": {"mode": "token", "token": "my-secret"}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "my-secret" {
|
||||
t.Errorf("expected token %q, got %q", "my-secret", token)
|
||||
}
|
||||
if port != 9999 {
|
||||
t.Errorf("expected port 9999, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses default port when not in config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {"auth": {"token": "tok"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "tok" {
|
||||
t.Errorf("expected token %q, got %q", "tok", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "legacy-token" {
|
||||
t.Errorf("expected token %q, got %q", "legacy-token", token)
|
||||
}
|
||||
if port != 12345 {
|
||||
t.Errorf("expected port 12345, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles missing gateway section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrintOpenclawReady(t *testing.T) {
|
||||
t.Run("includes port in URL", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", 9999, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "localhost:9999") {
|
||||
t.Errorf("expected port 9999 in output, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "#token=") {
|
||||
t.Error("should not include token fragment when token is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("URL-escapes token", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
escaped := url.QueryEscape("my token&special=chars")
|
||||
if !strings.Contains(output, "#token="+escaped) {
|
||||
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple token is not mangled", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "#token=ollama") {
|
||||
t.Errorf("expected #token=ollama in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes web UI hint", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Open the Web UI") {
|
||||
t.Errorf("expected web UI hint in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("first launch shows quick start tips", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subsequent launch shows single tip", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Tip:") {
|
||||
t.Errorf("expected single tip line, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "Quick start") {
|
||||
t.Errorf("should not show quick start on subsequent launch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModelConfig(t *testing.T) {
|
||||
t.Run("nil client returns base config", func(t *testing.T) {
|
||||
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
|
||||
|
||||
if cfg["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", cfg["id"])
|
||||
}
|
||||
if cfg["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", cfg["name"])
|
||||
}
|
||||
if cfg["cost"] == nil {
|
||||
t.Error("cost should be set")
|
||||
}
|
||||
// Should not have capability fields without API
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set without API")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set without API")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 1 {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns base config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
// Should still have input (default)
|
||||
if cfg["input"] == nil {
|
||||
t.Error("input should always be set")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("times out slow show and returns base config", func(t *testing.T) {
|
||||
oldTimeout := openclawModelShowTimeout
|
||||
openclawModelShowTimeout = 50 * time.Millisecond
|
||||
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
start := time.Now()
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
|
||||
elapsed := time.Since(start)
|
||||
if elapsed >= 250*time.Millisecond {
|
||||
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
|
||||
}
|
||||
if cfg["id"] != "slow-model" {
|
||||
t.Errorf("id = %v, want slow-model", cfg["id"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set on timeout")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set on timeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
|
||||
// Use a model name that's in cloudModelLimits and make the server
|
||||
// report it as a remote/cloud model
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud model")
|
||||
}
|
||||
if cfg["contextWindow"] != 204_800 {
|
||||
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
|
||||
}
|
||||
if cfg["maxTokens"] != 128_000 {
|
||||
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports vision capability should have input: [text, image].
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud vision model")
|
||||
}
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports thinking capability should have reasoning: true.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud thinking model")
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for cloud thinking model")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationOnboarded(t *testing.T) {
|
||||
t.Run("returns false when not set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("is case insensitive", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing integration data", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
model := IntegrationModel("openclaw")
|
||||
if model != "llama3.2" {
|
||||
t.Errorf("expected first model llama3.2, got %q", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
|
||||
@@ -524,7 +524,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -555,6 +555,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
@@ -618,7 +624,11 @@ func (m model) View() string {
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
@@ -634,7 +644,9 @@ func (m model) View() string {
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
|
||||
@@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bts = sanitizeNonFiniteJSON(bts)
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
@@ -319,12 +320,14 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &lfm2Model{}
|
||||
case "Qwen3NextForCausalLM":
|
||||
conv = &qwen3NextModel{}
|
||||
case "NemotronHForCausalLM":
|
||||
conv = &nemotronHModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err)
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
|
||||
385
convert/convert_nemotron_h.go
Normal file
385
convert/convert_nemotron_h.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type hybridPattern string
|
||||
|
||||
func (p *hybridPattern) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
*p = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(data, &single); err == nil {
|
||||
*p = hybridPattern(strings.TrimSpace(single))
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if err := json.Unmarshal(data, &parts); err == nil {
|
||||
*p = hybridPattern(strings.Join(parts, ""))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("hybrid_override_pattern must be a string or string array")
|
||||
}
|
||||
|
||||
type nemotronHModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
ConvKernel uint32 `json:"conv_kernel"`
|
||||
SSMStateSize uint32 `json:"ssm_state_size"`
|
||||
MambaNumHeads uint32 `json:"mamba_num_heads"`
|
||||
MambaHeadDim uint32 `json:"mamba_head_dim"`
|
||||
NGroups uint32 `json:"n_groups"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"`
|
||||
|
||||
// MoE
|
||||
NumExperts uint32 `json:"num_experts"`
|
||||
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||
NRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||
MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
ExpertGroupCount uint32 `json:"n_group"`
|
||||
ExpertGroupUsedCount uint32 `json:"topk_group"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*nemotronHModel)(nil)
|
||||
|
||||
func (n *nemotronHModel) parseMore(_ fs.FS) error {
|
||||
if n.NumHiddenLayers == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_hidden_layers must be set")
|
||||
}
|
||||
if n.HiddenSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size must be set")
|
||||
}
|
||||
if n.NumAttentionHeads == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_attention_heads must be set")
|
||||
}
|
||||
if n.HeadDim == 0 {
|
||||
if n.HiddenSize%n.NumAttentionHeads != 0 {
|
||||
return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads)
|
||||
}
|
||||
n.HeadDim = n.HiddenSize / n.NumAttentionHeads
|
||||
}
|
||||
if n.NumKeyValueHeads == 0 {
|
||||
n.NumKeyValueHeads = n.NumAttentionHeads
|
||||
}
|
||||
if n.ConvKernel == 0 {
|
||||
return fmt.Errorf("nemotron_h: conv_kernel must be set")
|
||||
}
|
||||
if n.SSMStateSize == 0 {
|
||||
return fmt.Errorf("nemotron_h: ssm_state_size must be set")
|
||||
}
|
||||
if n.ssmHeadCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_num_heads must be set")
|
||||
}
|
||||
if n.MambaHeadDim == 0 {
|
||||
return fmt.Errorf("nemotron_h: mamba_head_dim must be set")
|
||||
}
|
||||
if n.NGroups == 0 {
|
||||
n.NGroups = 1
|
||||
}
|
||||
|
||||
if _, _, err := n.layerArrays(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n.isMoE() {
|
||||
if n.routedExpertCount() == 0 {
|
||||
return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok == 0 {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models")
|
||||
}
|
||||
if n.NumExpertsPerTok > n.routedExpertCount() {
|
||||
return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount())
|
||||
}
|
||||
if n.moeIntermediateSize() == 0 {
|
||||
return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) isMoE() bool {
|
||||
return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) routedExpertCount() uint32 {
|
||||
return cmp.Or(n.NRoutedExperts, n.NumExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) sharedExpertCount() uint32 {
|
||||
return cmp.Or(n.NSharedExperts, n.NumSharedExperts)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmHeadCount() uint32 {
|
||||
return n.MambaNumHeads
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) ssmInnerSize() uint32 {
|
||||
return n.MambaHeadDim * n.ssmHeadCount()
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) epsilon() float32 {
|
||||
return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5))
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) moeIntermediateSize() uint32 {
|
||||
return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) denseIntermediateSize() uint32 {
|
||||
return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) {
|
||||
pattern := strings.TrimSpace(string(n.HybridOverridePattern))
|
||||
if pattern == "" {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set")
|
||||
}
|
||||
|
||||
runes := []rune(pattern)
|
||||
if len(runes) != int(n.NumHiddenLayers) {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers)
|
||||
}
|
||||
|
||||
headCountKV = make([]uint32, n.NumHiddenLayers)
|
||||
ffnLengths = make([]uint32, n.NumHiddenLayers)
|
||||
|
||||
attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads)
|
||||
moeFFN := n.moeIntermediateSize()
|
||||
denseFFN := n.denseIntermediateSize()
|
||||
|
||||
for i, layerType := range runes {
|
||||
switch layerType {
|
||||
case 'M':
|
||||
// Recurrent layer: no KV heads and no FFN.
|
||||
case '*', 'A':
|
||||
// Attention-only layer.
|
||||
headCountKV[i] = attnKVHeads
|
||||
case 'E':
|
||||
// MoE layer.
|
||||
if moeFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = moeFFN
|
||||
case '-':
|
||||
// Dense FFN layer.
|
||||
if denseFFN == 0 {
|
||||
return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i)
|
||||
}
|
||||
ffnLengths[i] = denseFFN
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i)
|
||||
}
|
||||
}
|
||||
|
||||
return headCountKV, ffnLengths, nil
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) KV(t *Tokenizer) KV {
|
||||
kv := n.ModelParameters.KV(t)
|
||||
|
||||
arch := "nemotron_h"
|
||||
if n.isMoE() {
|
||||
arch = "nemotron_h_moe"
|
||||
}
|
||||
kv["general.architecture"] = arch
|
||||
kv["block_count"] = n.NumHiddenLayers
|
||||
kv["context_length"] = n.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = n.HiddenSize
|
||||
kv["attention.head_count"] = n.NumAttentionHeads
|
||||
kv["attention.key_length"] = n.HeadDim
|
||||
kv["attention.value_length"] = n.HeadDim
|
||||
kv["attention.layer_norm_epsilon"] = n.epsilon()
|
||||
kv["attention.layer_norm_rms_epsilon"] = n.epsilon()
|
||||
kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000))
|
||||
if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 {
|
||||
kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor)
|
||||
}
|
||||
|
||||
if headCountKV, ffnLengths, err := n.layerArrays(); err == nil {
|
||||
kv["attention.head_count_kv"] = headCountKV
|
||||
kv["feed_forward_length"] = ffnLengths
|
||||
}
|
||||
|
||||
kv["ssm.conv_kernel"] = n.ConvKernel
|
||||
kv["ssm.inner_size"] = n.ssmInnerSize()
|
||||
kv["ssm.state_size"] = n.SSMStateSize
|
||||
kv["ssm.group_count"] = n.NGroups
|
||||
kv["ssm.time_step_rank"] = n.ssmHeadCount()
|
||||
|
||||
if n.isMoE() {
|
||||
kv["expert_count"] = n.routedExpertCount()
|
||||
kv["expert_used_count"] = n.NumExpertsPerTok
|
||||
kv["expert_feed_forward_length"] = n.moeIntermediateSize()
|
||||
if n.sharedExpertCount() > 0 {
|
||||
kv["expert_shared_count"] = n.sharedExpertCount()
|
||||
}
|
||||
if n.MoESharedExpertIntermediate > 0 {
|
||||
kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate
|
||||
}
|
||||
kv["expert_weights_norm"] = n.NormTopKProb
|
||||
kv["expert_weights_scale"] = n.RoutedScalingFactor
|
||||
if n.ExpertGroupCount > 0 {
|
||||
kv["expert_group_count"] = n.ExpertGroupCount
|
||||
}
|
||||
if n.ExpertGroupUsedCount > 0 {
|
||||
kv["expert_group_used_count"] = n.ExpertGroupUsedCount
|
||||
}
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func normalizeVectorShapeToColumn(shape []uint64) []uint64 {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return []uint64{shape[0], 1}
|
||||
case 2:
|
||||
if shape[0] == 1 && shape[1] > 1 {
|
||||
return []uint64{shape[1], 1}
|
||||
}
|
||||
if shape[1] == 1 && shape[0] > 1 {
|
||||
return []uint64{shape[0], 1}
|
||||
}
|
||||
}
|
||||
|
||||
return slices.Clone(shape)
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
remaining := ts
|
||||
if n.isMoE() {
|
||||
merges := make([]merge, 0, n.NumHiddenLayers*2)
|
||||
for i := range n.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
})
|
||||
}
|
||||
|
||||
merged, rest := mergeTensors(ts, merges...)
|
||||
out = append(out, merged...)
|
||||
remaining = rest
|
||||
}
|
||||
|
||||
nGroups := uint64(cmp.Or(n.NGroups, uint32(1)))
|
||||
for _, t := range remaining {
|
||||
name := t.Name()
|
||||
shape := slices.Clone(t.Shape())
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".ssm_a"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) {
|
||||
out := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
out[i] = -float32(math.Exp(float64(v)))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
case strings.HasSuffix(name, ".ssm_d"):
|
||||
shape = normalizeVectorShapeToColumn(shape)
|
||||
case strings.HasSuffix(name, ".ssm_norm.weight"):
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
if nGroups > 0 && shape[0]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[0] / nGroups}
|
||||
}
|
||||
case 2:
|
||||
if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 {
|
||||
shape = []uint64{nGroups, shape[1] / nGroups}
|
||||
}
|
||||
}
|
||||
case strings.HasSuffix(name, ".ssm_conv1d.weight"):
|
||||
if len(shape) == 3 {
|
||||
if shape[0] == 1 {
|
||||
shape = []uint64{shape[1], shape[2]}
|
||||
} else if shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *nemotronHModel) Replacements() []string {
|
||||
return []string{
|
||||
// Embedding and output
|
||||
"lm_head", "output",
|
||||
"backbone.embeddings", "token_embd",
|
||||
"backbone.norm_f", "output_norm",
|
||||
"backbone.layers", "blk",
|
||||
|
||||
// Recurrent (Mamba2) tensors
|
||||
"mixer.in_proj", "ssm_in",
|
||||
"mixer.out_proj", "ssm_out",
|
||||
"mixer.dt_bias", "ssm_dt.bias",
|
||||
"mixer.A_log", "ssm_a",
|
||||
"mixer.D", "ssm_d",
|
||||
"mixer.conv1d", "ssm_conv1d",
|
||||
"mixer.norm.weight", "ssm_norm.weight",
|
||||
|
||||
// Attention tensors
|
||||
"mixer.q_proj", "attn_q",
|
||||
"mixer.k_proj", "attn_k",
|
||||
"mixer.v_proj", "attn_v",
|
||||
"mixer.o_proj", "attn_output",
|
||||
|
||||
// FFN / MoE tensors
|
||||
"mixer.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mixer.gate", "ffn_gate_inp",
|
||||
"mixer.fc1_latent_proj", "ffn_latent_in",
|
||||
"mixer.fc2_latent_proj", "ffn_latent_out",
|
||||
"mixer.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mixer.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mixer.up_proj", "ffn_up",
|
||||
"mixer.down_proj", "ffn_down",
|
||||
|
||||
// Per-layer pre-norm
|
||||
".norm.weight", ".attn_norm.weight",
|
||||
}
|
||||
}
|
||||
230
convert/convert_nemotron_h_test.go
Normal file
230
convert/convert_nemotron_h_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHybridPatternUnmarshal(t *testing.T) {
|
||||
t.Run("string", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array", func(t *testing.T) {
|
||||
var p hybridPattern
|
||||
if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(p), "MEM*"; got != want {
|
||||
t.Fatalf("unexpected pattern: got %q want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNemotronHLayerArrays(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
}
|
||||
|
||||
headsKV, ffn, err := m.layerArrays()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHKV(t *testing.T) {
|
||||
m := &nemotronHModel{
|
||||
MaxPositionEmbeddings: 1048576,
|
||||
HiddenSize: 2688,
|
||||
NumHiddenLayers: 5,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 2,
|
||||
HeadDim: 128,
|
||||
LayerNormEpsilon: 1e-5,
|
||||
RopeTheta: 10000,
|
||||
PartialRotaryFactor: 0.5,
|
||||
ConvKernel: 4,
|
||||
SSMStateSize: 128,
|
||||
MambaNumHeads: 64,
|
||||
MambaHeadDim: 64,
|
||||
NGroups: 8,
|
||||
HybridOverridePattern: "MEM*E",
|
||||
NRoutedExperts: 128,
|
||||
NSharedExperts: 1,
|
||||
NumExpertsPerTok: 6,
|
||||
MoEIntermediateSize: 1856,
|
||||
MoESharedExpertIntermediate: 3712,
|
||||
NormTopKProb: true,
|
||||
RoutedScalingFactor: 2.5,
|
||||
}
|
||||
if err := m.parseMore(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}})
|
||||
if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want {
|
||||
t.Fatalf("unexpected architecture: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
headCountKV, ok := kv["attention.head_count_kv"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"])
|
||||
}
|
||||
if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
ffnLength, ok := kv["feed_forward_length"].([]uint32)
|
||||
if !ok {
|
||||
t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"])
|
||||
}
|
||||
if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) {
|
||||
t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHTensorsTransforms(t *testing.T) {
|
||||
m := &nemotronHModel{NGroups: 8}
|
||||
in := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_a",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_d",
|
||||
shape: []uint64{4},
|
||||
data: []float32{0, 1, 2, 3},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_norm.weight",
|
||||
shape: []uint64{16},
|
||||
data: make([]float32, 16),
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "blk.0.ssm_conv1d.weight",
|
||||
shape: []uint64{10, 1, 4},
|
||||
data: make([]float32, 40),
|
||||
},
|
||||
}
|
||||
|
||||
out := m.Tensors(in)
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in))
|
||||
}
|
||||
|
||||
got := map[string]struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{}
|
||||
for _, t := range out {
|
||||
got[t.Name] = struct {
|
||||
shape []uint64
|
||||
writer io.WriterTo
|
||||
}{shape: t.Shape, writer: t.WriterTo}
|
||||
}
|
||||
|
||||
if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_a shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) {
|
||||
t.Fatalf("unexpected ssm_d shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) {
|
||||
t.Fatalf("unexpected ssm_norm shape: %v", shape)
|
||||
}
|
||||
if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) {
|
||||
t.Fatalf("unexpected ssm_conv1d shape: %v", shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
values := make([]float32, 4)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 0 -> -exp(0) == -1
|
||||
if values[0] != -1 {
|
||||
t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHLoadModelMetadata(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := `{
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"model_type": "nemotron_h",
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"max_position_embeddings": 32768,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 64,
|
||||
"layer_norm_epsilon": 1e-5,
|
||||
"conv_kernel": 4,
|
||||
"ssm_state_size": 128,
|
||||
"mamba_num_heads": 16,
|
||||
"mamba_head_dim": 32,
|
||||
"n_groups": 8,
|
||||
"hybrid_override_pattern": "ME*M",
|
||||
"n_routed_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 256
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
kv, _, err := LoadModelMetadata(os.DirFS(tempDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := kv.(*nemotronHModel); !ok {
|
||||
t.Fatalf("unexpected converter type: %T", kv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNemotronHReplacementsLatentProjections(t *testing.T) {
|
||||
m := &nemotronHModel{}
|
||||
r := strings.NewReplacer(m.Replacements()...)
|
||||
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want {
|
||||
t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want {
|
||||
t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
97
convert/json_compat.go
Normal file
97
convert/json_compat.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package convert
|
||||
|
||||
// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some
|
||||
// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers.
|
||||
//
|
||||
// This is intentionally conservative:
|
||||
// - only runs outside quoted strings
|
||||
// - only rewrites full tokens
|
||||
//
|
||||
// We map these values to 0 because encoding/json rejects non-finite values,
|
||||
// and these fields are typically model-side metadata not consumed by the
|
||||
// converter.
|
||||
func sanitizeNonFiniteJSON(in []byte) []byte {
|
||||
if len(in) == 0 {
|
||||
return in
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(in))
|
||||
inString := false
|
||||
escape := false
|
||||
|
||||
for i := 0; i < len(in); {
|
||||
c := in[i]
|
||||
|
||||
if inString {
|
||||
out = append(out, c)
|
||||
if escape {
|
||||
escape = false
|
||||
} else if c == '\\' {
|
||||
escape = true
|
||||
} else if c == '"' {
|
||||
inString = false
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = true
|
||||
out = append(out, c)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "-Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("-Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "Infinity") {
|
||||
out = append(out, '0')
|
||||
i += len("Infinity")
|
||||
continue
|
||||
}
|
||||
|
||||
if hasToken(in, i, "NaN") {
|
||||
out = append(out, '0')
|
||||
i += len("NaN")
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, c)
|
||||
i++
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func hasToken(in []byte, at int, tok string) bool {
|
||||
end := at + len(tok)
|
||||
if at < 0 || end > len(in) {
|
||||
return false
|
||||
}
|
||||
if string(in[at:end]) != tok {
|
||||
return false
|
||||
}
|
||||
if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) {
|
||||
return false
|
||||
}
|
||||
if end < len(in) && !isJSONValueSuffixBoundary(in[end]) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isJSONWhitespace(b byte) bool {
|
||||
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
|
||||
}
|
||||
|
||||
func isJSONValuePrefixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ':' || b == ',' || b == '['
|
||||
}
|
||||
|
||||
func isJSONValueSuffixBoundary(b byte) bool {
|
||||
return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}'
|
||||
}
|
||||
46
convert/json_compat_test.go
Normal file
46
convert/json_compat_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package convert
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeNonFiniteJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "infinity token",
|
||||
in: `{"a":[0,Infinity,1]}`,
|
||||
want: `{"a":[0,0,1]}`,
|
||||
},
|
||||
{
|
||||
name: "negative infinity token",
|
||||
in: `{"a":-Infinity}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "nan token",
|
||||
in: `{"a":NaN}`,
|
||||
want: `{"a":0}`,
|
||||
},
|
||||
{
|
||||
name: "tokens inside strings untouched",
|
||||
in: `{"a":"Infinity -Infinity NaN","b":Infinity}`,
|
||||
want: `{"a":"Infinity -Infinity NaN","b":0}`,
|
||||
},
|
||||
{
|
||||
name: "identifier-like token untouched",
|
||||
in: `{"a":InfinityValue}`,
|
||||
want: `{"a":InfinityValue}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := string(sanitizeNonFiniteJSON([]byte(tt.in)))
|
||||
if got != tt.want {
|
||||
t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,47 +4,65 @@ title: OpenClaw
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
Ollama handles everything automatically:
|
||||
|
||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||
3. **Model** — Pick a model from the selector (local or cloud)
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
## Configure without launching
|
||||
|
||||
To change the model without starting the gateway and TUI:
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
To use a specific model directly:
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
```bash
|
||||
ollama launch openclaw --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
If the gateway is already running, it restarts automatically to pick up the new model.
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `glm-5:cloud` — Reasoning and code generation
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
```bash
|
||||
openclaw configure --section channels
|
||||
```
|
||||
|
||||
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
|
||||
|
||||
## Stopping the gateway
|
||||
|
||||
```bash
|
||||
openclaw gateway stop
|
||||
```
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
@@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
func (kv KV) FFNLength() []uint64 {
|
||||
ffnLengthDefault := uint32(0)
|
||||
ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault)
|
||||
if len(ffnLength) == 1 {
|
||||
ffnLengthDefault = ffnLength[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(ffnLength) > nLayers {
|
||||
slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(ffnLength) {
|
||||
out[i] = uint64(ffnLengthDefault)
|
||||
} else {
|
||||
out[i] = uint64(ffnLength[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
@@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
@@ -865,6 +887,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"nemotron_h", "nemotron_h_moe",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3next",
|
||||
|
||||
752
kvcache/recurrent.go
Normal file
752
kvcache/recurrent.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 32
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1280)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
// Config configures a shared hybrid recurrent cache.
|
||||
type RecurrentConfig struct {
|
||||
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
ConvDim int
|
||||
ConvChannels int
|
||||
RecurrentStateSize int
|
||||
CheckpointLogPrefix string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Cache = (*Recurrent)(nil)
|
||||
_ CheckpointCache = (*Recurrent)(nil)
|
||||
)
|
||||
|
||||
// Cache stores:
|
||||
// - a standard causal KV cache
|
||||
// - per-sequence conv state for recurrent operators
|
||||
// - per-sequence recurrent state for recurrent operators
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int
|
||||
convChannels int
|
||||
|
||||
// Recurrent state dimensions
|
||||
recurrentStateSize int
|
||||
|
||||
logPrefix string
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
seqCounts map[int]int
|
||||
slotScratch [1]int32
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer recurrent state buffers (allocated lazily)
|
||||
recurrentCtxs map[int]ml.Context
|
||||
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointRecurCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
convDim: config.ConvDim,
|
||||
convChannels: config.ConvChannels,
|
||||
recurrentStateSize: config.RecurrentStateSize,
|
||||
logPrefix: config.CheckpointLogPrefix,
|
||||
slotForSeq: make(map[int]int),
|
||||
seqCounts: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
recurrentCtxs: make(map[int]ml.Context),
|
||||
recurrentStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: DefaultCheckpointCount,
|
||||
checkpointMinPos: DefaultCheckpointMinPos,
|
||||
checkpointInterval: DefaultCheckpointInterval,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointRecurCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.recurrentCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointRecurCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
if nTokens == 0 {
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
c.curSlots = c.curSlots[:0]
|
||||
c.curSlotsInput = nil
|
||||
c.curSeqTokens = 0
|
||||
c.reserveCheckpoints = false
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path for single-sequence batches (common during decode and prefill).
|
||||
firstSeq := batch.Sequences[0]
|
||||
singleSeq := true
|
||||
for _, s := range batch.Sequences[1:] {
|
||||
if s != firstSeq {
|
||||
singleSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if singleSeq {
|
||||
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers.
|
||||
seqCounts := c.seqCounts
|
||||
for s := range seqCounts {
|
||||
delete(seqCounts, s)
|
||||
}
|
||||
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if seqCounts[s] == 0 {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
}
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch.
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
||||
c.curSeqs = append(c.curSeqs[:0], seq)
|
||||
c.curSeqTokens = seqTokens
|
||||
|
||||
if reserve {
|
||||
c.curSlots = append(c.curSlots[:0], 0)
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.slotForSeq[seq] = slot
|
||||
c.refCount[slot] = 1
|
||||
slotList := [1]int{slot}
|
||||
c.zeroSlots(ctx, slotList[:])
|
||||
}
|
||||
|
||||
c.curSlots = append(c.curSlots[:0], slot)
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
||||
c.setCurSlotsInput(ctx)
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = reserve
|
||||
c.planCheckpoints(batch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
||||
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
||||
}
|
||||
|
||||
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
||||
switch len(slots) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
c.slotScratch[0] = int32(slots[0])
|
||||
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
||||
default:
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, v := range slots {
|
||||
slotIndices[i] = int32(v)
|
||||
}
|
||||
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
||||
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
slotsTensor := c.slotsInput(ctx, slots)
|
||||
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.recurrentStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
||||
for _, buf := range c.recurrentStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
c.setCurSlotsInput(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.recurrentStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.pendingRestore, seq)
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok || !c.validSlot(slot) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
||||
if c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
slot = newSlot
|
||||
}
|
||||
|
||||
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *Recurrent) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) SeqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *Recurrent) NumSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.recurrentStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.recurrentCtxs[layer]; !ok {
|
||||
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
||||
c.recurrentStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
||||
c.ensureWritableOnce(ctx)
|
||||
return c.writableError
|
||||
}
|
||||
|
||||
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
}
|
||||
|
||||
return buf.Rows(ctx, c.SlotsTensor())
|
||||
}
|
||||
|
||||
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
ctx.Forward(src.Copy(ctx, view))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := c.convBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes new conv state for current batch sequences.
|
||||
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
||||
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := 1
|
||||
for _, d := range dims {
|
||||
if d <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
size *= d
|
||||
}
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
shape := make([]int, 0, len(dims)+1)
|
||||
shape = append(shape, dims...)
|
||||
shape = append(shape, c.NumSeqs())
|
||||
return cur.Reshape(ctx, shape...), nil
|
||||
}
|
||||
|
||||
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
||||
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := dim0 * dim1 * dim2
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
||||
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
||||
|
||||
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *Recurrent) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *Recurrent) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
561
kvcache/recurrent_checkpoints.go
Normal file
561
kvcache/recurrent_checkpoints.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
|
||||
// memory usage while preserving prefix reuse for recurrent state.
|
||||
|
||||
type checkpointEntry struct {
|
||||
pos int32
|
||||
conv map[int]ml.Tensor
|
||||
recurrent map[int]ml.Tensor
|
||||
}
|
||||
|
||||
type slotCheckpointStore struct {
|
||||
entries []checkpointEntry
|
||||
size int
|
||||
next int
|
||||
lastPos int32
|
||||
}
|
||||
|
||||
type checkpointRestore struct {
|
||||
slot int
|
||||
idx int
|
||||
pos int32
|
||||
}
|
||||
|
||||
func newSlotCheckpointStore(n int) *slotCheckpointStore {
|
||||
entries := make([]checkpointEntry, n)
|
||||
for i := range entries {
|
||||
entries[i].pos = -1
|
||||
}
|
||||
return &slotCheckpointStore{
|
||||
entries: entries,
|
||||
lastPos: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) reset() {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
for i := range s.entries {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) record(pos int32) int {
|
||||
if len(s.entries) == 0 {
|
||||
return -1
|
||||
}
|
||||
idx := s.next
|
||||
s.next = (s.next + 1) % len(s.entries)
|
||||
if s.size < len(s.entries) {
|
||||
s.size++
|
||||
}
|
||||
s.entries[idx].pos = pos
|
||||
s.lastPos = pos
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
|
||||
bestIdx := -1
|
||||
bestPos := int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 || pos >= targetPos {
|
||||
continue
|
||||
}
|
||||
if pos > bestPos {
|
||||
bestPos = pos
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
if bestIdx < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
return bestIdx, bestPos, true
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) pruneAfter(pos int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
minIdx := 0
|
||||
for i := range s.entries {
|
||||
if s.entries[i].pos > pos {
|
||||
s.entries[i].pos = -1
|
||||
}
|
||||
if s.entries[i].pos >= 0 {
|
||||
size++
|
||||
if s.entries[i].pos < minPos {
|
||||
minPos = s.entries[i].pos
|
||||
minIdx = i
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = pos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) {
|
||||
if len(s.entries) == 0 {
|
||||
s.size = 0
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
offset := beginIndex - endIndex
|
||||
|
||||
size := 0
|
||||
next := -1
|
||||
minPos := int32(math.MaxInt32)
|
||||
maxPos := int32(-1)
|
||||
minIdx := 0
|
||||
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
if pos >= beginIndex && pos < endIndex {
|
||||
s.entries[i].pos = -1
|
||||
} else if pos >= endIndex {
|
||||
s.entries[i].pos = pos + offset
|
||||
}
|
||||
}
|
||||
|
||||
pos = s.entries[i].pos
|
||||
if pos >= 0 {
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
minIdx = i
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
} else if next == -1 {
|
||||
next = i
|
||||
}
|
||||
}
|
||||
|
||||
s.size = size
|
||||
if size == 0 {
|
||||
s.next = 0
|
||||
s.lastPos = -1
|
||||
return
|
||||
}
|
||||
|
||||
if next != -1 {
|
||||
s.next = next
|
||||
} else {
|
||||
// Full ring: overwrite the oldest checkpoint next.
|
||||
s.next = minIdx
|
||||
}
|
||||
s.lastPos = maxPos
|
||||
}
|
||||
|
||||
func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
|
||||
minPos = int32(math.MaxInt32)
|
||||
maxPos = int32(-1)
|
||||
for i := range s.entries {
|
||||
pos := s.entries[i].pos
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
size++
|
||||
if pos < minPos {
|
||||
minPos = pos
|
||||
}
|
||||
if pos > maxPos {
|
||||
maxPos = pos
|
||||
}
|
||||
}
|
||||
if size == 0 {
|
||||
minPos = -1
|
||||
maxPos = -1
|
||||
}
|
||||
return size, minPos, maxPos, s.lastPos
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointTag() string {
|
||||
if c.logPrefix == "" {
|
||||
return "kvcache.recurrent"
|
||||
}
|
||||
return c.logPrefix
|
||||
}
|
||||
|
||||
func (c *Recurrent) planCheckpoints(batch input.Batch) {
|
||||
if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cap(c.curCheckpointPos) < len(c.curSeqs) {
|
||||
c.curCheckpointPos = make([]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curCheckpointPos {
|
||||
c.curCheckpointPos[i] = -1
|
||||
}
|
||||
for k := range c.curCheckpointSlots {
|
||||
delete(c.curCheckpointSlots, k)
|
||||
}
|
||||
|
||||
posMax := make(map[int]int32, len(c.curSeqs))
|
||||
for i, seq := range batch.Sequences {
|
||||
pos := batch.Positions[i]
|
||||
if cur, ok := posMax[seq]; !ok || pos > cur {
|
||||
posMax[seq] = pos
|
||||
}
|
||||
}
|
||||
|
||||
for i, seq := range c.curSeqs {
|
||||
pos, ok := posMax[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pos < c.checkpointMinPos {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
store := c.checkpointStore(slot)
|
||||
lastPos := store.lastPos
|
||||
if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
|
||||
c.curCheckpointPos[i] = pos
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore {
|
||||
store, ok := c.checkpoints[slot]
|
||||
if ok {
|
||||
return store
|
||||
}
|
||||
store = newSlotCheckpointStore(c.checkpointCount)
|
||||
c.checkpoints[slot] = store
|
||||
return store
|
||||
}
|
||||
|
||||
func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int {
|
||||
if c.checkpointCount == 0 {
|
||||
return -1
|
||||
}
|
||||
if idx, ok := c.curCheckpointSlots[slot]; ok {
|
||||
return idx
|
||||
}
|
||||
store := c.checkpointStore(slot)
|
||||
idx := store.record(pos)
|
||||
if idx >= 0 {
|
||||
c.curCheckpointSlots[slot] = idx
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool {
|
||||
if pos <= 0 {
|
||||
return false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, _, ok = store.bestIndex(pos)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if targetPos <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
store, ok := c.checkpoints[slot]
|
||||
if !ok {
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
|
||||
return 0, false
|
||||
}
|
||||
idx, pos, ok := store.bestIndex(targetPos)
|
||||
if !ok {
|
||||
size, minPos, maxPos, lastPos := store.window()
|
||||
slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
|
||||
"min", minPos, "max", maxPos, "last", lastPos)
|
||||
return 0, false
|
||||
}
|
||||
c.pendingRestore[seq] = checkpointRestore{
|
||||
slot: slot,
|
||||
idx: idx,
|
||||
pos: pos,
|
||||
}
|
||||
return pos + 1, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error {
|
||||
entry, ok := c.restoreEntry(restore)
|
||||
if !ok {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
|
||||
for layer, src := range entry.conv {
|
||||
buf := c.convBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
for layer, src := range entry.recurrent {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
ctx.Forward(buf.SetRows(ctx, src, slotIdx))
|
||||
}
|
||||
|
||||
if len(entry.conv) > 0 || len(entry.recurrent) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
store := c.checkpoints[restore.slot]
|
||||
store.pruneAfter(restore.pos)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreComplete(restore checkpointRestore) bool {
|
||||
_, ok := c.restoreEntry(restore)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
|
||||
store, ok := c.checkpoints[restore.slot]
|
||||
if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
|
||||
return nil, false
|
||||
}
|
||||
entry := &store.entries[restore.idx]
|
||||
if entry.pos < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if !c.entryComplete(entry) {
|
||||
return nil, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) entryComplete(entry *checkpointEntry) bool {
|
||||
for layer := range c.convStates {
|
||||
if entry.conv == nil || entry.conv[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for layer := range c.recurrentStates {
|
||||
if entry.recurrent == nil || entry.recurrent[layer] == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Recurrent) clearCheckpoints(slot int) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) {
|
||||
if store, ok := c.checkpoints[slot]; ok {
|
||||
store.shiftRange(beginIndex, endIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
srcStore, ok := c.checkpoints[srcSlot]
|
||||
if !ok || srcStore.size == 0 {
|
||||
return
|
||||
}
|
||||
dstStore := c.checkpointStore(dstSlot)
|
||||
dstStore.size = srcStore.size
|
||||
dstStore.next = srcStore.next
|
||||
dstStore.lastPos = srcStore.lastPos
|
||||
|
||||
for i := range srcStore.entries {
|
||||
srcEntry := &srcStore.entries[i]
|
||||
dstEntry := &dstStore.entries[i]
|
||||
dstEntry.pos = srcEntry.pos
|
||||
if srcEntry.conv != nil {
|
||||
if dstEntry.conv == nil {
|
||||
dstEntry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.conv {
|
||||
dst := c.ensureCheckpointConv(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
if srcEntry.recurrent != nil {
|
||||
if dstEntry.recurrent == nil {
|
||||
dstEntry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
for layer, src := range srcEntry.recurrent {
|
||||
dst := c.ensureCheckpointRecurrent(layer, dstEntry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointRecurrent(layer)
|
||||
return
|
||||
}
|
||||
if len(c.curCheckpointPos) == 0 {
|
||||
return
|
||||
}
|
||||
for i, pos := range c.curCheckpointPos {
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
slot := c.curSlots[i]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointRecurrent(layer, entry)
|
||||
seqSlice := src.Slice(ctx, 1, i, i+1, 1)
|
||||
ctx.Forward(seqSlice.Copy(ctx, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.conv == nil {
|
||||
entry.conv = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.conv[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointConvCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointConvCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
|
||||
entry.conv[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor {
|
||||
if entry.recurrent == nil {
|
||||
entry.recurrent = make(map[int]ml.Tensor)
|
||||
}
|
||||
if t, ok := entry.recurrent[layer]; ok {
|
||||
return t
|
||||
}
|
||||
ctx, ok := c.checkpointRecurCtxs[layer]
|
||||
if !ok {
|
||||
ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
|
||||
c.checkpointRecurCtxs[layer] = ctx
|
||||
}
|
||||
t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1)
|
||||
entry.recurrent[layer] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointConv(layer int) {
|
||||
key := checkpointReserveKey(layer, 0)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointConv(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *Recurrent) reserveCheckpointRecurrent(layer int) {
|
||||
key := checkpointReserveKey(layer, 1)
|
||||
if _, ok := c.checkpointReserved[key]; ok {
|
||||
return
|
||||
}
|
||||
for slot := range c.maxSequences {
|
||||
store := c.checkpointStore(slot)
|
||||
for i := range store.entries {
|
||||
entry := &store.entries[i]
|
||||
_ = c.ensureCheckpointRecurrent(layer, entry)
|
||||
}
|
||||
}
|
||||
c.checkpointReserved[key] = struct{}{}
|
||||
}
|
||||
|
||||
func checkpointReserveKey(layer int, kind int) int {
|
||||
return layer*2 + kind
|
||||
}
|
||||
288
kvcache/recurrent_checkpoints_test.go
Normal file
288
kvcache/recurrent_checkpoints_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func newTestCache() *Recurrent {
|
||||
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
|
||||
_, pos, ok := store.bestIndex(15)
|
||||
if !ok || pos != 10 {
|
||||
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
store.record(30) // overwrite oldest (10)
|
||||
|
||||
if _, _, ok := store.bestIndex(15); ok {
|
||||
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(40)
|
||||
if !ok || pos != 30 {
|
||||
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachePrepareRestore(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(5)
|
||||
store.record(9)
|
||||
store.record(15)
|
||||
|
||||
restorePos, ok := cache.PrepareRestore(1, 12)
|
||||
if !ok {
|
||||
t.Fatalf("expected restore ok")
|
||||
}
|
||||
if restorePos != 10 {
|
||||
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
||||
}
|
||||
rest, ok := cache.pendingRestore[1]
|
||||
if !ok {
|
||||
t.Fatalf("expected pending restore entry")
|
||||
}
|
||||
if rest.pos != 9 {
|
||||
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(20)
|
||||
|
||||
if store.lastPos != 20 {
|
||||
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, pos, ok := store.bestIndex(25)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
|
||||
_, pos, ok = store.bestIndex(35)
|
||||
if !ok || pos != 20 {
|
||||
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
// Simulate layer 0 requires both conv and recurrent checkpoints.
|
||||
cache.convStates[0] = nil
|
||||
cache.recurrentStates[0] = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
entry := &store.entries[idx]
|
||||
entry.conv = map[int]ml.Tensor{0: nil}
|
||||
// entry.recurrent intentionally missing
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
err := cache.Remove(1, 10, math.MaxInt32)
|
||||
if !errors.Is(err, ErrNotSupported) {
|
||||
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.checkpointCount = 3
|
||||
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
cache.pendingRestore = make(map[int]checkpointRestore)
|
||||
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.freeSlots = nil
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
idx := store.record(9)
|
||||
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
||||
|
||||
restore := cache.pendingRestore[1]
|
||||
if !cache.restoreComplete(restore) {
|
||||
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
_, err := cache.RecurrentState(nil, 0, 3)
|
||||
if !errors.Is(err, ErrInvalidRecurrentShape) {
|
||||
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
|
||||
store := newSlotCheckpointStore(5)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
store.shiftRange(2, 6)
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
|
||||
}
|
||||
if store.lastPos != 6 {
|
||||
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
|
||||
cache := newTestCache()
|
||||
cache.slotForSeq[1] = 0
|
||||
cache.refCount = []int{1}
|
||||
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
|
||||
|
||||
store := cache.checkpointStore(0)
|
||||
store.record(1)
|
||||
store.record(4)
|
||||
store.record(7)
|
||||
store.record(10)
|
||||
|
||||
if err := cache.Remove(1, 2, 6); err != nil {
|
||||
t.Fatalf("expected middle remove to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if _, ok := cache.pendingRestore[1]; ok {
|
||||
t.Fatalf("expected pending restore to be cleared after middle remove")
|
||||
}
|
||||
|
||||
var positions []int32
|
||||
for i := range store.entries {
|
||||
if store.entries[i].pos >= 0 {
|
||||
positions = append(positions, store.entries[i].pos)
|
||||
}
|
||||
}
|
||||
slices.Sort(positions)
|
||||
|
||||
want := []int32{1, 3, 6}
|
||||
if !slices.Equal(positions, want) {
|
||||
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.entries[0].conv = make(map[int]ml.Tensor)
|
||||
store.entries[0].conv[0] = nil
|
||||
store.entries[0].recurrent = make(map[int]ml.Tensor)
|
||||
store.entries[0].recurrent[0] = nil
|
||||
|
||||
store.record(40)
|
||||
|
||||
if store.entries[0].conv == nil {
|
||||
t.Fatalf("expected conv map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].recurrent == nil {
|
||||
t.Fatalf("expected recurrent map to be preserved on reuse")
|
||||
}
|
||||
if store.entries[0].pos != 40 {
|
||||
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
||||
store := newSlotCheckpointStore(2)
|
||||
|
||||
idx1 := store.record(10)
|
||||
idx2 := store.record(20)
|
||||
|
||||
if idx1 != 0 || idx2 != 1 {
|
||||
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
||||
}
|
||||
if store.size != 2 {
|
||||
t.Fatalf("expected size 2, got %d", store.size)
|
||||
}
|
||||
|
||||
_, pos1, ok1 := store.bestIndex(15)
|
||||
_, pos2, ok2 := store.bestIndex(25)
|
||||
|
||||
if !ok1 || pos1 != 10 {
|
||||
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
||||
}
|
||||
if !ok2 || pos2 != 20 {
|
||||
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
||||
store := newSlotCheckpointStore(0)
|
||||
|
||||
idx := store.record(10)
|
||||
if idx != -1 {
|
||||
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(15)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint for empty buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
||||
store := newSlotCheckpointStore(3)
|
||||
store.record(10)
|
||||
store.record(20)
|
||||
store.record(30)
|
||||
|
||||
store.pruneAfter(5)
|
||||
|
||||
if store.size != 0 {
|
||||
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
||||
}
|
||||
if store.lastPos != -1 {
|
||||
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
||||
}
|
||||
|
||||
_, _, ok := store.bestIndex(100)
|
||||
if ok {
|
||||
t.Fatalf("expected no checkpoint after pruning all")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 22 Feb 2026 14:12:30 -0800
|
||||
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
|
||||
specialization
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
2 files changed, 3 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 4ac135603..ac5ad53db 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c37447a10..4f338aa13 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
@@ -163,6 +163,7 @@ type Tensor interface {
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
||||
(ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
//switch (op->src[0]->type) {
|
||||
|
||||
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
|
||||
@@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor
|
||||
func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
|
||||
func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f }
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nemotronh"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/olmo3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
|
||||
88
model/models/nemotronh/attention.go
Normal file
88
model/models/nemotronh/attention.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// Attention implements simple attention without RoPE for Nemotron-H.
|
||||
// Unlike Qwen3Next, Nemotron-H attention has:
|
||||
// - No RoPE (position info comes from Mamba2 layers)
|
||||
// - Standard scaled dot-product attention
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
batchSize := nSeqTokens
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize)
|
||||
|
||||
headDim := opts.getHeadDim()
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim)
|
||||
}
|
||||
|
||||
// Q projection
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if query.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim)
|
||||
}
|
||||
numHeads := query.Dim(0) / headDim
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
|
||||
// K projection
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if key.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim)
|
||||
}
|
||||
numKVHeads := key.Dim(0) / headDim
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// V projection
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
if value.Dim(0)%headDim != 0 {
|
||||
return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim)
|
||||
}
|
||||
if value.Dim(0)/headDim != numKVHeads {
|
||||
return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim)
|
||||
}
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
// Standard attention computation (no RoPE)
|
||||
scale := opts.attentionScale
|
||||
if scale == 0 {
|
||||
scale = 1.0 / math.Sqrt(float64(headDim))
|
||||
}
|
||||
attention := nn.Attention(ctx, query, key, value, scale, cache)
|
||||
|
||||
// Flatten heads
|
||||
attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
|
||||
|
||||
// Output projection
|
||||
return a.Output.Forward(ctx, attention), nil
|
||||
}
|
||||
55
model/models/nemotronh/cache.go
Normal file
55
model/models/nemotronh/cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
|
||||
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
|
||||
type HybridCache struct {
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: Shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: ssmStateSize,
|
||||
CheckpointLogPrefix: "nemotronh",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
// SSMState returns the SSM state for current batch sequences.
|
||||
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
|
||||
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
|
||||
}
|
||||
|
||||
// UpdateSSMState writes a new SSM state for current batch sequences.
|
||||
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
197
model/models/nemotronh/mamba2.go
Normal file
197
model/models/nemotronh/mamba2.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// convKernel wraps the 1D convolution kernel tensor
|
||||
type convKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// Mamba2 implements the Mamba2 SSM layer for Nemotron-H.
|
||||
// The forward pass follows llama.cpp's build_mamba2_layer:
|
||||
// 1. Input projection: zxBCdt = SSMIn @ hidden
|
||||
// 2. Split: z, xBC, dt
|
||||
// 3. Concat with conv state, apply SSMConv, save new conv state
|
||||
// 4. Apply SiLU to convolved xBC
|
||||
// 5. Split: x, B, C
|
||||
// 6. Add dt bias
|
||||
// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids)
|
||||
// 8. D skip: y = y + x * D
|
||||
// 9. Swiglu with z: y = z * silu(y)
|
||||
// 10. Group RMSNorm
|
||||
// 11. Output projection
|
||||
type Mamba2 struct {
|
||||
SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel
|
||||
SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"`
|
||||
SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head]
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head]
|
||||
SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head]
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"` // output projection
|
||||
Layer int
|
||||
}
|
||||
|
||||
func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := m.Layer
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
switch hiddenStates.Dim(2) {
|
||||
case 0:
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1)
|
||||
case 1:
|
||||
default:
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
// Nemotron-H is currently clamped to num_parallel=1.
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
if cache.numSeqs() != 1 {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
}
|
||||
nSeqs := 1
|
||||
|
||||
dConv := opts.ssmDConv
|
||||
dInner := opts.ssmDInner
|
||||
dState := opts.ssmDState
|
||||
nHead := opts.ssmNHead
|
||||
headDim := dInner / nHead
|
||||
nGroup := opts.ssmNGroup
|
||||
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
|
||||
// d_in_proj = 2*d_inner + 2*n_group*d_state + n_head
|
||||
zxBCdt := m.SSMIn.Forward(ctx, hiddenStates)
|
||||
|
||||
// Split into z, xBC, dt
|
||||
// z: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
z := zxBCdt.Slice(ctx, 0, 0, dInner, 1)
|
||||
z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs]
|
||||
xBCSize := dInner + 2*nGroup*dState
|
||||
xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1)
|
||||
if nSeqTokens == 1 {
|
||||
xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs)
|
||||
}
|
||||
|
||||
// dt: [n_head, n_seq_tokens, n_seqs]
|
||||
dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1)
|
||||
if nSeqTokens == 1 {
|
||||
dt = dt.Reshape(ctx, nHead, 1, nSeqs)
|
||||
} else {
|
||||
dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs)
|
||||
}
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs)
|
||||
}
|
||||
|
||||
// Reshape conv states: [d_conv-1, xBCSize, n_seqs]
|
||||
convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs)
|
||||
|
||||
// For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair.
|
||||
var xBCT ml.Tensor
|
||||
if nSeqTokens == 1 {
|
||||
xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs)
|
||||
} else {
|
||||
// Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs]
|
||||
xBCT = xBC.Permute(ctx, 1, 0, 2, 3)
|
||||
}
|
||||
|
||||
// Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs]
|
||||
convInput := convStates.Concat(ctx, xBCT, 0)
|
||||
|
||||
// Save new conv state (last d_conv-1 columns)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
|
||||
// Apply SSM convolution
|
||||
xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight)
|
||||
|
||||
// Add conv bias
|
||||
if m.SSMConv1DB != nil {
|
||||
xBC = xBC.Add(ctx, m.SSMConv1DB)
|
||||
}
|
||||
|
||||
// Apply SiLU
|
||||
xBC = xBC.SILU(ctx)
|
||||
|
||||
// Split xBC into x, B, C
|
||||
// x: [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
x := xBC.Slice(ctx, 0, 0, dInner, 1)
|
||||
x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
// B: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1)
|
||||
B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// C: [d_state, n_group, n_seq_tokens, n_seqs]
|
||||
C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1)
|
||||
C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs)
|
||||
|
||||
// Add dt bias
|
||||
dt = dt.Add(ctx, m.SSMDtB)
|
||||
|
||||
// Get SSM state from cache
|
||||
state, err := cache.SSMState(ctx, layer, dState, headDim, nHead)
|
||||
if err != nil {
|
||||
slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err)
|
||||
state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs)
|
||||
}
|
||||
|
||||
// SSMScan
|
||||
// state: [d_state, head_dim, n_head, n_seqs]
|
||||
// returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state
|
||||
ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor())
|
||||
|
||||
// ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState]
|
||||
yElems := headDim * nHead * nSeqTokens * nSeqs
|
||||
y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs)
|
||||
|
||||
stateOffsetBytes := yElems * x.Stride(0)
|
||||
stateElems := dState * headDim * nHead * nSeqs
|
||||
newState := ySsm.View(ctx, stateOffsetBytes, stateElems)
|
||||
newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs)
|
||||
|
||||
// Update SSM state in cache
|
||||
cache.UpdateSSMState(ctx, layer, newState)
|
||||
|
||||
// D skip connection: y = y + x * D
|
||||
if m.SSMD != nil {
|
||||
// SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs]
|
||||
xD := x.Mul(ctx, m.SSMD)
|
||||
y = y.Add(ctx, xD)
|
||||
}
|
||||
|
||||
// Swiglu with z: y = z * silu(y)
|
||||
y = z.SILU(ctx, y)
|
||||
|
||||
// Group RMSNorm
|
||||
if m.SSMNorm != nil {
|
||||
// Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs]
|
||||
innerPerGroup := dInner / nGroup
|
||||
y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs)
|
||||
y = m.SSMNorm.Forward(ctx, y, opts.eps)
|
||||
}
|
||||
|
||||
// Reshape back to [d_inner, n_seq_tokens, n_seqs]
|
||||
y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs)
|
||||
|
||||
// Output projection
|
||||
out := m.SSMOut.Forward(ctx, y)
|
||||
|
||||
// Reshape to 2D for consistency with attention output
|
||||
return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
|
||||
}
|
||||
417
model/models/nemotronh/model.go
Normal file
417
model/models/nemotronh/model.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// Options contains model configuration
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int // attention heads
|
||||
numKVHeads int // KV heads for attention layers
|
||||
headDim int
|
||||
eps float32
|
||||
|
||||
// Mamba2 SSM config
|
||||
ssmDConv int // conv kernel size
|
||||
ssmDInner int // inner dimension (d_inner)
|
||||
ssmDState int // state dimension
|
||||
ssmNHead int // number of SSM heads (dt_rank)
|
||||
ssmNGroup int // number of groups for B, C
|
||||
|
||||
// Per-layer configuration
|
||||
isRecurrent []bool // true = Mamba2, false = attention or FFN
|
||||
nFF []int // n_ff per layer (0 = attention-only)
|
||||
|
||||
// Attention scale
|
||||
attentionScale float64
|
||||
|
||||
// MoE config
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
expertWeightsNorm bool
|
||||
expertWeightsScale float32
|
||||
expertWeightsNormClip float32
|
||||
}
|
||||
|
||||
func (o Options) getHeadDim() int {
|
||||
if o.headDim > 0 {
|
||||
return o.headDim
|
||||
}
|
||||
if o.numHeads <= 0 {
|
||||
return 0
|
||||
}
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
// Operator is the interface for layer operators (Mamba2 or Attention)
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
|
||||
}
|
||||
|
||||
// MLP is the interface for feedforward networks
|
||||
type MLP interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator // Mamba2, Attention, or nil (for FFN-only layers)
|
||||
MLP MLP // Dense or MoE FFN, or nil
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
residual := hiddenStates
|
||||
|
||||
// Pre-layer norm
|
||||
hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// Layer operator (Mamba2, Attention, or FFN)
|
||||
if l.Operator != nil {
|
||||
var err error
|
||||
hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if l.MLP != nil {
|
||||
// FFN-only layer
|
||||
hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
|
||||
}
|
||||
|
||||
// Output projection for last layer
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
return hiddenStates.Add(ctx, residual), nil
|
||||
}
|
||||
|
||||
// Model is the main Nemotron-H model
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
// Shift is used for KV cache position shifting.
|
||||
// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed.
|
||||
func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
cache := m.Cache.(*HybridCache)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
var err error
|
||||
hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
|
||||
// Get per-layer configuration from GGUF metadata
|
||||
// Use the same interface pattern as qwen3next
|
||||
type perLayerConfig interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
FFNLength() []uint64
|
||||
}
|
||||
|
||||
var headCount []uint64
|
||||
var headCountKV []uint64
|
||||
var ffnLength []uint64
|
||||
|
||||
if plc, ok := c.(perLayerConfig); ok {
|
||||
headCount = plc.HeadCount()
|
||||
headCountKV = plc.HeadCountKV()
|
||||
ffnLength = plc.FFNLength()
|
||||
}
|
||||
|
||||
// Build per-layer arrays with defaults
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
nFF := make([]int, numLayers)
|
||||
|
||||
for i := range numLayers {
|
||||
// Get per-layer values
|
||||
kvHeads := uint64(1) // Default non-zero
|
||||
if i < len(headCountKV) {
|
||||
kvHeads = headCountKV[i]
|
||||
}
|
||||
ff := uint64(0)
|
||||
if i < len(ffnLength) {
|
||||
ff = ffnLength[i]
|
||||
}
|
||||
nFF[i] = int(ff)
|
||||
|
||||
// A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0
|
||||
// This matches llama.cpp behavior for Nemotron-H
|
||||
isRecurrent[i] = kvHeads == 0 && ff == 0
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
isMoE := c.Uint("expert_count") > 0
|
||||
|
||||
for i := range layers {
|
||||
if isRecurrent[i] {
|
||||
// Mamba2 layer
|
||||
layers[i].Operator = &Mamba2{Layer: i}
|
||||
} else if nFF[i] == 0 {
|
||||
// Attention-only layer (n_head_kv > 0, n_ff == 0)
|
||||
layers[i].Operator = &Attention{}
|
||||
} else {
|
||||
// FFN layer (n_ff > 0)
|
||||
if isMoE {
|
||||
layers[i].MLP = &MoESparse{}
|
||||
} else {
|
||||
layers[i].MLP = &Dense{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get attention head configuration
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
if numHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 {
|
||||
numHeads = int(headCount[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv"))
|
||||
if numKVHeads == 0 {
|
||||
for i := range numLayers {
|
||||
if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 {
|
||||
numKVHeads = int(headCountKV[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = numHeads
|
||||
}
|
||||
}
|
||||
|
||||
headDim := int(c.Uint("attention.head_dim"))
|
||||
if headDim == 0 {
|
||||
if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 {
|
||||
headDim = keyLength
|
||||
} else if numHeads > 0 {
|
||||
headDim = int(c.Uint("embedding_length")) / numHeads
|
||||
}
|
||||
}
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("nemotronh: invalid attention head dimension")
|
||||
}
|
||||
if numHeads <= 0 {
|
||||
// Attention layers derive per-layer head counts from projection weights.
|
||||
// Keep a non-zero default to avoid invalid option math.
|
||||
numHeads = 1
|
||||
}
|
||||
|
||||
numExperts := int(c.Uint("expert_count"))
|
||||
numExpertsUsed := int(c.Uint("expert_used_count"))
|
||||
if numExperts > 0 {
|
||||
if numExpertsUsed <= 0 || numExpertsUsed > numExperts {
|
||||
return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts)
|
||||
}
|
||||
}
|
||||
|
||||
opts := &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ssmDConv: int(c.Uint("ssm.conv_kernel")),
|
||||
ssmDInner: int(c.Uint("ssm.inner_size")),
|
||||
ssmDState: int(c.Uint("ssm.state_size")),
|
||||
ssmNHead: int(c.Uint("ssm.time_step_rank")),
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
isRecurrent: isRecurrent,
|
||||
nFF: nFF,
|
||||
attentionScale: float64(c.Float("attention.scale")),
|
||||
numExperts: numExperts,
|
||||
numExpertsUsed: numExpertsUsed,
|
||||
expertWeightsNorm: c.Bool("expert_weights_norm", false),
|
||||
expertWeightsScale: c.Float("expert_weights_scale", 1.0),
|
||||
expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0),
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
convDim := max(0, opts.ssmDConv-1)
|
||||
convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
|
||||
ssmHeadDim := 0
|
||||
if opts.ssmNHead > 0 {
|
||||
ssmHeadDim = opts.ssmDInner / opts.ssmNHead
|
||||
}
|
||||
ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead
|
||||
|
||||
m := Model{
|
||||
Tokenizer: tokenizer.NewBytePairEncoding(
|
||||
&tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nemotron_h", New)
|
||||
model.Register("nemotron_h_moe", New)
|
||||
}
|
||||
|
||||
// Ensure Model implements model.Model
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
// Dense implements standard feedforward with ReLU-squared activation
|
||||
type Dense struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor {
|
||||
// up -> ReLU-squared -> down
|
||||
up := d.Up.Forward(ctx, x)
|
||||
up = up.RELU(ctx)
|
||||
up = up.Mul(ctx, up) // Square
|
||||
return d.Down.Forward(ctx, up)
|
||||
}
|
||||
|
||||
// MoESparse implements MoE with shared experts for Nemotron-H-MoE
|
||||
type MoESparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
|
||||
LatentIn *nn.Linear `gguf:"ffn_latent_in"`
|
||||
LatentOut *nn.Linear `gguf:"ffn_latent_out"`
|
||||
|
||||
// Shared experts
|
||||
SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim := hiddenStates.Dim(0)
|
||||
seqLen := hiddenStates.Dim(1)
|
||||
batchSize := hiddenStates.Dim(2)
|
||||
if batchSize == 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize)
|
||||
|
||||
// Router logits with sigmoid gating
|
||||
routerLogits := m.Router.Forward(ctx, hiddenStates2D)
|
||||
|
||||
// Weights come from unbiased sigmoid probabilities.
|
||||
probs := routerLogits.Sigmoid(ctx)
|
||||
|
||||
// Selection uses optional bias.
|
||||
selectionProbs := probs
|
||||
if m.Bias != nil {
|
||||
selectionProbs = selectionProbs.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
// Select top-k experts
|
||||
selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
|
||||
|
||||
// Normalize routing weights
|
||||
if opts.expertWeightsNorm {
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
weightsSum := routingWeights.SumRows(ctx)
|
||||
weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32))
|
||||
routingWeights = routingWeights.Div(ctx, weightsSum)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
|
||||
}
|
||||
|
||||
// Scale routing weights
|
||||
if opts.expertWeightsScale != 1.0 {
|
||||
routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale))
|
||||
}
|
||||
|
||||
routedInput := hiddenStates2D
|
||||
if m.LatentIn != nil {
|
||||
routedInput = m.LatentIn.Forward(ctx, routedInput)
|
||||
}
|
||||
hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1))
|
||||
|
||||
// Expert computation with ReLU-squared activation
|
||||
upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts)
|
||||
upOut = upOut.RELU(ctx)
|
||||
upOut = upOut.Mul(ctx, upOut) // Square
|
||||
experts := m.Down.Forward(ctx, upOut, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum over experts
|
||||
moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
if m.LatentOut != nil {
|
||||
moeOut = m.LatentOut.Forward(ctx, moeOut)
|
||||
}
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedUp != nil {
|
||||
sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D)
|
||||
sharedUp = sharedUp.RELU(ctx)
|
||||
sharedUp = sharedUp.Mul(ctx, sharedUp) // Square
|
||||
sharedOut := m.SharedDown.Forward(ctx, sharedUp)
|
||||
moeOut = moeOut.Add(ctx, sharedOut)
|
||||
}
|
||||
|
||||
return moeOut
|
||||
}
|
||||
@@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
|
||||
|
||||
// Some architectures are not safe with num_parallel > 1.
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
CHECK(handle->ctx != NULL);
|
||||
if (handle->ctx == NULL) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user