mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-31 12:07:45 -04:00
feat: prefix-cache-aware routing for distributed mode (#10071)
* feat(radixtree): generic prefix tree skeleton with longest-match Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(radixtree): Insert with path recency refresh and entry cap Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(radixtree): TTL idle-expiry and Evict sweep with branch pruning Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(radixtree): recency-weighted per-value Weight Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(radixtree): Remove all entries for a value Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(radixtree): race-free concurrency smoke test Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(radixtree): reclaim empty branches, RWMutex reads, TTL boundary, empty-key guard Address review findings on the generic prefix tree: - Extract a shared pruneWalk helper parameterized by a shouldClear predicate and use it from Evict, Remove, and the MaxEntries path. Previously evictOldestLocked cleared a victim's value but never removed the now value-less node or its childless ancestors, so internal nodes accumulated under sustained churn at the cap. The MaxEntries path now prunes the victim and its empty ancestors. - DRY: pruneWalk replaces the duplicated logic in the former pruneLocked and Remove's inner closure. - Switch Tree.mu to sync.RWMutex; LongestMatch, Weight and Len take the read lock (RLock) while Insert, Evict and Remove keep the write lock. Confirmed race-clean under go test -race. - Document the strict greater-than TTL boundary on Options.TTL and expired: age exactly equal to TTL is still live. - Guard Insert against an empty key (no-op): the root never holds a value. Adds Ginkgo specs covering MaxEntries eviction, ancestor reclamation, the no-growth-past-cap invariant, the TTL boundary, and empty-key behavior for both Insert and LongestMatch. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): RoutePolicy enum with parse/resolve Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): Config with defaults and validation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): deterministic xxhash prefix-chain extractor Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): pure filter-then-score replica selection Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): Provider interface and radix-tree-backed Index Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * style(prefixcache): gofmt policy enum comment alignment Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): head-first prefix chunking and hoist Weight out of sort Address code-quality review findings in the prefixcache package. Correctness: ExtractChain now chunks from absolute offset 0 with fixed [0,W),[W,2W),... boundaries and caps the chain to the FIRST MaxDepth head blocks. The previous tail-keeping logic shifted the byte offset by a non-window amount once a conversation grew past MaxDepth*WindowBytes, changing every hash each turn and silently breaking cross-turn longest-prefix matching. The reusable KV/prefix cache lives at the head of the prompt, so anchoring at offset 0 makes the chain a true prefix-chain: P and P+suffix share their full leading overlap. Add a regression spec proving cross-turn stability past the cap. Performance: Index.Decide precomputes each candidate's Weight once (decorate-sort-undecorate) instead of calling the O(tree size) Weight inside the O(n log n) sort comparator. Behavior is unchanged. Lint: encode prev with binary.LittleEndian.PutUint64 instead of a manual byte loop, clearing the modernize rangeint finding. Also add a concurrent Decide/Observe/Invalidate spec to exercise Index's documented concurrency safety under go test -race. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(messaging): prefixcache observe/invalidate subjects and payloads Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): NATS sync publish/apply for observe and invalidate Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributedhdr): ctx carrier for prefix-hash chain Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributedhdr): PrefixChainHook indirection for backend-side chain build Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(backend): stash prompt prefix chain on ctx before distributed routing Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(backend): mirror modelID fallback for prefix-chain salt parity Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): scheduling config columns for prefix-cache routing Add RoutePolicy and per-model balance/prefix-match override columns to ModelSchedulingConfig and include them in the SetModelScheduling upsert DoUpdates list so updates are not dropped on conflict. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): optional route preference in FindAndLockNodeWithModel Add a RoutePreference type and a new pref parameter so the atomic pick+lock+increment can be biased toward a preferred node without weakening atomicity. A nil preference reproduces the previous ORDER BY behavior exactly. Update the ModelRouter interface, both router.go call sites (pass nil for now; Phase 5 builds the real preference), the test doubles, and the distributed e2e caller. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): make Sync satisfy Provider with Evict Sync.Observe now returns whether the local index treated the assignment as new or extended, and Sync gains an Evict method that delegates to the wrapped index. Together these let SmartRouter hold a single prefixcache.Provider that broadcasts via NATS. Adds a compile-time Provider assertion and an Evict-delegates behavioral test. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): prefix-cache-aware preference and observe in SmartRouter.Route Add a PrefixProvider + PrefixConfig to SmartRouterOptions/SmartRouter (nil keeps routing byte-for-byte the round-robin floor). On each request Route now calls buildPreference: it reads the prompt prefix chain from ctx (distributedhdr.PrefixChain), resolves the per-model policy/thresholds over the global config, loads candidate replica in-flight via a new registry read LoadedReplicaStats (deduped to one entry per node using the MIN in-flight across that node's replicas), asks the provider to Decide, and runs prefixcache.Select. The chosen node is passed as the RoutePreference to FindAndLockNodeWithModel on all three pick paths (cache hit, locked re-pick, cold scheduleAndLoad), and the served node is recorded via Observe only when the resolved policy is prefix_cache so round-robin models never pollute the tree. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): invalidate prefix-cache entries on unload and stale removal UnloadModel and both staleness fall-through paths in Route (after a failed gRPC probe and RemoveNodeModel) now call prefixProvider.Invalidate(model, nodeID), guarded by a nil-provider check so the round-robin floor is unchanged. At runtime the provider is the *prefixcache.Sync, so invalidations also broadcast to peer frontends. Adds a test that a previously hot prefix no longer Decides to a node after UnloadModel. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(prefixcache): rolling forced-disturb pressure counter Add a concurrency-safe per-model rolling counter that tracks how many times a request had a usable hot prefix match but the load guard forced it off the warm node. Entries outside the window are dropped lazily on Count so the backing slice stays bounded. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): autoscale on prefix-cache forced-disturb pressure Wire the rolling forced-disturb counter into the SmartRouter and the ReplicaReconciler. Router: in buildPreference, after Decide + Select, record a forced-disturb when a usable hot prefix match existed (d.HotNodeID != "" and d.MatchRatio >= cfg.MinPrefixMatch) but Select chose a different node (or nothing) because the load guard ruled the warm node out. This is the scale-worthy signal: the cache-warm replica is saturated. It deliberately does not fire for all-unique workloads (no hot match), avoiding false-positive scale-ups. Pressure is optional on SmartRouterOptions; nil keeps the path a no-op. Reconciler: read the same Pressure instance in reconcileModel as an extra scale-up reason, reusing the existing MaxReplicas + ClusterCapacityForModel guards and the UnsatisfiableUntil cooldown that gates the whole method. Pressure never overrides MaxReplicas and never force-evicts; a no-capacity model does not spin. Window and threshold come from prefixcache.Config (PressureWindow default 1m, PressureScaleThreshold default 1) and are configurable via ReplicaReconcilerOptions. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): bound Pressure slice in Record; drop dead reconciler pressureWindow Record now prunes entries older than the rolling window (the same prune Count does), via a shared pruneLocked helper, so a model that takes forced-disturb records but is never Counted (e.g. one with zero loaded replicas the reconciler skips) no longer grows its backing slice unbounded. Also removes the dead pressureWindow struct field and the ReplicaReconcilerOptions.PressureWindow option from the reconciler: they were stored but never read (the window lives inside the *prefixcache.Pressure instance). The scale block now reads pressure.Count once into a local. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(api): prefix-cache fields in scheduling endpoint DTO with validation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(ui): prefix-cache routing controls in node scheduling form Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): wire prefix-cache index, NATS sync, and config Activates prefix-cache-aware routing in distributed mode. Builds the prefixcache Index + NATS-backed Sync + Pressure counter, installs the distributedhdr.PrefixChainHook so core/backend/llm.go attaches a prefix chain per request, subscribes to prefixcache.observe/prefixcache.invalidate to apply peers' events to the local index (no re-broadcast), threads PrefixProvider/PrefixConfig/Pressure into the SmartRouter and Pressure/PressureThreshold into the ReplicaReconciler, and runs a background eviction ticker (every TTL/2) bound to the app context. Enabled by default; --distributed-prefix-cache=false (LOCALAI_DISTRIBUTED_PREFIX_CACHE) opts out and leaves the provider/pressure nil so routing stays round-robin. --distributed-prefix-cache-ttl (LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL, default 5m) controls entry idle-timeout and eviction cadence. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(nodes): round-robin-floor invariant for prefix-cache routing Drives Select directly: a saturated hot node (in_flight 50 vs 0) is never picked even with a perfect prefix match (round-robin floor holds), while a balanced hot node within the load slack is reused. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(prefixcache): clear branch lint findings and em dashes Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): validate prefix-cache config at startup wiring Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * perf(radixtree): single-walk WeightsFor for batch value weights Add Tree.WeightsFor(values, now) which computes the recency-weighted weight for many values in a single O(N + len(values)) tree traversal, versus calling Weight once per value (O(len(values) * N)). Consumers that score K candidates against the tree under the read lock no longer pay K full walks. Extract the per-entry contribution math into an unexported helper shared by both Weight and WeightsFor so the metric stays identical (DRY). Weight's public behavior is unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(config): add ModelConfig.ModelID() single source of truth The c.Name fallback to c.Model was duplicated in core/backend/options.go (feeding model.WithModelID) and hand-copied into core/backend/llm.go (the prefix-chain salt). These MUST agree or the prefix-cache salt diverges silently from the id the model loader tracks. Consolidate both into a new config.ModelConfig.ModelID() helper and call it from both sites. Behavior is identical. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * perf(prefixcache): reuse one xxhash.Digest in ExtractChain ExtractChain allocated a fresh xxhash.New() Digest per block (up to MaxDepth per call) and grew the chain slice without preallocation. Reuse a single Digest via Reset() before each block and preallocate the chain to min(nBlocks, MaxDepth). xxhash seed 0 is stateless, so Reset()+Write produces the byte-identical value to a fresh New()+Write. Output hashes are unchanged, preserving the cross-process determinism that peers rely on over NATS. Verified by capturing ExtractChain output for the existing test inputs before and after the refactor: identical. Existing extractor tests pass unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): drop hot match when matched node is not a candidate; weigh cold candidates in one walk Index.Decide called radixtree.LongestMatch over the whole tree, so the deepest match could be a node that is offline, unloaded, or simply not in the passed candidate set. Honoring that as HotNodeID produced a false forced-disturb signal upstream (buildPreference records pressure when chosen != HotNodeID), making it look like a warm replica was load saturated when it was actually absent. Build the candidate set once and only set HotNodeID/MatchRatio when the matched node is an actual candidate; otherwise fall back to cold placement. A future refinement could ask the tree for the longest match restricted to the candidate nodes (shallower-but-valid) instead of dropping it. Also replace the per-candidate tree.Weight call in the cold-order sort with a single tree.WeightsFor walk, turning O(K*N) under the read lock into O(N + K). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(prefixcache): remove Select's unreachable deterministic fallback buildPreference always passes ColdOrder as a permutation of the full candidate set, so the cold-order loop hits every eligible candidate. The trailing best/bestIF scan was dead. Replace it with a plain "return """ and document that ColdOrder is guaranteed to cover all candidates, so "" means none were eligible. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(nodes): fetch model scheduling config once per Route GetModelScheduling was read three times per request - in resolveSelectorCandidates, buildPreference, and nodeMatchesScheduling - three DB round-trips for one row that is immutable for the life of the request, and not a consistent snapshot. Fetch it once near the top of Route and thread the *ModelSchedulingConfig (may be nil) into all three helpers. scheduleNewModel keeps its own fetch since it runs outside the Route snapshot. Behavior is identical for nil sched. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(autoscale): add Pressure.Reset to consume forced-disturb signal Pressure.Count is non-draining (it prunes only by age), so a single burst of forced-disturbs stays within the rolling window for the whole window and keeps Count >= threshold on every reconciler tick. The reconciler will use Reset to clear a model's events after acting on the signal so a fresh scale-up requires fresh forced-disturbs to accumulate, rather than one burst driving the model toward MaxReplicas. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(autoscale): at most one scale-up per reconcile tick, consume pressure Two autoscale bugs: 1. Over-scaling: the pressure scale-up block read Pressure.Count but never consumed it. With a non-draining counter a single forced-disturb burst kept Count >= threshold across the whole window, firing scaleUp on every tick and pushing the model toward MaxReplicas off one transient burst. After a successful pressure-triggered scale-up the reconciler now calls Pressure.Reset to consume the signal. 2. Double scale-up in one tick: the all-replicas-busy block and the pressure block could both fire in the same reconcileModel pass, each calling scaleUp(+1) against the same `current` read once at the top, so a model that was both busy and over threshold scaled +2 and could overshoot MaxReplicas by one. A scaledUp flag now enforces at most one scaleUp(+1) per tick: the pressure block is skipped if the busy block already scaled, and scale-down is skipped in any tick that scaled up. MinReplicas enforcement, UnsatisfiableUntil backoff, and capacity guards are unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(nodes): replica-removed chokepoint hook for prefix-cache invalidation Add SetReplicaRemovedHook to NodeRegistry and fire it from both RemoveNodeModel and RemoveAllNodeModelReplicas after a successful delete. This is the single chokepoint every replica-removal path funnels through (router eviction, reconciler scale-down, probe reaper, health-monitor node-down reap, RemoteUnloaderAdapter), so the prefix-cache index can be invalidated by construction rather than wiring each call site individually. The hook is stored in an atomic.Pointer so the startup wiring (setter) and the request/reconcile-time fire are race-free; it is nil-safe when unset. GORM Delete reports no error for a no-op delete, so the hook also fires when nothing was removed; the consumer's Invalidate(model, node) is idempotent so this is harmless. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): invalidate prefix-cache on any replica removal via registry hook Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(prefixcache): single source of truth for threshold bounds Extract ValidateThresholds into prefixcache/config.go so the per-model override validation (nodes.go endpoint) and Config.Validate share one implementation of the numeric bounds (min_prefix_match in [0,1], balance_abs_threshold >= 0, balance_rel_threshold == 0-or->= 1) instead of hard-coding them in two places. The route_policy allow-list stays explicit (not ParsePolicy, which maps typos to Default). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(nodes): preserve prefix-cache settings on partial scheduling update A scheduling POST that omitted route_policy/thresholds (e.g. a min_replicas-only update) full-replaced every column and silently reset the model's previously-configured prefix-cache settings to empty/zero. Make the four prefix-cache request fields pointers so omitted is distinguishable from explicit zero, and merge PATCH-style in SetSchedulingEndpoint: a provided pointer wins, an omitted one preserves the existing config value (zero default when none). Non-prefix fields keep their full-replace PUT semantics. Validation now runs on the resolved values via prefixcache.ValidateThresholds. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): make Invalidate a no-op for uncached models and skip empty broadcasts A registry chokepoint fires Sync.Invalidate(model, nodeID) for every replica removal of every model, including round-robin models that never used the prefix cache. Index.Invalidate previously called tree(model), which lazily created and permanently retained an empty radix tree for any model that ever lost a replica, growing the trees map without bound. Sync.Invalidate also published a NATS PrefixCacheInvalidateEvent on every call, amplifying no-op removals across the cluster. Index.Invalidate now looks the tree up read-only via existingTree and returns without allocating when none exists. The Provider interface is unchanged; Sync gates the broadcast through an optional invalidateExisting(bool) capability type-asserted from the wrapped Index, falling back to the prior always-broadcast behavior for other Provider implementations. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * perf(prefixcache): derive Decide candidacy from WeightsFor and skip trivial sort WeightsFor already returns a map keyed by every requested candidate, so the separate candidates set built to validate the hot match was redundant: a node is a candidate iff it is a key in the weights map. Drop the extra map and gate the hot-match check on weights membership. Also skip the sort when there is at most one candidate, since the input order is already the cold order. Behavior is unchanged. Deferred follow-up: skipping the WeightsFor walk entirely when a hot match wins would need lazy cross-file changes and is out of scope here. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(nodes): fire replica-removed hook on bulk node_models deletes; trim LoadedReplicaStats columns Bulk node-scoped node_models deletes (Register re-register cleanup, MarkOffline, MarkDraining, Deregister) removed rows directly without firing the replica-removed hook, so the prefix-cache index kept pointing at nodes whose models were gone. Capture the DISTINCT model names before each bulk delete and fire fireReplicaRemoved once per model after a successful delete, restoring the single-chokepoint invariant for all removal paths. The pre-query is skipped when no hook is set so the no-hook path stays cheap. Also narrow LoadedReplicaStats to SELECT only node_id and in_flight (the only fields the router consumer reads), dropping the JOIN-side available_vram fetch and unused columns while keeping the []ReplicaCandidate return type unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(reconciler): consume autoscale signals only on a real scale-up scaleUp was fire-and-forget (void) yet its callers unconditionally consumed the pressure signal (Pressure.Reset) and the MinReplicas hysteresis (ClearUnsatisfiable) right after calling it. If scaleUp added nothing (ScheduleAndLoadModel errored, or no node could be loaded) the saturated warm replica got no new replica AND its accumulated forced-disturb history was wiped, forcing the signal to re-accumulate over a full PressureWindow before the next attempt. Make scaleUp return whether at least one replica was actually scheduled, and gate the side effects on it: - pressure block (2b): set scaledUp and call Pressure.Reset only on success; on failure preserve the signal so the next tick retries off the same accumulated pressure. - busy-burst block (2): set scaledUp from the return value so a failed attempt does not suppress the pressure path or scale-down. - MinReplicas block: call ClearUnsatisfiable only on success so a failed attempt does not reset the unsatisfiable counter. All existing invariants (MaxReplicas, capacity gating, UnsatisfiableUntil cooldown, at-most-one-scale-up-per-tick) are preserved. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(nodes): drop router's redundant prefix-cache Invalidate calls The NodeRegistry removal chokepoint (RemoveNodeModel / RemoveAllNodeModelReplicas) now fires SetReplicaRemovedHook, which invalidates the prefix-cache index. The router was also calling prefixProvider.Invalidate explicitly right after each registry removal on the two stale-replica health-probe fall-throughs in Route and in UnloadModel, so every router-side eviction invalidated twice (double tree-prune + double NATS broadcast). Remove the three redundant explicit Invalidate calls and their empty nil-guards. Each removed call sat immediately after a registry removal that fires the hook, so invalidation is preserved via the chokepoint. Decide/Observe usage is untouched. Re-point the unit test (fake registry fires no hook) to assert the removal chokepoint is exercised on unload instead of the router's direct invalidation. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): broadcast invalidations unconditionally for cross-frontend coherence Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(prefixcache): reject TTL<=0 in Config.Validate (eviction ticker would panic) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(nodes): make capture+delete atomic in bulk node_models removal paths MarkOffline, MarkDraining, and the Register re-register cleanup ran the nodeModelNames SELECT and the bulk node_models DELETE as two separate statements on r.db with no transaction. A SetNodeModel landing between the two was deleted but its replica-removed hook never fired, leaving the prefix-cache index pointing at a removed replica until TTL or candidacy self-heal. Wrap the capture and the delete in a single db.Transaction in each path (mirroring how Deregister already does it). The captured model names are collected into a slice declared outside the closure; the replica-removed hook fires for each only after the transaction commits, so a rollback never invalidates the index for a removal that did not persist. The set of fired hooks now equals exactly the set of node_models rows actually deleted, with no interleaving gap. The status flip in MarkOffline/MarkDraining (setStatus) is a separate, pre-existing operation and routing already filters non-healthy nodes, so it stays outside the transaction; return contracts are unchanged. Deregister was already correct and is untouched. The cheap-path skip (no hook -> skip the SELECT) is preserved. Adds a spec asserting MarkOffline fires hooks for exactly the rows it deletes and leaves no node_models row behind (consistent snapshot). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(nodes): debug logging for prefix-cache routing decisions and observations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(radixtree): match shared prefixes by valuing every node on insert Insert recorded the value (node id) only on the final node of the key chain, leaving every intermediate prefix node valueless. LongestMatch returns the deepest node that hasValue, so two chains that share a leading block but diverge in the tail never matched: only exact-repeat queries hit. That broke the prefix-cache routing core use cases (shared system prompt, multi-turn extension, volatile tail), all of which rely on prefix matching rather than exact-repeat. Set value/hasValue/lastSeen at every node along the chain so each prefix-block node remembers the node id that served that prefix (SGLang/vLLM-style). The deepest match wins, and the last writer owns a shared prefix node (a recency heuristic: the most recent chain through a block is the one most likely still warm). size now counts valued nodes, which is the intended meaning. Updated radixtree tests to the new semantics: deepest-prefix test uses non-overlapping chains, a new test asserts last-writer-owns-shared-node, Evict/Remove/MaxEntries expectations recomputed for per-prefix-node counting, and a shared-prefix LongestMatch red test added. Added a prefixcache Decide test proving a prefix-only query routes to the warm node. No prefixcache .go logic changed. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(distributed): lock in prefix-cache routing behavior end to end Add a DB-backed e2e spec that drives SmartRouter against a real NodeRegistry (Postgres testcontainer) and the real prefixcache.Index radix-tree provider, using a fake gRPC backend factory so no real inference runs. Covers the five behaviors validated by hand: 1. Cold miss + observe: an unseen prefix chain cold-places and is recorded. 2. Hot-match affinity: the same chain returns to its warm node X. 3. Shared-prefix match: a divergent chain sharing X's leading prefix still routes to X (the radix-tree regression we fixed). 4. Negative control: an unrelated chain is a cold miss, not a false hot match on X. 5. Failover + invalidation: removing X's replica fires the registry chokepoint hook to invalidate the prefix entry, and the chain fails over to surviving node Y and re-homes there. Replaces the need for manual docker-compose re-runs. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(prefixcache): make prefix-cache affinity replica-granular Track prefix-cache affinity per loaded replica (a backend process with its own KV cache) instead of per node, so multiple replicas of the same model on one node each keep distinct affinity and a hot prefix routes back to the exact replica that served it. - radixtree: add RemoveFunc(pred) and reimplement Remove on top of it. - prefixcache: introduce ReplicaKey{NodeID, Replica}; Index/Candidate/ PrefixDecision/Select/Provider now key on ReplicaKey. Add InvalidateNode to drop every replica of a node; Invalidate drops one replica. Select returns (ReplicaKey, bool) and gains a deterministic least-in-flight eligible fallback (tiebreak NodeID then Replica). - messaging: carry Replica on PrefixCacheObserveEvent and PrefixCacheInvalidateEvent (Replica < 0 means all replicas of the node). - Sync delegates + broadcasts with replica; InvalidateNode broadcasts Replica=-1; ApplyInvalidate routes negative replica to InvalidateNode. This is part 1 of 2; the registry/router/wiring consumers are updated separately. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): make prefix-cache routing replica-granular Wire the SmartRouter, NodeRegistry, and distributed startup to the replica-keyed prefixcache API. Affinity is now tracked per replica (each replica is a separate process with its own KV cache), so a prefix served by (node,0) no longer leaks onto the same-node sibling (node,1). - RoutePreference gains PreferredReplica; FindAndLockNodeWithModel locks the EXACT (node_id, replica_index) row, falling through to the default ORDER BY when that replica is not loaded. - SetReplicaRemovedHook now carries replicaIndex; RemoveNodeModel fires the specific replica, RemoveAllNodeModelReplicas and the four bulk node-scoped deletes fire replica<0 (all replicas of the node). - buildPreference builds one Candidate per loaded replica and locks the exact replica the policy chose; observePrefix records the served ReplicaKey at every call site. - distributed.go routes the hook to InvalidateNode (replica<0) or Invalidate(key). - Tests updated to the replica-keyed API plus new coverage: a hot prefix on (node,0) prefers replica 0 over the same-node sibling (router unit + e2e), and FindAndLock locks the exact preferred replica. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(distributed): derive prefix chain from messages for tokenizer-template models Prefix-cache-aware routing built its prompt-prefix chain from the rendered prompt string `s` in ModelInference. For models with TemplateConfig.UseTokenizerTemplate the frontend never renders a prompt - the backend tokenizes the structured messages itself - so `s` is empty, the chain is empty, and routing silently falls back to round-robin. That covers the bulk of modern chat models (qwen3, llama3, ...), so the feature effectively never engaged for them. Fall back to messagesPrefixSource(messages): a deterministic, prefix-stable head-first serialization of the conversation (role + content per turn). Two requests sharing a leading system prompt and early turns share a leading byte prefix, which ExtractChain maps to a shared chain prefix - landing both on the same cache-warm replica. The rendered `s` is still preferred when present (higher fidelity for non-template models). Found via the multi-replica-per-node e2e: zero "prefix-cache routing decision" logs despite per-request Route calls, traced to the empty-chain guard. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): document prefix-cache routing roadmap Add a routing-and-caching roadmap section to the distributed-mode guide, linking the epic (#10063) and the follow-up issues (#10064-#10070) surfaced from a survey of SGLang, vLLM production-stack, Ray Serve, llm-d, AIBrix, and NVIDIA Dynamo. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -16,7 +16,9 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
@@ -240,6 +242,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
|
||||
)
|
||||
|
||||
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
|
||||
// with --distributed-prefix-cache=false, which leaves prefixProvider and
|
||||
// pressure nil so the SmartRouter and reconciler behave exactly as the
|
||||
// round-robin floor (true no-op). When enabled we build the local index,
|
||||
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
|
||||
// via the subscriptions below), install the extraction hook used by
|
||||
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
|
||||
var prefixProvider prefixcache.Provider
|
||||
var pressure *prefixcache.Pressure
|
||||
var prefixCfg prefixcache.Config
|
||||
if !cfg.Distributed.PrefixCacheDisabled {
|
||||
prefixCfg = prefixcache.DefaultConfig()
|
||||
if cfg.Distributed.PrefixCacheTTL > 0 {
|
||||
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
|
||||
}
|
||||
if err := prefixCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
|
||||
}
|
||||
idx := prefixcache.NewIndex(prefixCfg)
|
||||
prefixSync := prefixcache.NewSync(idx, natsClient)
|
||||
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
|
||||
prefixProvider = prefixSync
|
||||
|
||||
// Invalidate the prefix-cache index whenever a replica row is removed.
|
||||
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
|
||||
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
|
||||
// one hook covers every path: reconciler scale-down, probe reaper,
|
||||
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
|
||||
// it only inside this enabled block keeps the disabled path a true no-op
|
||||
// (the registry stays hook-less).
|
||||
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
|
||||
if replica < 0 {
|
||||
prefixSync.InvalidateNode(model, node)
|
||||
} else {
|
||||
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
|
||||
return prefixcache.ExtractChain(model, prompt, prefixCfg)
|
||||
}
|
||||
|
||||
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
|
||||
// and ApplyInvalidate update only the local index and do not re-publish,
|
||||
// so there is no broadcast loop.
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
|
||||
prefixSync.ApplyObserve(ev, time.Now())
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
|
||||
}
|
||||
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
prefixSync.ApplyInvalidate(ev)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
|
||||
}
|
||||
|
||||
// Background eviction: sweep idle entries on the app context. Stopped
|
||||
// when the app context is cancelled (mirrors the reconciler loop which
|
||||
// also runs on options.Context). TTL/2 keeps stale entries from
|
||||
// outliving their idle window by more than half a TTL.
|
||||
evictInterval := prefixCfg.TTL / 2
|
||||
go func() {
|
||||
ticker := time.NewTicker(evictInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
prefixSync.Evict(time.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
|
||||
} else {
|
||||
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
|
||||
}
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
@@ -252,6 +332,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
PrefixProvider: prefixProvider,
|
||||
PrefixConfig: prefixCfg,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
@@ -268,6 +351,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
ProbeStaleAfter: 2 * time.Minute,
|
||||
Pressure: pressure,
|
||||
PressureThreshold: prefixCfg.PressureScaleThreshold,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
// Make the rendered prompt's prefix chain available to the distributed router
|
||||
// for prefix-cache-aware node selection. No-op in single-process mode. The
|
||||
// model id MUST match the id ModelOptions feeds to model.WithModelID, so both
|
||||
// use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
|
||||
// Model) or the chain salt and the tracking key would diverge.
|
||||
//
|
||||
// s is empty for UseTokenizerTemplate models (the backend tokenizes the
|
||||
// structured messages itself), so fall back to a prefix-stable serialization
|
||||
// of the messages - otherwise prefix routing would silently degrade to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
chainSource := s
|
||||
if chainSource == "" {
|
||||
chainSource = messagesPrefixSource(messages)
|
||||
}
|
||||
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
|
||||
|
||||
opts := ModelOptions(*c, o, model.WithContext(ctx))
|
||||
inferenceModel, err := loader.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
|
||||
}
|
||||
|
||||
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
|
||||
name := c.Name
|
||||
if name == "" {
|
||||
name = c.Model
|
||||
}
|
||||
|
||||
defOpts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(so.Context),
|
||||
model.WithModelID(name),
|
||||
model.WithModelID(c.ModelID()),
|
||||
}
|
||||
|
||||
threads := 1
|
||||
|
||||
36
core/backend/prefix_source.go
Normal file
36
core/backend/prefix_source.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
|
||||
// chat conversation for prefix-cache-aware routing. It is the fallback used when
|
||||
// the frontend did not render a prompt string: models with
|
||||
// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
|
||||
// backend-side, so the frontend's rendered prompt is empty and a chain built
|
||||
// from it would always be empty - silently degrading prefix routing to
|
||||
// round-robin for the bulk of modern chat models.
|
||||
//
|
||||
// Messages are emitted head-first in turn order (role line + content line per
|
||||
// message), so two conversations sharing a leading system prompt and early turns
|
||||
// share a leading byte prefix. That is exactly what ExtractChain hashes into a
|
||||
// shared chain prefix, landing both requests on the same cache-warm replica.
|
||||
func messagesPrefixSource(messages schema.Messages) string {
|
||||
var b strings.Builder
|
||||
for _, m := range messages {
|
||||
b.WriteString(m.Role)
|
||||
b.WriteByte('\n')
|
||||
content := m.StringContent
|
||||
if content == "" {
|
||||
if s, ok := m.Content.(string); ok {
|
||||
content = s
|
||||
}
|
||||
}
|
||||
b.WriteString(content)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
53
core/backend/prefix_source_internal_test.go
Normal file
53
core/backend/prefix_source_internal_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("messagesPrefixSource", func() {
|
||||
mk := func(role, content string) schema.Message {
|
||||
return schema.Message{Role: role, StringContent: content}
|
||||
}
|
||||
|
||||
It("serializes messages head-first in turn order", func() {
|
||||
got := messagesPrefixSource(schema.Messages{
|
||||
mk("system", "You are helpful."),
|
||||
mk("user", "Hi"),
|
||||
})
|
||||
Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
|
||||
})
|
||||
|
||||
It("is deterministic across calls for the same conversation", func() {
|
||||
conv := schema.Messages{mk("system", "S"), mk("user", "U")}
|
||||
Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
|
||||
})
|
||||
|
||||
It("shares a leading byte prefix when the system prompt is shared", func() {
|
||||
shared := "system\nShared system prompt.\nuser\n"
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
|
||||
Expect(strings.HasPrefix(a, shared)).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, shared)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT share a prefix when the system prompt differs", func() {
|
||||
a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
|
||||
b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
|
||||
Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
|
||||
Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty for no messages", func() {
|
||||
Expect(messagesPrefixSource(nil)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("falls back to Content when StringContent is empty", func() {
|
||||
got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
|
||||
Expect(got).To(Equal("user\nplain\n"))
|
||||
})
|
||||
})
|
||||
@@ -145,19 +145,21 @@ type RunCMD struct {
|
||||
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
|
||||
|
||||
// Distributed / Horizontal Scaling
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
|
||||
InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
|
||||
StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
|
||||
StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
|
||||
StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
|
||||
StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
|
||||
AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
|
||||
DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
@@ -284,6 +286,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
if !r.DistributedPrefixCache {
|
||||
opts = append(opts, config.DisablePrefixCache)
|
||||
}
|
||||
if r.DistributedPrefixCacheTTL != "" {
|
||||
d, err := time.ParseDuration(r.DistributedPrefixCacheTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithPrefixCacheTTL(d))
|
||||
}
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
|
||||
@@ -49,6 +49,17 @@ type DistributedConfig struct {
|
||||
|
||||
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
|
||||
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
|
||||
|
||||
// PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to
|
||||
// round-robin (the floor). Prefix-cache routing is ON by default in
|
||||
// distributed mode; this flag exists so operators can opt out. The CLI
|
||||
// surfaces a default-true --distributed-prefix-cache enable flag and sets
|
||||
// this when the operator passes --distributed-prefix-cache=false.
|
||||
PrefixCacheDisabled bool
|
||||
// PrefixCacheTTL is the idle-timeout for prefix-cache index entries and
|
||||
// drives the background eviction cadence (eviction runs every TTL/2). Zero
|
||||
// means use the prefixcache package default (5m).
|
||||
PrefixCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
@@ -158,6 +169,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
|
||||
o.Distributed.AutoApproveNodes = true
|
||||
}
|
||||
|
||||
// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
|
||||
// round-robin). Prefix-cache routing is enabled by default in distributed mode.
|
||||
var DisablePrefixCache = func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheDisabled = true
|
||||
}
|
||||
|
||||
// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the
|
||||
// background eviction cadence, which runs every TTL/2).
|
||||
func WithPrefixCacheTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.PrefixCacheTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
|
||||
@@ -694,6 +694,18 @@ func (c *ModelConfig) IsModelURL() bool {
|
||||
return uri.LooksLikeURL()
|
||||
}
|
||||
|
||||
// ModelID returns the identifier used to reference this model across the
|
||||
// system: the configured Name, falling back to Model when Name is empty.
|
||||
// This is the single source of truth for the id fed to model.WithModelID and
|
||||
// the prefix-cache chain salt; both MUST agree with the router's tracking key
|
||||
// or the prefix-cache salt diverges silently.
|
||||
func (c ModelConfig) ModelID() string {
|
||||
if c.Name != "" {
|
||||
return c.Name
|
||||
}
|
||||
return c.Model
|
||||
}
|
||||
|
||||
// ModelFileName returns the filename of the model
|
||||
// If the model is a URL, it will return the MD5 of the URL which is the filename
|
||||
func (c *ModelConfig) ModelFileName() string {
|
||||
|
||||
@@ -10,6 +10,23 @@ import (
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("ModelID", func() {
|
||||
It("returns Name when set", func() {
|
||||
c := ModelConfig{Name: "my-name"}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-name"))
|
||||
})
|
||||
It("falls back to Model when Name is empty", func() {
|
||||
c := ModelConfig{}
|
||||
c.Model = "my-model"
|
||||
Expect(c.ModelID()).To(Equal("my-model"))
|
||||
})
|
||||
It("returns empty string when both are empty", func() {
|
||||
c := ModelConfig{}
|
||||
Expect(c.ModelID()).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
It("Test Validate", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
@@ -911,14 +913,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
||||
//
|
||||
// The four prefix-cache fields are POINTERS so an omitted field is
|
||||
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
||||
// field preserves the model's previously-configured value instead of resetting
|
||||
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
||||
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
RoutePolicy *string `json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
|
||||
}
|
||||
|
||||
// validateSchedulingRequest enforces the invariants of a scheduling config.
|
||||
// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
|
||||
// single source of truth), and are checked against the RESOLVED values passed
|
||||
// in (provided-or-preserved), so validation only rejects bad values the caller
|
||||
// actually supplied. It returns nil when valid, or an error with a user-facing
|
||||
// message describing the first violation.
|
||||
func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
if req.ModelName == "" {
|
||||
return errors.New("model_name is required")
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return errors.New("min_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return errors.New("max_replicas must be >= 0")
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return errors.New("min_replicas must be <= max_replicas")
|
||||
}
|
||||
if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
||||
//
|
||||
// The registry upsert full-replaces all columns, so a request that omits the
|
||||
// prefix-cache fields would otherwise wipe a model's previously-configured
|
||||
// routing settings. To avoid that footgun the four prefix-cache fields are
|
||||
// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
|
||||
// existing config's value (or the zero default when no config exists yet). The
|
||||
// non-prefix fields keep their full-replace PUT behavior.
|
||||
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
@@ -926,17 +970,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if req.ModelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
|
||||
|
||||
// Fetch the existing config (may be nil) so omitted prefix-cache fields
|
||||
// can fall back to the stored value rather than resetting to zero.
|
||||
var existing *nodes.ModelSchedulingConfig
|
||||
if req.ModelName != "" {
|
||||
var err error
|
||||
existing, err = registry.GetModelScheduling(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
|
||||
}
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
|
||||
|
||||
// Resolve each prefix-cache field: provided pointer wins, otherwise keep
|
||||
// the existing value (zero/default when there is no existing config).
|
||||
routePolicy := ""
|
||||
absThr := 0
|
||||
relThr := 0.0
|
||||
minMatch := 0.0
|
||||
if existing != nil {
|
||||
routePolicy = existing.RoutePolicy
|
||||
absThr = existing.BalanceAbsThreshold
|
||||
relThr = existing.BalanceRelThreshold
|
||||
minMatch = existing.MinPrefixMatch
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
|
||||
if req.RoutePolicy != nil {
|
||||
routePolicy = *req.RoutePolicy
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
|
||||
if req.BalanceAbsThreshold != nil {
|
||||
absThr = *req.BalanceAbsThreshold
|
||||
}
|
||||
if req.BalanceRelThreshold != nil {
|
||||
relThr = *req.BalanceRelThreshold
|
||||
}
|
||||
if req.MinPrefixMatch != nil {
|
||||
minMatch = *req.MinPrefixMatch
|
||||
}
|
||||
|
||||
if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
|
||||
}
|
||||
|
||||
// Serialize node selector to JSON
|
||||
@@ -950,10 +1022,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
}
|
||||
|
||||
config := &nodes.ModelSchedulingConfig{
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
RoutePolicy: routePolicy,
|
||||
BalanceAbsThreshold: absThr,
|
||||
BalanceRelThreshold: relThr,
|
||||
MinPrefixMatch: minMatch,
|
||||
}
|
||||
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("validateSchedulingRequest", func() {
|
||||
base := func() SetSchedulingRequest {
|
||||
return SetSchedulingRequest{ModelName: "m"}
|
||||
}
|
||||
|
||||
It("accepts an empty route policy (inherit) with valid thresholds", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the prefix_cache policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts the round_robin policy", func() {
|
||||
Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("accepts balance_rel_threshold >= 1", func() {
|
||||
Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects a missing model_name", func() {
|
||||
req := base()
|
||||
req.ModelName = ""
|
||||
err := validateSchedulingRequest(req, "", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("model_name is required"))
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (no silent default)", func() {
|
||||
err := validateSchedulingRequest(base(), "bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, 2)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := validateSchedulingRequest(base(), "", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := validateSchedulingRequest(base(), "", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
@@ -230,6 +230,114 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetSchedulingEndpoint", func() {
|
||||
postScheduling := func(body string) *httptest.ResponseRecorder {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := SetSchedulingEndpoint(registry)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
return rec
|
||||
}
|
||||
|
||||
It("persists prefix-cache fields and round-trips them via GET", func() {
|
||||
ctx := context.Background()
|
||||
rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "pc-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(cfg.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
|
||||
e := echo.New()
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
gc := e.NewContext(getReq, getRec)
|
||||
gc.SetParamNames("model")
|
||||
gc.SetParamValues("pc-model")
|
||||
Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed())
|
||||
Expect(getRec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var got nodes.ModelSchedulingConfig
|
||||
Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed())
|
||||
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(got.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
|
||||
})
|
||||
|
||||
It("returns 400 for an out-of-range min_prefix_match", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("returns 400 for an unknown route_policy", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("returns 400 for a balance_rel_threshold between 0 and 1", func() {
|
||||
rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
|
||||
// Regression for the partial-update footgun: a min/max-only POST used to
|
||||
// full-replace every column and silently reset the prefix-cache settings
|
||||
// to empty/zero. The pointer-merge must preserve omitted prefix fields.
|
||||
It("preserves prefix-cache settings across a min_replicas-only update", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
// Update only min_replicas - omits all prefix-cache fields.
|
||||
rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "merge-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update")
|
||||
Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved")
|
||||
})
|
||||
|
||||
It("updates a prefix-cache field when it is explicitly provided", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
cfg, err := registry.GetModelScheduling(ctx, "update-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update")
|
||||
Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListNodesEndpoint", func() {
|
||||
It("returns an empty list when no nodes are registered", func() {
|
||||
e := echo.New()
|
||||
|
||||
@@ -493,6 +493,13 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
const [selector, setSelector] = useState({})
|
||||
const [minReplicas, setMinReplicas] = useState(1)
|
||||
const [maxReplicas, setMaxReplicas] = useState(0)
|
||||
// Prefix-cache routing controls. Empty routePolicy means "inherit the
|
||||
// cluster default"; the three thresholds at 0 likewise inherit, so they
|
||||
// stay out of the POST body's effective override only when explicitly set.
|
||||
const [routePolicy, setRoutePolicy] = useState('')
|
||||
const [balanceAbsThreshold, setBalanceAbsThreshold] = useState(0)
|
||||
const [balanceRelThreshold, setBalanceRelThreshold] = useState(0)
|
||||
const [minPrefixMatch, setMinPrefixMatch] = useState(0)
|
||||
|
||||
const hasSelector = Object.keys(selector).length > 0
|
||||
|
||||
@@ -508,6 +515,10 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
node_selector: hasSelector ? selector : undefined,
|
||||
min_replicas: mode === 'placement' ? 0 : minReplicas,
|
||||
max_replicas: mode === 'placement' ? 0 : maxReplicas,
|
||||
route_policy: routePolicy,
|
||||
balance_abs_threshold: balanceAbsThreshold,
|
||||
balance_rel_threshold: balanceRelThreshold,
|
||||
min_prefix_match: minPrefixMatch,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -593,6 +604,76 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Per-model routing policy. Left empty/zero these inherit the
|
||||
cluster-wide defaults; set them to override how requests for this
|
||||
model are spread across replicas. */}
|
||||
<div>
|
||||
<label className="form-label" htmlFor="sched-route-policy">Routing policy</label>
|
||||
<select
|
||||
id="sched-route-policy"
|
||||
className="input"
|
||||
value={routePolicy}
|
||||
onChange={e => setRoutePolicy(e.target.value)}
|
||||
>
|
||||
<option value="">Default (cluster setting)</option>
|
||||
<option value="round_robin">Round Robin</option>
|
||||
<option value="prefix_cache">Prefix Cache</option>
|
||||
</select>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Prefix Cache routes shared-prefix requests to the same replica to reuse its KV cache, falling back to round-robin when replicas are imbalanced.
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{routePolicy === 'prefix_cache' && (
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)' }}>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-min-prefix-match">Min prefix match</label>
|
||||
<input
|
||||
id="sched-min-prefix-match"
|
||||
className="input"
|
||||
type="number"
|
||||
step="0.05"
|
||||
min="0"
|
||||
max="1"
|
||||
value={minPrefixMatch}
|
||||
onChange={e => setMinPrefixMatch(parseFloat(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Fraction of the prompt (0..1) that must match a cached prefix before affinity kicks in. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-balance-abs">Balance abs threshold</label>
|
||||
<input
|
||||
id="sched-balance-abs"
|
||||
className="input"
|
||||
type="number"
|
||||
min="0"
|
||||
value={balanceAbsThreshold}
|
||||
onChange={e => setBalanceAbsThreshold(parseInt(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Max absolute in-flight gap allowed before falling back to round-robin. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
<div style={{ flex: 1 }}>
|
||||
<label className="form-label" htmlFor="sched-balance-rel">Balance rel threshold</label>
|
||||
<input
|
||||
id="sched-balance-rel"
|
||||
className="input"
|
||||
type="number"
|
||||
step="0.1"
|
||||
min="0"
|
||||
value={balanceRelThreshold}
|
||||
onChange={e => setBalanceRelThreshold(parseFloat(e.target.value) || 0)}
|
||||
/>
|
||||
<span style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', display: 'block', marginTop: 6 }}>
|
||||
Max relative in-flight ratio (>= 1) allowed before falling back to round-robin. 0 inherits the default.
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Hairline divider above the actions, matching the project's form pattern. */}
|
||||
@@ -1475,6 +1556,8 @@ export default function Nodes() {
|
||||
<th>Node Selector</th>
|
||||
<th>Min Replicas</th>
|
||||
<th>Max Replicas</th>
|
||||
<th>Routing</th>
|
||||
<th>Thresholds</th>
|
||||
<th>Status</th>
|
||||
<th style={{ textAlign: 'right' }}>Actions</th>
|
||||
</tr></thead>
|
||||
@@ -1519,6 +1602,18 @@ export default function Nodes() {
|
||||
<td style={{ fontFamily: 'var(--font-mono)' }}>
|
||||
{isAutoScaling ? (cfg.max_replicas || 'no limit') : '-'}
|
||||
</td>
|
||||
<td style={{ fontSize: '0.8125rem' }}>
|
||||
{cfg.route_policy || 'default'}
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
{cfg.route_policy === 'prefix_cache' ? (
|
||||
<>
|
||||
<div>match: {cfg.min_prefix_match ? cfg.min_prefix_match : 'inherit'}</div>
|
||||
<div>abs: {cfg.balance_abs_threshold ? cfg.balance_abs_threshold : 'inherit'}</div>
|
||||
<div>rel: {cfg.balance_rel_threshold ? cfg.balance_rel_threshold : 'inherit'}</div>
|
||||
</>
|
||||
) : '-'}
|
||||
</td>
|
||||
<td>
|
||||
{isUnsatisfiable ? (
|
||||
<span
|
||||
|
||||
@@ -355,3 +355,35 @@ type CacheInvalidateEvent struct {
|
||||
func SubjectCacheInvalidateCollection(name string) string {
|
||||
return "cache.invalidate.collections." + sanitizeSubjectToken(name)
|
||||
}
|
||||
|
||||
// Prefix-Cache Routing Sync (Pub/Sub - broadcast to all frontends)
|
||||
//
|
||||
// Frontends share prefix-cache observations so a request routed to any replica
|
||||
// benefits from the prefix-affinity another replica already learned. This
|
||||
// mirrors the OpCache live-sync pattern: plain NATS Core pub/sub, no JetStream.
|
||||
const (
|
||||
SubjectPrefixCacheObserve = "prefixcache.observe"
|
||||
SubjectPrefixCacheInvalidate = "prefixcache.invalidate"
|
||||
)
|
||||
|
||||
// PrefixCacheObserveEvent announces that the replica (NodeID, Replica) served a
|
||||
// request whose prefix chain ends at the given hashes for model. Chain is the
|
||||
// full shallow-to-deep hash chain so peers can insert the same path. Affinity is
|
||||
// per replica (a backend process with its own KV cache), not per node, so the
|
||||
// replica index is carried so peers attribute the observation to the same one.
|
||||
type PrefixCacheObserveEvent struct {
|
||||
Model string `json:"model"`
|
||||
Chain []uint64 `json:"chain"`
|
||||
NodeID string `json:"node_id"`
|
||||
Replica int `json:"replica"`
|
||||
}
|
||||
|
||||
// PrefixCacheInvalidateEvent tells peers to drop entries for a replica. When
|
||||
// Replica >= 0 it targets the single replica (Model, NodeID, Replica). When
|
||||
// Replica < 0 it targets ALL replicas of (Model, NodeID), for example when a
|
||||
// whole node goes offline.
|
||||
type PrefixCacheInvalidateEvent struct {
|
||||
Model string `json:"model"`
|
||||
NodeID string `json:"node_id"`
|
||||
Replica int `json:"replica"`
|
||||
}
|
||||
|
||||
27
core/services/messaging/subjects_prefixcache_test.go
Normal file
27
core/services/messaging/subjects_prefixcache_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
var _ = Describe("PrefixCache subjects", func() {
|
||||
It("exposes stable subject constants", func() {
|
||||
Expect(messaging.SubjectPrefixCacheObserve).To(Equal("prefixcache.observe"))
|
||||
Expect(messaging.SubjectPrefixCacheInvalidate).To(Equal("prefixcache.invalidate"))
|
||||
})
|
||||
|
||||
It("carries a replica index on the observe event", func() {
|
||||
ev := messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 3}
|
||||
Expect(ev.Replica).To(Equal(3))
|
||||
})
|
||||
|
||||
It("uses a negative replica on the invalidate event to mean all replicas of a node", func() {
|
||||
all := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1}
|
||||
Expect(all.Replica).To(BeNumerically("<", 0))
|
||||
one := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0}
|
||||
Expect(one.Replica).To(Equal(0))
|
||||
})
|
||||
})
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// ModelRouter is used by SmartRouter for routing decisions and model lifecycle.
|
||||
type ModelRouter interface {
|
||||
FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error)
|
||||
FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error)
|
||||
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
@@ -37,6 +37,7 @@ type ModelRouter interface {
|
||||
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
|
||||
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
|
||||
LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error)
|
||||
}
|
||||
|
||||
// ConcurrencyConflictResolver returns the names of configured models that
|
||||
|
||||
@@ -27,7 +27,7 @@ func newFakeModelRouterForSmartRouter() *fakeModelRouterForSmartRouter {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string) (*BackendNode, *NodeModel, error) {
|
||||
func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string, _ *RoutePreference) (*BackendNode, *NodeModel, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.node, f.nodeModel, f.findErr
|
||||
@@ -121,6 +121,9 @@ func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ strin
|
||||
func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) LoadedReplicaStats(_ context.Context, _ string, _ []string) ([]ReplicaCandidate, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
|
||||
|
||||
95
core/services/nodes/prefixcache/config.go
Normal file
95
core/services/nodes/prefixcache/config.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds prefix-cache-aware routing settings. Per-model overrides
|
||||
// (policy, abs/rel thresholds, min-match) live on ModelSchedulingConfig; TTL
|
||||
// and window/depth are global-only.
|
||||
type Config struct {
|
||||
GlobalPolicy RoutePolicy
|
||||
MinPrefixMatch float64 // ratio matched/total, [0,1]
|
||||
BalanceAbsThreshold int // absolute in-flight slack
|
||||
BalanceRelThreshold float64 // relative load ratio, >= 1
|
||||
TTL time.Duration // idle-timeout for entries
|
||||
HalfLife time.Duration // recency decay for cacheWeight
|
||||
WindowBytes int // chunk window size
|
||||
MaxDepth int // max trailing blocks hashed
|
||||
// PressureWindow is the rolling window over which forced-disturb events are
|
||||
// counted for the autoscale signal (see Pressure). Default 1 minute.
|
||||
PressureWindow time.Duration
|
||||
// PressureScaleThreshold is the minimum forced-disturb count within
|
||||
// PressureWindow that makes the reconciler treat the cache-warm replica as
|
||||
// saturated and scale up (subject to MaxReplicas and capacity). Default 1,
|
||||
// i.e. any sustained forced-disturb.
|
||||
PressureScaleThreshold int
|
||||
}
|
||||
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
GlobalPolicy: RoutePolicyPrefixCache,
|
||||
MinPrefixMatch: 0.3,
|
||||
BalanceAbsThreshold: 2,
|
||||
BalanceRelThreshold: 1.5,
|
||||
TTL: 5 * time.Minute,
|
||||
HalfLife: 2 * time.Minute,
|
||||
WindowBytes: 256,
|
||||
MaxDepth: 64,
|
||||
PressureWindow: time.Minute,
|
||||
PressureScaleThreshold: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// validateThresholdBounds enforces the numeric bounds shared between the
|
||||
// per-model override validator (ValidateThresholds) and Config.Validate:
|
||||
// minMatch in [0,1]; absThr >= 0; relThr == 0 (inherit) or >= 1. It is the
|
||||
// single source of truth for those bounds so the endpoint and the global
|
||||
// config cannot drift apart.
|
||||
func validateThresholdBounds(absThr int, relThr, minMatch float64) error {
|
||||
if minMatch < 0 || minMatch > 1 {
|
||||
return fmt.Errorf("prefixcache: min_prefix_match must be in [0,1], got %v", minMatch)
|
||||
}
|
||||
if absThr < 0 {
|
||||
return fmt.Errorf("prefixcache: balance_abs_threshold must be >= 0, got %d", absThr)
|
||||
}
|
||||
if relThr != 0 && relThr < 1 {
|
||||
return fmt.Errorf("prefixcache: balance_rel_threshold must be 0 (inherit) or >= 1, got %v", relThr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateThresholds checks per-model override bounds. routePolicy must be one
|
||||
// of "", "round_robin", "prefix_cache" (explicit allow-list - NOT ParsePolicy,
|
||||
// which maps unknown to Default and would accept typos). minMatch in [0,1];
|
||||
// absThr >= 0; relThr == 0 (inherit) or >= 1.
|
||||
func ValidateThresholds(routePolicy string, absThr int, relThr, minMatch float64) error {
|
||||
switch routePolicy {
|
||||
case "", "round_robin", "prefix_cache":
|
||||
default:
|
||||
return fmt.Errorf(`prefixcache: route_policy must be one of "", "round_robin", "prefix_cache", got %q`, routePolicy)
|
||||
}
|
||||
return validateThresholdBounds(absThr, relThr, minMatch)
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
// Config.BalanceRelThreshold has no "inherit" sentinel - it is a concrete
|
||||
// global value that must be >= 1 - so pass 0 for relThr to the shared
|
||||
// numeric check and assert the >= 1 floor here separately.
|
||||
if err := validateThresholdBounds(c.BalanceAbsThreshold, 0, c.MinPrefixMatch); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.BalanceRelThreshold < 1 {
|
||||
return fmt.Errorf("prefixcache: balance_rel_threshold must be >= 1, got %v", c.BalanceRelThreshold)
|
||||
}
|
||||
if c.WindowBytes <= 0 || c.MaxDepth <= 0 {
|
||||
return fmt.Errorf("prefixcache: window_bytes and max_depth must be > 0")
|
||||
}
|
||||
// TTL must be positive: it is the entry idle-lifetime and the eviction
|
||||
// ticker runs at TTL/2, so time.NewTicker would panic on TTL <= 0.
|
||||
if c.TTL <= 0 {
|
||||
return fmt.Errorf("prefixcache: ttl must be > 0, got %v", c.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
73
core/services/nodes/prefixcache/config_test.go
Normal file
73
core/services/nodes/prefixcache/config_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("Config", func() {
|
||||
It("supplies defaults", func() {
|
||||
c := prefixcache.DefaultConfig()
|
||||
Expect(c.GlobalPolicy).To(Equal(prefixcache.RoutePolicyPrefixCache)) // default ON
|
||||
Expect(c.MinPrefixMatch).To(BeNumerically("==", 0.3))
|
||||
Expect(c.BalanceAbsThreshold).To(Equal(2))
|
||||
Expect(c.BalanceRelThreshold).To(BeNumerically("==", 1.5))
|
||||
Expect(c.TTL).To(Equal(5 * time.Minute))
|
||||
Expect(c.WindowBytes).To(Equal(256))
|
||||
Expect(c.MaxDepth).To(Equal(64))
|
||||
})
|
||||
|
||||
It("rejects invalid values", func() {
|
||||
c := prefixcache.DefaultConfig()
|
||||
c.MinPrefixMatch = 1.5
|
||||
Expect(c.Validate()).To(HaveOccurred())
|
||||
c = prefixcache.DefaultConfig()
|
||||
c.BalanceAbsThreshold = -1
|
||||
Expect(c.Validate()).To(HaveOccurred())
|
||||
c = prefixcache.DefaultConfig()
|
||||
c.TTL = 0
|
||||
Expect(c.Validate()).To(HaveOccurred()) // TTL/2 ticker would panic
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("ValidateThresholds", func() {
|
||||
It("accepts valid values across all route policies", func() {
|
||||
Expect(prefixcache.ValidateThresholds("", 3, 0, 0.4)).To(Succeed())
|
||||
Expect(prefixcache.ValidateThresholds("round_robin", 0, 1.5, 0)).To(Succeed())
|
||||
Expect(prefixcache.ValidateThresholds("prefix_cache", 2, 2.0, 1.0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects an unknown route_policy (explicit allow-list, no silent default)", func() {
|
||||
err := prefixcache.ValidateThresholds("bogus", 0, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("route_policy"))
|
||||
})
|
||||
|
||||
It("rejects min_prefix_match above 1", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0, 1.5)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative min_prefix_match", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0, -0.1)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
|
||||
})
|
||||
|
||||
It("rejects a negative balance_abs_threshold", func() {
|
||||
err := prefixcache.ValidateThresholds("", -1, 0, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
|
||||
})
|
||||
|
||||
It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
|
||||
err := prefixcache.ValidateThresholds("", 0, 0.5, 0)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
|
||||
})
|
||||
})
|
||||
18
core/services/nodes/prefixcache/export_test.go
Normal file
18
core/services/nodes/prefixcache/export_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package prefixcache
|
||||
|
||||
// LenForTest exposes the internal per-model slice length so black-box tests can
|
||||
// assert that Record bounds its backing slice. Test-only.
|
||||
func (p *Pressure) LenForTest(model string) int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.events[model])
|
||||
}
|
||||
|
||||
// TreeCountForTest exposes the number of per-model radix trees the Index
|
||||
// currently retains, so black-box tests can assert that Invalidate does not
|
||||
// intern empty trees for models that never used the prefix cache. Test-only.
|
||||
func (ix *Index) TreeCountForTest() int {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
return len(ix.trees)
|
||||
}
|
||||
57
core/services/nodes/prefixcache/extractor.go
Normal file
57
core/services/nodes/prefixcache/extractor.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// ExtractChain renders prompt into a cumulative chain of prefix hashes:
|
||||
// h[0]=H(salt,block0), h[i]=H(h[i-1],block_i). Blocks are fixed
|
||||
// cfg.WindowBytes-byte windows over the prompt bytes, chunked from absolute
|
||||
// offset 0 with fixed boundaries [0,W), [W,2W), ... and the chain is capped to
|
||||
// the FIRST cfg.MaxDepth blocks (the head).
|
||||
//
|
||||
// Head-first chunking is what makes this a true prefix-chain. The reusable
|
||||
// KV/prefix cache is always at the HEAD of the prompt: the system prompt and
|
||||
// early turns are stable, new content is appended at the end, and the KV cache
|
||||
// is valid up to the first differing token scanning from the start. Because the
|
||||
// boundaries are anchored at offset 0 (never length-dependent), a prompt P and
|
||||
// any extension P+suffix share their entire leading overlap, so turn N and turn
|
||||
// N+1 match for longest-prefix routing. Prefixes deeper than
|
||||
// MaxDepth*WindowBytes bytes are treated as equal (two prompts agreeing on the
|
||||
// first MaxDepth head blocks yield identical chains): an accepted routing-hint
|
||||
// limitation, since the cap bounds the chain length for very long prompts.
|
||||
//
|
||||
// xxhash is used (not hash/maphash) because the hash MUST be identical across
|
||||
// frontend processes: peers exchange these hashes over NATS, and maphash uses a
|
||||
// per-process random seed that would make peers disagree.
|
||||
func ExtractChain(model, prompt string, cfg Config) []uint64 {
|
||||
if prompt == "" {
|
||||
return nil
|
||||
}
|
||||
data := []byte(prompt)
|
||||
nBlocks := (len(data) + cfg.WindowBytes - 1) / cfg.WindowBytes
|
||||
depth := min(nBlocks, cfg.MaxDepth)
|
||||
salt := xxhash.Sum64String(model)
|
||||
// One Digest reused across blocks: Reset() restores the seed-0 initial
|
||||
// state, so Reset()+Write produces the byte-identical value to a fresh
|
||||
// New()+Write. xxhash seed 0 is stateless, so output is unchanged while we
|
||||
// avoid allocating a Digest per block. The output determinism across
|
||||
// processes (peers exchange these hashes over NATS) is preserved.
|
||||
h := xxhash.New()
|
||||
chain := make([]uint64, 0, depth)
|
||||
prev := salt
|
||||
var pb [8]byte
|
||||
for i := range depth {
|
||||
off := i * cfg.WindowBytes
|
||||
end := min(off+cfg.WindowBytes, len(data))
|
||||
h.Reset()
|
||||
binary.LittleEndian.PutUint64(pb[:], prev)
|
||||
_, _ = h.Write(pb[:])
|
||||
_, _ = h.Write(data[off:end])
|
||||
prev = h.Sum64()
|
||||
chain = append(chain, prev)
|
||||
}
|
||||
return chain
|
||||
}
|
||||
75
core/services/nodes/prefixcache/extractor_test.go
Normal file
75
core/services/nodes/prefixcache/extractor_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("Extractor", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("produces a deterministic chain for the same prompt and model", func() {
|
||||
a := prefixcache.ExtractChain("modelX", "hello world", cfg)
|
||||
b := prefixcache.ExtractChain("modelX", "hello world", cfg)
|
||||
Expect(a).To(Equal(b))
|
||||
Expect(len(a)).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("shares the head but diverges on a volatile tail", func() {
|
||||
base := strings.Repeat("system rules ", 100) // > one window
|
||||
x := prefixcache.ExtractChain("m", base+"Current time 12:00:00", cfg)
|
||||
y := prefixcache.ExtractChain("m", base+"Current time 12:00:01", cfg)
|
||||
// leading hashes (the stable head) are identical
|
||||
Expect(x[0]).To(Equal(y[0]))
|
||||
// the final (tail) hash differs
|
||||
Expect(x[len(x)-1]).NotTo(Equal(y[len(y)-1]))
|
||||
})
|
||||
|
||||
It("salts by model so identical text yields different chains per model", func() {
|
||||
Expect(prefixcache.ExtractChain("m1", "abc", cfg)[0]).
|
||||
NotTo(Equal(prefixcache.ExtractChain("m2", "abc", cfg)[0]))
|
||||
})
|
||||
|
||||
It("caps depth", func() {
|
||||
small := cfg
|
||||
small.WindowBytes = 1
|
||||
small.MaxDepth = 4
|
||||
chain := prefixcache.ExtractChain("m", "abcdefghij", small)
|
||||
Expect(len(chain)).To(Equal(4))
|
||||
})
|
||||
|
||||
It("returns nil for empty prompt", func() {
|
||||
Expect(prefixcache.ExtractChain("m", "", cfg)).To(BeNil())
|
||||
})
|
||||
|
||||
It("stays stable across turns once the prompt grows past the depth cap", func() {
|
||||
small := cfg
|
||||
small.WindowBytes = 4
|
||||
small.MaxDepth = 3 // 12-byte head budget
|
||||
|
||||
// base is longer than MaxDepth*WindowBytes so the chain is capped to
|
||||
// the first 3 head blocks.
|
||||
base := "system-rules-stable-prefix-that-exceeds-the-budget"
|
||||
Expect(len(base)).To(BeNumerically(">", small.WindowBytes*small.MaxDepth))
|
||||
|
||||
turnN := prefixcache.ExtractChain("m", base, small)
|
||||
turnN1 := prefixcache.ExtractChain("m", base+"more text appended", small)
|
||||
// Both capped to the same first MaxDepth head blocks -> identical chains.
|
||||
Expect(turnN).To(HaveLen(small.MaxDepth))
|
||||
Expect(turnN1).To(HaveLen(small.MaxDepth))
|
||||
Expect(turnN1).To(Equal(turnN))
|
||||
|
||||
// A prompt diverging WITHIN the budget shares the leading hashes up to
|
||||
// the divergence block and differs after. "system-r" matches base for
|
||||
// the first two 4-byte blocks ("syst","em-r"), then block 2 differs.
|
||||
divergent := prefixcache.ExtractChain("m", "system-rDIFFERENT-tail", small)
|
||||
Expect(divergent).To(HaveLen(small.MaxDepth))
|
||||
Expect(divergent[0]).To(Equal(turnN[0]))
|
||||
Expect(divergent[1]).To(Equal(turnN[1]))
|
||||
Expect(divergent[2]).NotTo(Equal(turnN[2]))
|
||||
})
|
||||
})
|
||||
129
core/services/nodes/prefixcache/index.go
Normal file
129
core/services/nodes/prefixcache/index.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/radixtree"
|
||||
)
|
||||
|
||||
// Index is the guessed (routing-history) Provider backed by per-model radix
|
||||
// trees keyed by ReplicaKey. Affinity is per replica, so the same prefix served
|
||||
// by two replicas of one node resolves back to the exact replica that served it.
|
||||
// Safe for concurrent use.
|
||||
type Index struct {
|
||||
cfg Config
|
||||
mu sync.RWMutex
|
||||
trees map[string]*radixtree.Tree[ReplicaKey]
|
||||
}
|
||||
|
||||
func NewIndex(cfg Config) *Index {
|
||||
return &Index{cfg: cfg, trees: map[string]*radixtree.Tree[ReplicaKey]{}}
|
||||
}
|
||||
|
||||
// existingTree returns the tree for model without creating one. The bool
|
||||
// reports whether a tree already existed.
|
||||
func (ix *Index) existingTree(model string) (*radixtree.Tree[ReplicaKey], bool) {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
t, ok := ix.trees[model]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (ix *Index) tree(model string) *radixtree.Tree[ReplicaKey] {
|
||||
ix.mu.RLock()
|
||||
t, ok := ix.trees[model]
|
||||
ix.mu.RUnlock()
|
||||
if ok {
|
||||
return t
|
||||
}
|
||||
ix.mu.Lock()
|
||||
defer ix.mu.Unlock()
|
||||
if t, ok = ix.trees[model]; ok {
|
||||
return t
|
||||
}
|
||||
t = radixtree.New[ReplicaKey](radixtree.Options{TTL: ix.cfg.TTL, HalfLife: ix.cfg.HalfLife})
|
||||
ix.trees[model] = t
|
||||
return t
|
||||
}
|
||||
|
||||
func (ix *Index) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
|
||||
t := ix.tree(model)
|
||||
var d PrefixDecision
|
||||
// WeightsFor computes every candidate weight in a single tree walk and
|
||||
// returns a map pre-populated with an entry (weight 0 by default) for every
|
||||
// requested candidate. Candidacy is therefore exactly "is a key in weights",
|
||||
// so we derive the hot-match membership check from it rather than building a
|
||||
// second set.
|
||||
weights := t.WeightsFor(candidates, now)
|
||||
if len(chain) > 0 {
|
||||
if key, depth, ok := t.LongestMatch(chain, now); ok {
|
||||
// LongestMatch searches the whole tree, so the deepest match can be
|
||||
// a replica that is offline / unloaded / not in the candidate set.
|
||||
// Treating that as a hot match produces a false forced-disturb signal
|
||||
// upstream (the warm replica was absent, not load-saturated). Only honor
|
||||
// the match when the matched replica is an actual candidate; otherwise
|
||||
// fall back to cold placement.
|
||||
if _, ok := weights[key]; ok {
|
||||
d.Hot = key
|
||||
d.HasHot = true
|
||||
d.MatchRatio = float64(depth) / float64(len(chain))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Cold order: candidates ascending by cacheWeight, tie-break by NodeID then
|
||||
// Replica. The sort comparator reads precomputed weights instead of triggering
|
||||
// an O(tree size) Weight call per comparison. With at most one candidate the
|
||||
// input order is already the cold order, so skip the sort.
|
||||
order := make([]ReplicaKey, len(candidates))
|
||||
copy(order, candidates)
|
||||
if len(order) > 1 {
|
||||
sort.Slice(order, func(i, j int) bool {
|
||||
if weights[order[i]] != weights[order[j]] {
|
||||
return weights[order[i]] < weights[order[j]]
|
||||
}
|
||||
return order[i].less(order[j])
|
||||
})
|
||||
}
|
||||
d.ColdOrder = order
|
||||
return d
|
||||
}
|
||||
|
||||
func (ix *Index) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
|
||||
if len(chain) == 0 || key.NodeID == "" {
|
||||
return false
|
||||
}
|
||||
t := ix.tree(model)
|
||||
// New/extended iff the current deepest match for this exact chain is not
|
||||
// already this replica at full depth.
|
||||
cur, depth, ok := t.LongestMatch(chain, now)
|
||||
t.Insert(chain, key, now)
|
||||
return !ok || depth < len(chain) || cur != key
|
||||
}
|
||||
|
||||
// Invalidate drops all entries for ONE replica. It never interns an empty tree
|
||||
// (a registry chokepoint fires Invalidate for every replica removal of every
|
||||
// model, including round-robin models that never used the prefix cache, so
|
||||
// lazily creating a tree here would grow the trees map unboundedly).
|
||||
func (ix *Index) Invalidate(model string, key ReplicaKey) {
|
||||
if t, ok := ix.existingTree(model); ok {
|
||||
t.RemoveFunc(func(k ReplicaKey) bool { return k == key })
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateNode drops entries for ALL replicas of nodeID. Like Invalidate it
|
||||
// does not intern an empty tree.
|
||||
func (ix *Index) InvalidateNode(model, nodeID string) {
|
||||
if t, ok := ix.existingTree(model); ok {
|
||||
t.RemoveFunc(func(k ReplicaKey) bool { return k.NodeID == nodeID })
|
||||
}
|
||||
}
|
||||
|
||||
func (ix *Index) Evict(now time.Time) {
|
||||
ix.mu.RLock()
|
||||
defer ix.mu.RUnlock()
|
||||
for _, t := range ix.trees {
|
||||
t.Evict(now)
|
||||
}
|
||||
}
|
||||
169
core/services/nodes/prefixcache/index_test.go
Normal file
169
core/services/nodes/prefixcache/index_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var _ = Describe("Index provider", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("returns no hot match before anything is observed", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
// cold order present (all weights zero -> deterministic by node id)
|
||||
Expect(d.ColdOrder).To(ConsistOf(rk("A", 0), rk("B", 0)))
|
||||
})
|
||||
|
||||
It("returns the observed replica as hot match with the right ratio", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
|
||||
})
|
||||
|
||||
It("orders cold candidates by ascending cacheWeight", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1}, rk("A", 0), t0)
|
||||
idx.Observe("m", []uint64{2}, rk("A", 0), t0) // A weight 2
|
||||
idx.Observe("m", []uint64{3}, rk("B", 0), t0) // B weight 1
|
||||
d := idx.Decide("m", []uint64{9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
Expect(d.ColdOrder).To(Equal([]prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)})) // B lower weight first
|
||||
})
|
||||
|
||||
It("drops the hot match when the matched replica is not in the candidate set", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
// A holds the longest match, but A is not a candidate (offline /
|
||||
// unloaded). The matched replica must be ignored so cold placement runs
|
||||
// and no false forced-disturb fires upstream.
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("B", 0), rk("C", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
Expect(d.MatchRatio).To(Equal(0.0))
|
||||
Expect(d.ColdOrder).To(ConsistOf(rk("B", 0), rk("C", 0)))
|
||||
})
|
||||
|
||||
It("returns a hot match for a query that only shares a prefix with an observed chain", func() {
|
||||
// The real-world case: a replica served chain [1,2,3,4]; a new request
|
||||
// shares the leading block [1,2,3] but diverges at the tail ([1,2,3,9]).
|
||||
// With prefix matching (value recorded at every node) Decide must still
|
||||
// route to the warm replica, matching at the depth of the shared prefix.
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 3.0/4.0, 0.001)) // shared [1,2,3] of len-4 query
|
||||
})
|
||||
|
||||
It("keeps the hot match when the matched replica is a candidate", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0)))
|
||||
Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
|
||||
})
|
||||
|
||||
It("tracks affinity per replica, not per node", func() {
|
||||
// Two replicas on the SAME node, each serving a different chain that share
|
||||
// a leading block. The hot match for a query extending chain1 must be the
|
||||
// EXACT replica that served chain1, not the other replica on the same node.
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) // replica 0 owns [1,2,3,4]
|
||||
idx.Observe("m", []uint64{1, 2, 5, 6}, rk("A", 1), t0) // replica 1 owns [1,2,5,6]
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
d := idx.Decide("m", []uint64{1, 2, 3, 4, 7}, cands, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 0))) // distinct replicas on one node have distinct affinity
|
||||
d2 := idx.Decide("m", []uint64{1, 2, 5, 6, 7}, cands, t0)
|
||||
Expect(d2.HasHot).To(BeTrue())
|
||||
Expect(d2.Hot).To(Equal(rk("A", 1)))
|
||||
})
|
||||
|
||||
It("Invalidate drops one replica while InvalidateNode drops all replicas of a node", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
idx.Observe("m", []uint64{5, 6, 7, 8}, rk("A", 1), t0)
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
|
||||
// Invalidate replica 0 only: replica 1 survives.
|
||||
idx.Invalidate("m", rk("A", 0))
|
||||
Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
d1 := idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0)
|
||||
Expect(d1.HasHot).To(BeTrue())
|
||||
Expect(d1.Hot).To(Equal(rk("A", 1)))
|
||||
|
||||
// Re-observe both, then InvalidateNode drops BOTH replicas.
|
||||
idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
|
||||
idx.InvalidateNode("m", "A")
|
||||
Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0).HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("forgets a replica on Invalidate", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
idx.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
idx.Invalidate("m", rk("A", 0))
|
||||
d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does not intern an empty tree when invalidating a model that has none", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
// Round-robin model that never used the prefix cache: invalidating a
|
||||
// replica removal must be a no-op and must not retain a tree.
|
||||
idx.Invalidate("never-cached", rk("A", 0))
|
||||
idx.Invalidate("never-cached", rk("B", 0))
|
||||
idx.InvalidateNode("other", "C")
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
// And a Decide afterwards still works without a hot match.
|
||||
d := idx.Decide("never-cached", []uint64{1}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
|
||||
Expect(d.HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is safe for concurrent Decide/Observe/Invalidate (run with -race)", func() {
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
models := []string{"m1", "m2"}
|
||||
nodes := []string{"A", "B", "C"}
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0), rk("C", 0)}
|
||||
var wg sync.WaitGroup
|
||||
for g := range 8 {
|
||||
wg.Add(1)
|
||||
go func(g int) {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
model := models[g%len(models)]
|
||||
node := nodes[g%len(nodes)]
|
||||
now := t0
|
||||
for i := range 200 {
|
||||
chain := []uint64{uint64(g), uint64(i % 7), uint64(i)}
|
||||
switch i % 4 {
|
||||
case 0:
|
||||
idx.Observe(model, chain, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2}, now)
|
||||
case 1:
|
||||
idx.Decide(model, chain, cands, now)
|
||||
case 2:
|
||||
idx.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2})
|
||||
case 3:
|
||||
idx.InvalidateNode(model, node)
|
||||
}
|
||||
now = now.Add(time.Millisecond)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
47
core/services/nodes/prefixcache/policy.go
Normal file
47
core/services/nodes/prefixcache/policy.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Package prefixcache implements prefix-cache-aware routing for distributed
|
||||
// mode: it turns a request prompt into a chain of prefix hashes, tracks which
|
||||
// node served which prefix in an in-memory radix tree, and provides a
|
||||
// load-guarded preferred-node decision. See docs/content/features/distributed-mode.md.
|
||||
package prefixcache
|
||||
|
||||
// RoutePolicy selects the routing strategy for a model. The zero value is
|
||||
// RoutePolicyDefault, meaning "inherit the cluster-wide default".
|
||||
type RoutePolicy int
|
||||
|
||||
const (
|
||||
RoutePolicyDefault RoutePolicy = iota // inherit global default
|
||||
RoutePolicyRoundRobin // today's behavior (the floor)
|
||||
RoutePolicyPrefixCache // cache-aware routing
|
||||
)
|
||||
|
||||
// ParsePolicy maps a config string to a RoutePolicy. Unknown or empty strings
|
||||
// map to RoutePolicyDefault.
|
||||
func ParsePolicy(s string) RoutePolicy {
|
||||
switch s {
|
||||
case "round_robin":
|
||||
return RoutePolicyRoundRobin
|
||||
case "prefix_cache":
|
||||
return RoutePolicyPrefixCache
|
||||
default:
|
||||
return RoutePolicyDefault
|
||||
}
|
||||
}
|
||||
|
||||
func (p RoutePolicy) String() string {
|
||||
switch p {
|
||||
case RoutePolicyRoundRobin:
|
||||
return "round_robin"
|
||||
case RoutePolicyPrefixCache:
|
||||
return "prefix_cache"
|
||||
default:
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve returns p unless it is Default, in which case it returns global.
|
||||
func (p RoutePolicy) Resolve(global RoutePolicy) RoutePolicy {
|
||||
if p == RoutePolicyDefault {
|
||||
return global
|
||||
}
|
||||
return p
|
||||
}
|
||||
29
core/services/nodes/prefixcache/policy_test.go
Normal file
29
core/services/nodes/prefixcache/policy_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
var _ = Describe("RoutePolicy", func() {
|
||||
It("parses known values and defaults unknown to Default (zero)", func() {
|
||||
Expect(prefixcache.ParsePolicy("round_robin")).To(Equal(prefixcache.RoutePolicyRoundRobin))
|
||||
Expect(prefixcache.ParsePolicy("prefix_cache")).To(Equal(prefixcache.RoutePolicyPrefixCache))
|
||||
Expect(prefixcache.ParsePolicy("")).To(Equal(prefixcache.RoutePolicyDefault))
|
||||
Expect(prefixcache.ParsePolicy("bogus")).To(Equal(prefixcache.RoutePolicyDefault))
|
||||
})
|
||||
|
||||
It("stringifies", func() {
|
||||
Expect(prefixcache.RoutePolicyPrefixCache.String()).To(Equal("prefix_cache"))
|
||||
Expect(prefixcache.RoutePolicyRoundRobin.String()).To(Equal("round_robin"))
|
||||
})
|
||||
|
||||
It("resolves per-model against a global default", func() {
|
||||
Expect(prefixcache.RoutePolicyDefault.Resolve(prefixcache.RoutePolicyPrefixCache)).
|
||||
To(Equal(prefixcache.RoutePolicyPrefixCache))
|
||||
Expect(prefixcache.RoutePolicyRoundRobin.Resolve(prefixcache.RoutePolicyPrefixCache)).
|
||||
To(Equal(prefixcache.RoutePolicyRoundRobin))
|
||||
})
|
||||
})
|
||||
13
core/services/nodes/prefixcache/prefixcache_suite_test.go
Normal file
13
core/services/nodes/prefixcache/prefixcache_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPrefixCache(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "PrefixCache Suite")
|
||||
}
|
||||
82
core/services/nodes/prefixcache/pressure.go
Normal file
82
core/services/nodes/prefixcache/pressure.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pressure is a concurrency-safe rolling per-model counter of forced-disturb
|
||||
// events. A forced-disturb is recorded by the router when a usable hot prefix
|
||||
// match existed but the load guard forced the request off the warm node (see
|
||||
// SmartRouter.buildPreference). The reconciler reads Count to decide whether
|
||||
// the cache-warm replica is saturated enough to warrant a scale-up.
|
||||
//
|
||||
// Entries older than the window are dropped on both Record and Count, so the
|
||||
// slice never grows unbounded - even for a model that takes records but is
|
||||
// never Counted (e.g. one with zero loaded replicas the reconciler skips). An
|
||||
// idle model's history also decays to zero on the next read.
|
||||
type Pressure struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
events map[string][]time.Time
|
||||
}
|
||||
|
||||
// NewPressure creates a Pressure counter that remembers events for the given
|
||||
// rolling window.
|
||||
func NewPressure(window time.Duration) *Pressure {
|
||||
return &Pressure{
|
||||
window: window,
|
||||
events: make(map[string][]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// pruneLocked drops entries older than cutoff, compacting in place. The cutoff
|
||||
// boundary itself is inclusive so an event exactly window-old still counts.
|
||||
// Callers must hold p.mu.
|
||||
func pruneLocked(ts []time.Time, cutoff time.Time) []time.Time {
|
||||
kept := ts[:0]
|
||||
for _, t := range ts {
|
||||
if !t.Before(cutoff) {
|
||||
kept = append(kept, t)
|
||||
}
|
||||
}
|
||||
return kept
|
||||
}
|
||||
|
||||
// Record appends a forced-disturb timestamp for the model and prunes entries
|
||||
// older than the window, so the per-model slice stays bounded regardless of how
|
||||
// often Count runs.
|
||||
func (p *Pressure) Record(model string, now time.Time) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
cutoff := now.Add(-p.window)
|
||||
kept := append(pruneLocked(p.events[model], cutoff), now)
|
||||
p.events[model] = kept
|
||||
}
|
||||
|
||||
// Count returns the number of records for the model within [now-window, now],
|
||||
// dropping any entries older than the window so the backing slice stays bounded.
|
||||
func (p *Pressure) Count(model string, now time.Time) int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
ts := p.events[model]
|
||||
if len(ts) == 0 {
|
||||
return 0
|
||||
}
|
||||
kept := pruneLocked(ts, now.Add(-p.window))
|
||||
if len(kept) == 0 {
|
||||
delete(p.events, model)
|
||||
return 0
|
||||
}
|
||||
p.events[model] = kept
|
||||
return len(kept)
|
||||
}
|
||||
|
||||
// Reset clears all recorded events for model. Call after acting on the signal
|
||||
// (a pressure-triggered scale-up) so a single burst does not trigger repeated
|
||||
// scale-ups across consecutive ticks.
|
||||
func (p *Pressure) Reset(model string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.events, model)
|
||||
}
|
||||
98
core/services/nodes/prefixcache/pressure_test.go
Normal file
98
core/services/nodes/prefixcache/pressure_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Pressure counter", func() {
|
||||
t0 := time.Unix(1700000000, 0)
|
||||
|
||||
It("counts events within the window and forgets older ones", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
p.Record("m", t0.Add(30*time.Second))
|
||||
Expect(p.Count("m", t0.Add(40*time.Second))).To(Equal(2))
|
||||
Expect(p.Count("m", t0.Add(90*time.Second))).To(Equal(1)) // first expired
|
||||
})
|
||||
|
||||
It("tracks pressure per model independently", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("a", t0)
|
||||
p.Record("a", t0.Add(10*time.Second))
|
||||
p.Record("b", t0.Add(20*time.Second))
|
||||
Expect(p.Count("a", t0.Add(30*time.Second))).To(Equal(2))
|
||||
Expect(p.Count("b", t0.Add(30*time.Second))).To(Equal(1))
|
||||
Expect(p.Count("c", t0.Add(30*time.Second))).To(Equal(0))
|
||||
})
|
||||
|
||||
It("returns zero for a model that was never recorded", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
Expect(p.Count("never", t0)).To(Equal(0))
|
||||
})
|
||||
|
||||
It("includes the boundary timestamp at exactly now-window", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
// now-window == t0 exactly, so the entry is still within [now-window, now].
|
||||
Expect(p.Count("m", t0.Add(time.Minute))).To(Equal(1))
|
||||
// one nanosecond past the window drops it.
|
||||
Expect(p.Count("m", t0.Add(time.Minute+1))).To(Equal(0))
|
||||
})
|
||||
|
||||
It("bounds the backing slice in Record without any Count calls", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
// Record many timestamps, advancing now well past the window between
|
||||
// each, and never call Count. Each Record must prune the entries that
|
||||
// have fallen out of [now-window, now] so the slice cannot accumulate.
|
||||
var last time.Time
|
||||
for i := range 1000 {
|
||||
last = t0.Add(time.Duration(i) * 10 * time.Second)
|
||||
p.Record("m", last)
|
||||
}
|
||||
// With a 1m window and 10s spacing, at most ~7 records (the boundary is
|
||||
// inclusive) can be within [last-window, last]. The slice must stay that
|
||||
// bounded, never growing toward 1000.
|
||||
Expect(p.LenForTest("m")).To(BeNumerically("<=", 7))
|
||||
// And the in-window count must reflect only those bounded entries.
|
||||
Expect(p.Count("m", last)).To(Equal(p.LenForTest("m")))
|
||||
})
|
||||
|
||||
It("clears all recorded events on Reset", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("m", t0)
|
||||
p.Record("m", t0.Add(10*time.Second))
|
||||
p.Record("m", t0.Add(20*time.Second))
|
||||
Expect(p.Count("m", t0.Add(30*time.Second))).To(BeNumerically(">", 0))
|
||||
|
||||
p.Reset("m")
|
||||
|
||||
// After Reset the model has no in-window events even though the
|
||||
// timestamps would otherwise still be within [now-window, now].
|
||||
Expect(p.Count("m", t0.Add(30*time.Second))).To(Equal(0))
|
||||
Expect(p.LenForTest("m")).To(Equal(0))
|
||||
})
|
||||
|
||||
It("Reset only clears the named model", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
p.Record("a", t0)
|
||||
p.Record("b", t0)
|
||||
p.Reset("a")
|
||||
Expect(p.Count("a", t0.Add(time.Second))).To(Equal(0))
|
||||
Expect(p.Count("b", t0.Add(time.Second))).To(Equal(1))
|
||||
})
|
||||
|
||||
It("does not accumulate repeated out-of-window Records", func() {
|
||||
p := prefixcache.NewPressure(time.Minute)
|
||||
// Each record is more than a window apart, so every Record prunes the
|
||||
// previous one. The slice should never hold more than a single entry.
|
||||
for i := range 100 {
|
||||
p.Record("m", t0.Add(time.Duration(i)*2*time.Minute))
|
||||
}
|
||||
Expect(p.LenForTest("m")).To(Equal(1))
|
||||
Expect(p.Count("m", t0.Add(198*time.Minute))).To(Equal(1))
|
||||
})
|
||||
})
|
||||
24
core/services/nodes/prefixcache/provider.go
Normal file
24
core/services/nodes/prefixcache/provider.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package prefixcache
|
||||
|
||||
import "time"
|
||||
|
||||
// Provider is the seam between SmartRouter and the prefix-cache implementation.
|
||||
// The radix-tree (guessed) implementation is the only one today; a future
|
||||
// KV-event (reported) implementation can satisfy the same interface without
|
||||
// changing SmartRouter (epic #10063 / #10064). Affinity is tracked per replica:
|
||||
// each loaded replica is a separate process with its own KV cache.
|
||||
type Provider interface {
|
||||
// Decide computes the prefix decision for a request given the candidate
|
||||
// replicas (the selector-filtered set). It does not consult load - load
|
||||
// filtering happens in the DB transaction.
|
||||
Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision
|
||||
// Observe records that the replica served the request whose prefix is chain.
|
||||
// Returns true when the assignment was new or extended (caller broadcasts).
|
||||
Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool
|
||||
// Invalidate drops all entries for ONE replica.
|
||||
Invalidate(model string, key ReplicaKey)
|
||||
// InvalidateNode drops entries for ALL replicas of a node.
|
||||
InvalidateNode(model, nodeID string)
|
||||
// Evict sweeps expired entries for all models.
|
||||
Evict(now time.Time)
|
||||
}
|
||||
93
core/services/nodes/prefixcache/select.go
Normal file
93
core/services/nodes/prefixcache/select.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package prefixcache
|
||||
|
||||
// ReplicaKey identifies a specific loaded replica (a backend process). Affinity
|
||||
// is tracked per replica, not per node, because each replica is a separate
|
||||
// process with its own KV cache.
|
||||
type ReplicaKey struct {
|
||||
NodeID string
|
||||
Replica int
|
||||
}
|
||||
|
||||
// less reports whether a sorts before b, ordering by NodeID then Replica. It is
|
||||
// the deterministic tiebreak used wherever two replicas are otherwise equal.
|
||||
func (a ReplicaKey) less(b ReplicaKey) bool {
|
||||
if a.NodeID != b.NodeID {
|
||||
return a.NodeID < b.NodeID
|
||||
}
|
||||
return a.Replica < b.Replica
|
||||
}
|
||||
|
||||
// Candidate is a load-eligible-or-not replica view from the registry. There is
|
||||
// one Candidate per LOADED replica: the router no longer collapses replicas per
|
||||
// node, so two replicas of the same model on the same node are two candidates.
|
||||
type Candidate struct {
|
||||
Key ReplicaKey
|
||||
InFlight int
|
||||
}
|
||||
|
||||
// PrefixDecision is computed from the in-memory tree before the DB transaction.
|
||||
// Hot is the replica holding the longest prefix match and HasHot reports whether
|
||||
// there is one (a ReplicaKey has no "" sentinel). MatchRatio is matched/total
|
||||
// for that match. ColdOrder lists candidate replicas ascending by cacheWeight
|
||||
// (lowest = least valuable warm cache = best cold target).
|
||||
type PrefixDecision struct {
|
||||
Hot ReplicaKey
|
||||
HasHot bool
|
||||
MatchRatio float64
|
||||
ColdOrder []ReplicaKey
|
||||
}
|
||||
|
||||
// Select implements filter-then-score per replica: keep candidates within the
|
||||
// load guard (relative to the min in-flight across ALL candidate replicas), then
|
||||
// prefer the exact hot-match replica, else the lowest-cacheWeight eligible
|
||||
// replica via ColdOrder, else a deterministic eligible fallback (least in-flight,
|
||||
// tiebreak by NodeID then Replica). Returns (ReplicaKey{}, false) when nothing is
|
||||
// selectable.
|
||||
func Select(cands []Candidate, d PrefixDecision, cfg Config) (ReplicaKey, bool) {
|
||||
if len(cands) == 0 {
|
||||
return ReplicaKey{}, false
|
||||
}
|
||||
minIF := cands[0].InFlight
|
||||
for _, c := range cands {
|
||||
minIF = min(minIF, c.InFlight)
|
||||
}
|
||||
eligible := map[ReplicaKey]bool{}
|
||||
for _, c := range cands {
|
||||
withinAbs := c.InFlight <= minIF+cfg.BalanceAbsThreshold
|
||||
// +1 softens the relative guard when minIF==0 so a zero baseline does
|
||||
// not require exact-zero in-flight; the absolute guard governs near 0.
|
||||
withinRel := float64(c.InFlight) <= float64(minIF)*cfg.BalanceRelThreshold+1
|
||||
if withinAbs && withinRel {
|
||||
eligible[c.Key] = true
|
||||
}
|
||||
}
|
||||
// Hot match wins if eligible and strong enough.
|
||||
if d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && eligible[d.Hot] {
|
||||
return d.Hot, true
|
||||
}
|
||||
// Cold placement: lowest cacheWeight eligible replica.
|
||||
for _, k := range d.ColdOrder {
|
||||
if eligible[k] {
|
||||
return k, true
|
||||
}
|
||||
}
|
||||
// Deterministic eligible fallback: least in-flight, tiebreak NodeID then
|
||||
// Replica. ColdOrder may not cover the eligible set (the caller may pass an
|
||||
// empty ColdOrder), so this guarantees Select still returns the best eligible
|
||||
// replica rather than failing.
|
||||
var best Candidate
|
||||
found := false
|
||||
for _, c := range cands {
|
||||
if !eligible[c.Key] {
|
||||
continue
|
||||
}
|
||||
if !found || c.InFlight < best.InFlight ||
|
||||
(c.InFlight == best.InFlight && c.Key.less(best.Key)) {
|
||||
best, found = c, true
|
||||
}
|
||||
}
|
||||
if found {
|
||||
return best.Key, true
|
||||
}
|
||||
return ReplicaKey{}, false
|
||||
}
|
||||
139
core/services/nodes/prefixcache/select_test.go
Normal file
139
core/services/nodes/prefixcache/select_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
func rk(node string, replica int) prefixcache.ReplicaKey {
|
||||
return prefixcache.ReplicaKey{NodeID: node, Replica: replica}
|
||||
}
|
||||
|
||||
var _ = Describe("Select (filter-then-score)", func() {
|
||||
cfg := prefixcache.DefaultConfig() // abs=2, rel=1.5, minMatch=0.3
|
||||
|
||||
cand := func(node string, replica, inflight int) prefixcache.Candidate {
|
||||
return prefixcache.Candidate{Key: rk(node, replica), InFlight: inflight}
|
||||
}
|
||||
|
||||
It("returns the hot-match replica when it is load-eligible and match >= min", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 1), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.5,
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // A in-flight 1 <= min(0)+2 and <= 0*1.5+1
|
||||
})
|
||||
|
||||
It("rejects the hot match when it violates the absolute load guard", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 5), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.9,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // A 5 > min(0)+2, drop to cold placement
|
||||
})
|
||||
|
||||
It("ignores a match below min_prefix_match", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 0.2, // < 0.3
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // cold placement: lowest cacheWeight eligible
|
||||
})
|
||||
|
||||
It("cold-places to lowest-cacheWeight replica within the eligible subset", func() {
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0), cand("C", 0, 9)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("C", 0), rk("B", 0), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("B", 0))) // C filtered out by load; B is next in cold order
|
||||
})
|
||||
|
||||
It("returns false when no candidates", func() {
|
||||
_, ok := prefixcache.Select(nil, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("falls back to the least-in-flight eligible replica when ColdOrder is empty", func() {
|
||||
// Deterministic eligible fallback: ColdOrder does not cover the eligible
|
||||
// set, so Select picks the least-in-flight eligible replica, tiebreaking by
|
||||
// NodeID then Replica.
|
||||
cands := []prefixcache.Candidate{cand("B", 1, 0), cand("B", 0, 0), cand("A", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // all in-flight 0; A < B; within B, replica 0 < 1
|
||||
})
|
||||
|
||||
It("returns false when no candidate is eligible", func() {
|
||||
// Impossible in practice (min is always eligible) but guards the contract:
|
||||
// an empty eligible set yields no selection. Here every candidate is the
|
||||
// min, so one is always eligible; instead test the documented zero value.
|
||||
cands := []prefixcache.Candidate{cand("A", 0, 0)}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0)))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Select replica granularity", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
|
||||
It("distinguishes two replicas of the same node as separate candidates", func() {
|
||||
// Two replicas on NodeA: replica 0 is hot but saturated, replica 1 is cool.
|
||||
// The round-robin floor must drop to replica 1, NOT collapse them per node.
|
||||
cands := []prefixcache.Candidate{
|
||||
{Key: rk("A", 0), InFlight: 50},
|
||||
{Key: rk("A", 1), InFlight: 0},
|
||||
}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 1)))
|
||||
})
|
||||
|
||||
It("pins back to the exact hot replica when it is within slack", func() {
|
||||
cands := []prefixcache.Candidate{
|
||||
{Key: rk("A", 0), InFlight: 1},
|
||||
{Key: rk("A", 1), InFlight: 0},
|
||||
}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("A", 0))) // within slack -> reuse exact replica
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Select round-robin floor invariant", func() {
|
||||
It("never pins to a saturated hot replica (round-robin floor)", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 50}, {Key: rk("cool", 0), InFlight: 0}}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("cool", 0)))
|
||||
})
|
||||
|
||||
It("improves reuse when balanced", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 1}, {Key: rk("cool", 0), InFlight: 0}}
|
||||
got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
|
||||
Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
|
||||
}, cfg)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(got).To(Equal(rk("hot", 0))) // within slack -> reuse
|
||||
})
|
||||
})
|
||||
91
core/services/nodes/prefixcache/sync.go
Normal file
91
core/services/nodes/prefixcache/sync.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package prefixcache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// publisher is the minimal slice of messaging.Client that Sync needs.
|
||||
type publisher interface {
|
||||
Publish(subject string, v any) error
|
||||
}
|
||||
|
||||
// Sync wraps an Index, broadcasting new/extended observations to peers and
|
||||
// applying peers' broadcasts. It is the cross-frontend coherence layer.
|
||||
type Sync struct {
|
||||
idx Provider
|
||||
pub publisher
|
||||
}
|
||||
|
||||
func NewSync(idx Provider, pub publisher) *Sync { return &Sync{idx: idx, pub: pub} }
|
||||
|
||||
// Observe records locally and, if new/extended, broadcasts to peers. It returns
|
||||
// whether the local index treated the assignment as new or extended, so Sync
|
||||
// satisfies prefixcache.Provider.
|
||||
func (s *Sync) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
|
||||
changed := s.idx.Observe(model, chain, key, now)
|
||||
if changed && s.pub != nil {
|
||||
ev := messaging.PrefixCacheObserveEvent{Model: model, Chain: chain, NodeID: key.NodeID, Replica: key.Replica}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheObserve, ev); err != nil {
|
||||
xlog.Debug("prefixcache: observe publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// Invalidate drops the local entry for one replica and broadcasts to peers. The
|
||||
// local drop is a no-op for models that were never cached (Index.Invalidate does
|
||||
// not intern a tree). The broadcast is UNCONDITIONAL (when a publisher is
|
||||
// configured): the registry chokepoint fires for every replica removal, and a
|
||||
// peer frontend may hold a stale entry for the model even when THIS frontend
|
||||
// never cached it, so gating the broadcast on local-tree existence would drop
|
||||
// cross-frontend invalidations and leave peers routing to a removed replica
|
||||
// until their TTL.
|
||||
func (s *Sync) Invalidate(model string, key ReplicaKey) {
|
||||
s.idx.Invalidate(model, key)
|
||||
if s.pub != nil {
|
||||
ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: key.NodeID, Replica: key.Replica}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
|
||||
xlog.Debug("prefixcache: invalidate publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateNode drops the local entries for ALL replicas of node and broadcasts
|
||||
// to peers. Like Invalidate the broadcast is unconditional for cross-frontend
|
||||
// coherence. A negative Replica on the wire means "all replicas of the node".
|
||||
func (s *Sync) InvalidateNode(model, node string) {
|
||||
s.idx.InvalidateNode(model, node)
|
||||
if s.pub != nil {
|
||||
ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: node, Replica: -1}
|
||||
if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
|
||||
xlog.Debug("prefixcache: invalidate-node publish failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyObserve applies a peer observe event locally (no re-broadcast).
|
||||
func (s *Sync) ApplyObserve(ev messaging.PrefixCacheObserveEvent, now time.Time) {
|
||||
s.idx.Observe(ev.Model, ev.Chain, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica}, now)
|
||||
}
|
||||
|
||||
// ApplyInvalidate applies a peer invalidate event locally (no re-broadcast). A
|
||||
// negative Replica targets all replicas of the node.
|
||||
func (s *Sync) ApplyInvalidate(ev messaging.PrefixCacheInvalidateEvent) {
|
||||
if ev.Replica < 0 {
|
||||
s.idx.InvalidateNode(ev.Model, ev.NodeID)
|
||||
return
|
||||
}
|
||||
s.idx.Invalidate(ev.Model, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica})
|
||||
}
|
||||
|
||||
// Decide delegates to the wrapped index.
|
||||
func (s *Sync) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
|
||||
return s.idx.Decide(model, chain, candidates, now)
|
||||
}
|
||||
|
||||
// Evict delegates eviction of expired entries to the wrapped index. It does not
|
||||
// broadcast: each frontend evicts its own copy on its own TTL clock.
|
||||
func (s *Sync) Evict(now time.Time) { s.idx.Evict(now) }
|
||||
118
core/services/nodes/prefixcache/sync_test.go
Normal file
118
core/services/nodes/prefixcache/sync_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package prefixcache_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
)
|
||||
|
||||
type fakePub struct{ published []any }
|
||||
|
||||
func (f *fakePub) Publish(subject string, v any) error {
|
||||
f.published = append(f.published, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync must satisfy the Provider seam so SmartRouter can hold a single
|
||||
// prefixcache.Provider that broadcasts via NATS.
|
||||
var _ prefixcache.Provider = (*prefixcache.Sync)(nil)
|
||||
|
||||
var _ = Describe("Sync", func() {
|
||||
It("delegates Evict to the wrapped index", func() {
|
||||
cfg := prefixcache.DefaultConfig()
|
||||
cfg.TTL = time.Minute
|
||||
idx := prefixcache.NewIndex(cfg)
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
// Before TTL: still hot.
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0).HasHot).To(BeTrue())
|
||||
// After TTL via Sync.Evict: entry is swept.
|
||||
s.Evict(t0.Add(2 * time.Minute))
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0.Add(2*time.Minute)).HasHot).To(BeFalse())
|
||||
})
|
||||
|
||||
It("publishes an observe event with the replica when Observe is new", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // first time -> publish
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheObserveEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(Equal(1))
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // same -> no publish
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("broadcasts an invalidate even for a model with no local tree, without interning one", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
// A peer frontend may hold a stale entry for this model even though THIS
|
||||
// frontend never cached it, so the invalidate MUST be broadcast for
|
||||
// cross-frontend coherence. The local drop must still not intern a tree.
|
||||
s.Invalidate("never-cached", rk("A", 0))
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(Equal(0))
|
||||
Expect(idx.TreeCountForTest()).To(Equal(0))
|
||||
})
|
||||
|
||||
It("broadcasts an invalidate for a cached replica too", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) // creates the tree (also publishes observe)
|
||||
pub.published = nil
|
||||
s.Invalidate("m", rk("A", 0))
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
Expect(pub.published[0]).To(BeAssignableToTypeOf(messaging.PrefixCacheInvalidateEvent{}))
|
||||
})
|
||||
|
||||
It("broadcasts a node-wide invalidate with a negative replica", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
pub := &fakePub{}
|
||||
s := prefixcache.NewSync(idx, pub)
|
||||
s.InvalidateNode("m", "A")
|
||||
Expect(pub.published).To(HaveLen(1))
|
||||
ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
|
||||
Expect(ev.NodeID).To(Equal("A"))
|
||||
Expect(ev.Replica).To(BeNumerically("<", 0))
|
||||
})
|
||||
|
||||
It("applies a peer observe event into the local index with the replica", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.ApplyObserve(messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 2}, t0)
|
||||
d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 2)}, t0)
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(rk("A", 2)))
|
||||
})
|
||||
|
||||
It("applies a peer single-replica invalidate", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
|
||||
s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0})
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeTrue())
|
||||
})
|
||||
|
||||
It("applies a peer node-wide invalidate when replica is negative", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
s := prefixcache.NewSync(idx, &fakePub{})
|
||||
s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
|
||||
s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
|
||||
s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1})
|
||||
cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
|
||||
Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
|
||||
Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
grpcclient "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
@@ -56,6 +57,13 @@ type ReplicaReconciler struct {
|
||||
// probeStaleAfter: only probe node_models rows older than this so we
|
||||
// don't hammer every worker every tick for models we just heard from.
|
||||
probeStaleAfter time.Duration
|
||||
// pressure is the shared forced-disturb counter written by the router. When
|
||||
// a model's count within the Pressure's rolling window reaches pressureThreshold the
|
||||
// reconciler treats its cache-warm replica as saturated and scales up,
|
||||
// subject to the same MaxReplicas/capacity/UnsatisfiableUntil machinery as
|
||||
// the other scale-up paths. nil disables this signal (a true no-op).
|
||||
pressure *prefixcache.Pressure
|
||||
pressureThreshold int
|
||||
}
|
||||
|
||||
// ModelScheduler abstracts the scheduling logic needed by the reconciler.
|
||||
@@ -83,6 +91,12 @@ type ReplicaReconcilerOptions struct {
|
||||
Interval time.Duration // default 30s
|
||||
ScaleDownDelay time.Duration // default 5m
|
||||
ProbeStaleAfter time.Duration // default 2m
|
||||
// Pressure is the shared forced-disturb counter written by the router. nil
|
||||
// disables the cache-saturation autoscale signal (a true no-op).
|
||||
Pressure *prefixcache.Pressure
|
||||
// PressureThreshold is the forced-disturb count within PressureWindow that
|
||||
// triggers a scale-up. Default prefixcache.DefaultConfig().PressureScaleThreshold (1).
|
||||
PressureThreshold int
|
||||
}
|
||||
|
||||
// NewReplicaReconciler creates a new ReplicaReconciler.
|
||||
@@ -103,16 +117,22 @@ func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
|
||||
if prober == nil {
|
||||
prober = grpcModelProber{token: opts.RegistrationToken}
|
||||
}
|
||||
pressureThreshold := opts.PressureThreshold
|
||||
if pressureThreshold == 0 {
|
||||
pressureThreshold = prefixcache.DefaultConfig().PressureScaleThreshold
|
||||
}
|
||||
return &ReplicaReconciler{
|
||||
registry: opts.Registry,
|
||||
scheduler: opts.Scheduler,
|
||||
unloader: opts.Unloader,
|
||||
adapter: opts.Adapter,
|
||||
prober: prober,
|
||||
db: opts.DB,
|
||||
interval: interval,
|
||||
scaleDownDelay: scaleDownDelay,
|
||||
probeStaleAfter: probeStaleAfter,
|
||||
registry: opts.Registry,
|
||||
scheduler: opts.Scheduler,
|
||||
unloader: opts.Unloader,
|
||||
adapter: opts.Adapter,
|
||||
prober: prober,
|
||||
db: opts.DB,
|
||||
interval: interval,
|
||||
scaleDownDelay: scaleDownDelay,
|
||||
probeStaleAfter: probeStaleAfter,
|
||||
pressure: opts.Pressure,
|
||||
pressureThreshold: pressureThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,13 +429,25 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
|
||||
}
|
||||
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
|
||||
"current", current, "min", cfg.MinReplicas, "adding", needed)
|
||||
rc.scaleUp(ctx, cfg, needed)
|
||||
// Successful (or partial) scale-up clears the hysteresis so a future
|
||||
// dip starts fresh.
|
||||
_ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
|
||||
if rc.scaleUp(ctx, cfg, needed) {
|
||||
// A real (or partial) scale-up clears the hysteresis so a future
|
||||
// dip starts fresh. If scaleUp added nothing (scheduler errored or
|
||||
// no node could be loaded) we leave the hysteresis intact so the
|
||||
// next tick retries from where it left off rather than resetting
|
||||
// the unsatisfiable counter on a failed attempt.
|
||||
_ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// scaledUp tracks whether a scale-up already fired in this tick. The two
|
||||
// scale-up paths below (busy-burst and pressure) share the single `current`
|
||||
// value read once above; scaleUp does not re-check it. So at most one of
|
||||
// them may fire per tick, otherwise a model that is both busy AND over the
|
||||
// pressure threshold would scale +2 and could overshoot MaxReplicas by one.
|
||||
// Scale-down is also skipped in a tick that scaled up.
|
||||
scaledUp := false
|
||||
|
||||
// 2. Auto-scale up if all replicas are busy
|
||||
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
|
||||
if rc.allReplicasBusy(ctx, cfg.ModelName) {
|
||||
@@ -432,17 +464,63 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
|
||||
}
|
||||
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
|
||||
"current", current)
|
||||
rc.scaleUp(ctx, cfg, 1)
|
||||
// Only mark the tick as having scaled up if a replica was actually
|
||||
// added. On a failed scaleUp, leave scaledUp false so the pressure
|
||||
// path below and the scale-down logic still apply as they would
|
||||
// have if the busy-burst path had not run.
|
||||
scaledUp = rc.scaleUp(ctx, cfg, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Scale down idle replicas above minimum
|
||||
floor := cfg.MinReplicas
|
||||
if floor < 1 {
|
||||
floor = 1
|
||||
// 2b. Auto-scale up on prefix-cache forced-disturb pressure. A forced-disturb
|
||||
// is recorded by the router when a request had a usable hot prefix match
|
||||
// but the load guard forced it off the warm node: the cache-warm replica
|
||||
// is saturated. We reuse the same MaxReplicas + capacity guards as the
|
||||
// busy-burst path, and the same UnsatisfiableUntil cooldown gates this
|
||||
// block at the top of reconcileModel, so a no-capacity model will not
|
||||
// spin. Pressure never overrides MaxReplicas or force-evicts.
|
||||
//
|
||||
// Skipped when the busy-burst path already scaled up this tick: at most
|
||||
// one scaleUp(+1) per tick (see scaledUp above).
|
||||
if !scaledUp && rc.pressure != nil && current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
|
||||
if pressureCount := rc.pressure.Count(cfg.ModelName, time.Now()); pressureCount >= rc.pressureThreshold {
|
||||
candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg)
|
||||
if selectorMatched {
|
||||
capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs)
|
||||
if capErr == nil && capacity > 0 {
|
||||
xlog.Info("Reconciler: prefix-cache forced-disturb pressure, scaling up",
|
||||
"model", cfg.ModelName, "current", current,
|
||||
"pressure", pressureCount,
|
||||
"threshold", rc.pressureThreshold)
|
||||
if rc.scaleUp(ctx, cfg, 1) {
|
||||
scaledUp = true
|
||||
// Consume the signal only on a real scale-up:
|
||||
// Pressure.Count is non-draining (it prunes only by
|
||||
// age), so a single burst stays in-window for the whole
|
||||
// window and would re-fire scaleUp on every tick. Reset
|
||||
// clears the model's events so a fresh scale-up needs
|
||||
// fresh forced-disturbs to accumulate. If scaleUp added
|
||||
// nothing (scheduler errored or no node could be loaded)
|
||||
// we preserve the signal so the next tick retries off
|
||||
// the same accumulated pressure instead of having to
|
||||
// re-accumulate a full window from scratch.
|
||||
rc.pressure.Reset(cfg.ModelName)
|
||||
}
|
||||
}
|
||||
// No capacity: transient demand, not a misconfig - let the next
|
||||
// tick retry naturally (mirrors the busy-burst path's choice not
|
||||
// to enter cooldown for burst load).
|
||||
}
|
||||
}
|
||||
}
|
||||
if int(current) > floor {
|
||||
rc.scaleDownIdle(ctx, cfg, int(current), floor)
|
||||
|
||||
// 3. Scale down idle replicas above minimum. Skipped in a tick that already
|
||||
// scaled up so we never scale up and down in the same pass.
|
||||
if !scaledUp {
|
||||
floor := max(cfg.MinReplicas, 1)
|
||||
if int(current) > floor {
|
||||
rc.scaleDownIdle(ctx, cfg, int(current), floor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,10 +548,17 @@ func (rc *ReplicaReconciler) markCapacityProblem(ctx context.Context, modelName,
|
||||
// scaleUp schedules additional replicas of the model. Callers in
|
||||
// reconcileModel are expected to have already capped `count` against
|
||||
// ClusterCapacityForModel so this function never tries to overshoot.
|
||||
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
|
||||
//
|
||||
// Returns true if at least one replica was actually scheduled. Callers use
|
||||
// this to gate signal-consuming side effects (Pressure.Reset,
|
||||
// ClearUnsatisfiable) on a real scale-up: a failed/no-op scaleUp must not
|
||||
// discard the accumulated forced-disturb pressure or clear the unsatisfiable
|
||||
// hysteresis, otherwise the signal has to re-accumulate from scratch and the
|
||||
// next tick can't simply retry.
|
||||
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) bool {
|
||||
if rc.scheduler == nil {
|
||||
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Resolve selector → candidate node IDs (nil when no selector → "any
|
||||
@@ -481,18 +566,21 @@ func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingCon
|
||||
// reconcileModel, but defensively short-circuit here too.
|
||||
candidateNodeIDs, ok := rc.candidateNodeIDsForSelector(ctx, cfg)
|
||||
if !ok {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
scheduled := 0
|
||||
for i := 0; i < count; i++ {
|
||||
node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
|
||||
"attempt", i+1, "error", err)
|
||||
return // stop trying on first failure
|
||||
break // stop trying on first failure
|
||||
}
|
||||
scheduled++
|
||||
xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
|
||||
}
|
||||
return scheduled > 0
|
||||
}
|
||||
|
||||
// scaleDownIdle removes idle replicas above the floor.
|
||||
|
||||
@@ -2,12 +2,14 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -245,6 +247,225 @@ var _ = Describe("ReplicaReconciler", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Forced-disturb pressure autoscale (Phase 6)", func() {
|
||||
It("scales up when pressure exceeds threshold, replicas<max, and capacity exists", func() {
|
||||
// One node with spare slots, one loaded idle replica (so the
|
||||
// all-busy path does not fire). Pressure for the model is above the
|
||||
// threshold, which is the only reason to scale here.
|
||||
node := registerNode("pressure-node", "10.0.0.60:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "pressure-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("pressure-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pressure-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"forced-disturb pressure above threshold must trigger a scale-up")
|
||||
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("pressure-model"))
|
||||
})
|
||||
|
||||
It("does not scale up on pressure when already at max_replicas", func() {
|
||||
// Two nodes, both loaded (idle), MaxReplicas=2 → at max. Pressure is
|
||||
// high but MaxReplicas must never be overridden.
|
||||
node1 := registerNode("pmax-1", "10.0.0.61:50051")
|
||||
node2 := registerNode("pmax-2", "10.0.0.62:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node1.ID, "pmax-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node2.ID, "pmax-model", 0, "loaded", "addr2", 0)).To(Succeed())
|
||||
setSchedulingConfig("pmax-model", 1, 2, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pmax-model", time.Now())
|
||||
pressure.Record("pmax-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node1}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty(),
|
||||
"pressure must never override MaxReplicas")
|
||||
})
|
||||
|
||||
It("consumes the pressure signal so a single burst scales up only once", func() {
|
||||
// A single burst of forced-disturbs (well within the window) must
|
||||
// trigger exactly ONE pressure scale-up. A subsequent tick, with the
|
||||
// SAME events still in-window, must NOT scale again: the first
|
||||
// scale-up consumed (Reset) the signal. Without the fix, the
|
||||
// non-draining Count keeps returning >= threshold every tick and
|
||||
// drives the model toward MaxReplicas off a single burst.
|
||||
node := registerNode("consume-node", "10.0.0.64:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "consume-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("consume-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("consume-model", now)
|
||||
pressure.Record("consume-model", now)
|
||||
pressure.Record("consume-model", now)
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
// First tick: pressure above threshold → one scale-up.
|
||||
reconciler.reconcile(context.Background())
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"first tick must scale up once on the burst")
|
||||
|
||||
// Second tick: the burst's events are still inside the window, but
|
||||
// the first scale-up Reset them, so no further scale-up occurs.
|
||||
reconciler.reconcile(context.Background())
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"a single burst must not re-trigger scale-up on the next in-window tick")
|
||||
})
|
||||
|
||||
It("does not consume the pressure signal when scaleUp fails", func() {
|
||||
// Pressure above threshold and capacity exists, but the scheduler
|
||||
// errors so no replica is actually added. The forced-disturb signal
|
||||
// must be preserved (NOT Reset) so the next tick retries the
|
||||
// scale-up off the same accumulated pressure, instead of having to
|
||||
// re-accumulate a full window of forced-disturbs from scratch.
|
||||
node := registerNode("fail-node", "10.0.0.66:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "fail-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("fail-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("fail-model", now)
|
||||
pressure.Record("fail-model", now)
|
||||
pressure.Record("fail-model", now)
|
||||
Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1))
|
||||
|
||||
// Scheduler errors: scaleUp attempts but adds nothing.
|
||||
scheduler := &fakeScheduler{scheduleErr: errors.New("schedule boom")}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"scaleUp must have attempted exactly one schedule call")
|
||||
Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1),
|
||||
"a failed scaleUp must NOT consume (Reset) the pressure signal — next tick should retry")
|
||||
})
|
||||
|
||||
It("consumes the pressure signal only when scaleUp succeeds", func() {
|
||||
// Mirror of the failure case: when the scheduler succeeds and a
|
||||
// replica is actually added, the forced-disturb signal IS consumed
|
||||
// (Reset to 0) so a single burst scales up only once.
|
||||
node := registerNode("ok-node", "10.0.0.67:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "ok-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
setSchedulingConfig("ok-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
now := time.Now()
|
||||
pressure.Record("ok-model", now)
|
||||
pressure.Record("ok-model", now)
|
||||
pressure.Record("ok-model", now)
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"successful scaleUp must have scheduled one replica")
|
||||
Expect(pressure.Count("ok-model", time.Now())).To(Equal(0),
|
||||
"a successful scaleUp must consume (Reset) the pressure signal to 0")
|
||||
})
|
||||
|
||||
It("performs at most one scale-up per tick when both busy and over pressure", func() {
|
||||
// The single loaded replica is busy (all-replicas-busy fires) AND
|
||||
// pressure is above threshold. Both scale-up paths are eligible in
|
||||
// the same tick. The invariant is at-most-one scaleUp(+1) per tick,
|
||||
// so exactly one schedule call must happen, not two.
|
||||
node := registerNode("dual-node", "10.0.0.65:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "dual-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dual-model", 0)).To(Succeed())
|
||||
setSchedulingConfig("dual-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("dual-model", time.Now())
|
||||
pressure.Record("dual-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1),
|
||||
"busy + pressure in one tick must still scale up by exactly one, not two")
|
||||
})
|
||||
|
||||
It("does not spin when pressure is high but no capacity exists", func() {
|
||||
// Single node, cap 1, already loaded → capacity 0. Pressure is high
|
||||
// but there is nowhere to place a replica: must not call scheduler.
|
||||
registerCappedNodeFn := func(name, address string, cap int) *BackendNode {
|
||||
node := &BackendNode{
|
||||
Name: name,
|
||||
NodeType: NodeTypeBackend,
|
||||
Address: address,
|
||||
MaxReplicasPerModel: cap,
|
||||
}
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
return node
|
||||
}
|
||||
node := registerCappedNodeFn("pcap-node", "10.0.0.63:50051", 1)
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "pcap-model", 0, "loaded", "addr1", 0)).To(Succeed())
|
||||
// MaxReplicas high enough that replicas<max, so only capacity guards it.
|
||||
setSchedulingConfig("pcap-model", 1, 4, "")
|
||||
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
pressure.Record("pcap-model", time.Now())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: node}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty(),
|
||||
"no capacity means no scale-up: must not spin the scheduler")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Capacity gating + circuit breaker (PR4)", func() {
|
||||
// Helper: register a node with an explicit per-model replica cap.
|
||||
// Tests in this Describe block want to exercise both "fits" and
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -141,6 +142,13 @@ type ModelSchedulingConfig struct {
|
||||
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
|
||||
MinReplicas int `gorm:"default:0" json:"min_replicas"`
|
||||
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
|
||||
// Prefix-cache-aware routing (epic #10063). RoutePolicy "" means inherit
|
||||
// the cluster-wide default. Thresholds are per-model overrides; 0 means
|
||||
// inherit the global default.
|
||||
RoutePolicy string `gorm:"column:route_policy;size:32" json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold int `gorm:"column:balance_abs_threshold;default:0" json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold float64 `gorm:"column:balance_rel_threshold;default:0" json:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch float64 `gorm:"column:min_prefix_match;default:0" json:"min_prefix_match,omitempty"`
|
||||
// UnsatisfiableUntil is set by the reconciler when no candidate node has
|
||||
// free capacity for this model; while in the future, the reconciler skips
|
||||
// scale-up attempts for this model. Cleared on cluster events that could
|
||||
@@ -196,6 +204,58 @@ const (
|
||||
// NodeRegistry manages backend node registration and lookup in PostgreSQL.
|
||||
type NodeRegistry struct {
|
||||
db *gorm.DB
|
||||
// replicaRemovedHook is invoked after a replica row for (modelName, nodeID)
|
||||
// is removed. It is the single chokepoint that lets the prefix-cache index
|
||||
// be invalidated no matter which removal path (router eviction, reconciler
|
||||
// scale-down, probe reaper, health-monitor reap, RemoteUnloaderAdapter) ran.
|
||||
// The replicaIndex argument is the SPECIFIC replica removed, or negative to
|
||||
// signal "all replicas of (modelName, nodeID)". Stored in an atomic.Pointer
|
||||
// so the startup wiring (setter) and request / reconcile handling (fire) are
|
||||
// race-free.
|
||||
replicaRemovedHook atomic.Pointer[func(modelName, nodeID string, replicaIndex int)]
|
||||
}
|
||||
|
||||
// SetReplicaRemovedHook registers a callback invoked after a replica row for
|
||||
// (modelName, nodeID) is removed from the registry. replicaIndex is the
|
||||
// specific replica removed, or negative to mean "all replicas of the node".
|
||||
// Used to invalidate the prefix-cache index so it never points at a replica
|
||||
// that no longer hosts the model. Set once at startup before serving. Safe to
|
||||
// leave unset (no-op).
|
||||
func (r *NodeRegistry) SetReplicaRemovedHook(fn func(modelName, nodeID string, replicaIndex int)) {
|
||||
r.replicaRemovedHook.Store(&fn)
|
||||
}
|
||||
|
||||
// fireReplicaRemoved invokes the replica-removed hook if one is set. A negative
|
||||
// replicaIndex means all replicas of (modelName, nodeID). Nil-safe.
|
||||
func (r *NodeRegistry) fireReplicaRemoved(modelName, nodeID string, replicaIndex int) {
|
||||
if fn := r.replicaRemovedHook.Load(); fn != nil && *fn != nil {
|
||||
(*fn)(modelName, nodeID, replicaIndex)
|
||||
}
|
||||
}
|
||||
|
||||
// nodeModelNames returns the DISTINCT model names that have node_models rows
|
||||
// for nodeID, using db (which may be a transaction handle). Used by the bulk
|
||||
// node-scoped delete paths (Register re-register cleanup, MarkOffline,
|
||||
// MarkDraining, Deregister) to capture what will be removed BEFORE the delete
|
||||
// so they can fire the replica-removed hook once per distinct model afterwards
|
||||
// and keep the prefix-cache index from pointing at a node that no longer hosts
|
||||
// the model. Skips the query entirely when no hook is set (these are lifecycle
|
||||
// ops, not the request hot path, but the query is pure overhead with no hook).
|
||||
func (r *NodeRegistry) nodeModelNames(ctx context.Context, db *gorm.DB, nodeID string) []string {
|
||||
if fn := r.replicaRemovedHook.Load(); fn == nil || *fn == nil {
|
||||
return nil
|
||||
}
|
||||
var names []string
|
||||
if err := db.WithContext(ctx).Model(&NodeModel{}).
|
||||
Where("node_id = ?", nodeID).
|
||||
Distinct().
|
||||
Pluck("model_name", &names).Error; err != nil {
|
||||
// Non-fatal: proceed with the delete, just skip hook invalidation.
|
||||
// A stale prefix-cache entry self-heals on the next routing miss.
|
||||
xlog.Warn("Failed to enumerate node models before bulk delete; skipping prefix-cache invalidation", "node", nodeID, "error", err)
|
||||
return nil
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// NewNodeRegistry creates a NodeRegistry and auto-migrates the schema.
|
||||
@@ -280,9 +340,24 @@ func (r *NodeRegistry) Register(ctx context.Context, node *BackendNode, autoAppr
|
||||
if node.APIKeyID == "" {
|
||||
node.APIKeyID = existing.APIKeyID
|
||||
}
|
||||
// Clear stale model records — the node restarted and has nothing loaded
|
||||
if err := r.db.WithContext(ctx).Where("node_id = ?", existing.ID).Delete(&NodeModel{}).Error; err != nil {
|
||||
// Clear stale model records — the node restarted and has nothing loaded.
|
||||
// Capture the distinct models and run the bulk delete inside a single
|
||||
// transaction so the set of fired hooks equals exactly the set of rows
|
||||
// deleted: a SetNodeModel landing between the capture and the delete can
|
||||
// no longer be deleted without its hook firing (no interleaving gap).
|
||||
// Fire the hooks only after the transaction commits so a rollback does
|
||||
// not invalidate the prefix-cache index for a removal that did not
|
||||
// persist (the single chokepoint must cover this path too).
|
||||
var removedModels []string
|
||||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
removedModels = r.nodeModelNames(ctx, tx, existing.ID)
|
||||
return tx.Where("node_id = ?", existing.ID).Delete(&NodeModel{}).Error
|
||||
}); err != nil {
|
||||
xlog.Warn("Failed to clear stale model records on re-register", "node", node.Name, "error", err)
|
||||
} else {
|
||||
for _, m := range removedModels {
|
||||
r.fireReplicaRemoved(m, existing.ID, -1)
|
||||
}
|
||||
}
|
||||
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Create new node
|
||||
@@ -359,9 +434,24 @@ func (r *NodeRegistry) MarkOffline(ctx context.Context, nodeID string) error {
|
||||
if err := r.setStatus(ctx, nodeID, StatusOffline); err != nil {
|
||||
return err
|
||||
}
|
||||
// Clear model records — node is shutting down
|
||||
if err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
|
||||
// Clear model records — node is shutting down. Capture the distinct models
|
||||
// and run the bulk delete inside a single transaction so the set of fired
|
||||
// hooks equals exactly the set of rows deleted: a SetNodeModel landing
|
||||
// between the capture and the delete can no longer be deleted without its
|
||||
// hook firing (no interleaving gap). The status flip above is a separate,
|
||||
// pre-existing operation and routing already filters non-healthy nodes, so
|
||||
// it stays outside this transaction. Fire hooks only after commit so a
|
||||
// rollback does not invalidate the index for a removal that did not persist.
|
||||
var removedModels []string
|
||||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
removedModels = r.nodeModelNames(ctx, tx, nodeID)
|
||||
return tx.Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error
|
||||
}); err != nil {
|
||||
xlog.Warn("Failed to clear model records on offline", "node", nodeID, "error", err)
|
||||
} else {
|
||||
for _, m := range removedModels {
|
||||
r.fireReplicaRemoved(m, nodeID, -1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -468,7 +558,11 @@ func (r *NodeRegistry) Deregister(ctx context.Context, nodeID string) error {
|
||||
return fmt.Errorf("node %s not found: %w", nodeID, err)
|
||||
}
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// Capture the distinct models removed so the prefix-cache index can be
|
||||
// invalidated once the transaction commits.
|
||||
var removedModels []string
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
removedModels = r.nodeModelNames(ctx, tx, nodeID)
|
||||
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
|
||||
return fmt.Errorf("deleting node models for %s: %w", nodeID, err)
|
||||
}
|
||||
@@ -483,7 +577,13 @@ func (r *NodeRegistry) Deregister(ctx context.Context, nodeID string) error {
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range removedModels {
|
||||
r.fireReplicaRemoved(m, nodeID, -1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HeartbeatUpdate contains optional fields to update on heartbeat.
|
||||
@@ -600,8 +700,23 @@ func (r *NodeRegistry) MarkDraining(ctx context.Context, nodeID string) error {
|
||||
if err := r.setStatus(ctx, nodeID, StatusDraining); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
|
||||
// Capture the distinct models and run the bulk delete inside a single
|
||||
// transaction so the set of fired hooks equals exactly the set of rows
|
||||
// deleted: a SetNodeModel landing between the capture and the delete can no
|
||||
// longer be deleted without its hook firing (no interleaving gap). The
|
||||
// status flip above is a separate, pre-existing operation and stays outside
|
||||
// this transaction. Fire hooks only after commit so a rollback does not
|
||||
// invalidate the index for a removal that did not persist.
|
||||
var removedModels []string
|
||||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
removedModels = r.nodeModelNames(ctx, tx, nodeID)
|
||||
return tx.Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error
|
||||
}); err != nil {
|
||||
xlog.Warn("Failed to clear model records on draining", "node", nodeID, "error", err)
|
||||
} else {
|
||||
for _, m := range removedModels {
|
||||
r.fireReplicaRemoved(m, nodeID, -1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -712,16 +827,25 @@ func (r *NodeRegistry) GetModelLoadInfo(ctx context.Context, modelName string) (
|
||||
// to keep the contract explicit (probeLoadedModels and scaleDownIdle iterate
|
||||
// per-row and must not orphan healthy siblings).
|
||||
func (r *NodeRegistry) RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error {
|
||||
return r.db.WithContext(ctx).Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
|
||||
Delete(&NodeModel{}).Error
|
||||
if err := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
|
||||
Delete(&NodeModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
r.fireReplicaRemoved(modelName, nodeID, replicaIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllNodeModelReplicas removes every replica of modelName on nodeID.
|
||||
// Used by callers (e.g. node deregistration, full backend stop) that genuinely
|
||||
// want to clear all replicas, not just one.
|
||||
func (r *NodeRegistry) RemoveAllNodeModelReplicas(ctx context.Context, nodeID, modelName string) error {
|
||||
return r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
||||
Delete(&NodeModel{}).Error
|
||||
if err := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
|
||||
Delete(&NodeModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Negative index signals "all replicas of (modelName, nodeID)".
|
||||
r.fireReplicaRemoved(modelName, nodeID, -1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindNodesWithModel returns nodes that have the given model loaded.
|
||||
@@ -737,6 +861,17 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// RoutePreference biases FindAndLockNodeWithModel. PreferredNodeID +
|
||||
// PreferredReplica, when set and the exact (node, replica) row is still
|
||||
// loaded/healthy, is locked instead of the default ORDER BY pick. The caller
|
||||
// (the prefix-cache router) has already applied the load guard, so the lock
|
||||
// targets the EXACT replica it chose, not the least-loaded replica on the node.
|
||||
// Nil preference => unchanged behavior.
|
||||
type RoutePreference struct {
|
||||
PreferredNodeID string
|
||||
PreferredReplica int
|
||||
}
|
||||
|
||||
// FindAndLockNodeWithModel atomically finds the best loaded replica of the
|
||||
// given model and increments its in-flight counter within a single
|
||||
// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction
|
||||
@@ -758,7 +893,7 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
|
||||
// NodeSelector so a cached replica on a now-excluded node isn't picked over a
|
||||
// matching replica elsewhere — the selector-mismatch fall-through path used to
|
||||
// trigger an eviction-busy loop when both sides had the model loaded.
|
||||
func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error) {
|
||||
func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
|
||||
var nm NodeModel
|
||||
var node BackendNode
|
||||
|
||||
@@ -781,17 +916,33 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
|
||||
// stale row in the same window, and other helpers that mirror this
|
||||
// JOIN need the same invariant. Belt-and-braces: status filter here
|
||||
// AND the status-checked node fetch below.
|
||||
q := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
base := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
||||
Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
|
||||
modelName, "loaded", StatusHealthy)
|
||||
if len(candidateNodeIDs) > 0 {
|
||||
q = q.Where("node_models.node_id IN ?", candidateNodeIDs)
|
||||
base = base.Where("node_models.node_id IN ?", candidateNodeIDs)
|
||||
}
|
||||
if err := q.
|
||||
Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC").
|
||||
First(&nm).Error; err != nil {
|
||||
return err
|
||||
|
||||
picked := false
|
||||
if pref != nil && pref.PreferredNodeID != "" {
|
||||
// Lock the EXACT (node_id, replica_index) row the caller chose. The
|
||||
// caller (prefix-cache router) has already applied the load guard
|
||||
// per replica, so here we only require that exact replica still be
|
||||
// loaded+healthy. Fall through to the default ORDER BY when that
|
||||
// specific replica is not found/loaded.
|
||||
q := base.Session(&gorm.Session{}).
|
||||
Where("node_models.node_id = ? AND node_models.replica_index = ?", pref.PreferredNodeID, pref.PreferredReplica)
|
||||
if err := q.First(&nm).Error; err == nil {
|
||||
picked = true
|
||||
}
|
||||
}
|
||||
if !picked {
|
||||
if err := base.
|
||||
Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC").
|
||||
First(&nm).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(&nm).Updates(map[string]any{
|
||||
@@ -815,6 +966,47 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
|
||||
return &node, &nm, nil
|
||||
}
|
||||
|
||||
// LoadedReplicaStats returns one ReplicaCandidate per loaded+healthy replica of
|
||||
// modelName, carrying its current in-flight count. It is a read used by the
|
||||
// prefix-cache router to apply the load guard when choosing a preferred node.
|
||||
// When candidateNodeIDs is non-empty, only replicas on those nodes are
|
||||
// returned; pass nil to consider any healthy node. The result is never nil;
|
||||
// an empty slice means no loaded replica exists.
|
||||
func (r *NodeRegistry) LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error) {
|
||||
type row struct {
|
||||
NodeID string
|
||||
Address string
|
||||
ReplicaIndex int
|
||||
InFlight int
|
||||
LastUsed time.Time
|
||||
AvailableVRAM uint64
|
||||
}
|
||||
q := r.db.WithContext(ctx).Model(&NodeModel{}).
|
||||
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
||||
Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
|
||||
modelName, "loaded", StatusHealthy)
|
||||
if len(candidateNodeIDs) > 0 {
|
||||
q = q.Where("node_models.node_id IN ?", candidateNodeIDs)
|
||||
}
|
||||
|
||||
// Narrow to only the columns the sole consumer (router buildPreference)
|
||||
// reads: NodeID and InFlight. The other ReplicaCandidate fields stay at
|
||||
// their zero value, which the consumer does not read. This avoids the
|
||||
// JOIN-side available_vram fetch and the extra column transfer.
|
||||
var rows []row
|
||||
err := q.Select("node_models.node_id AS node_id, node_models.in_flight AS in_flight").
|
||||
Scan(&rows).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading replica stats for %s: %w", modelName, err)
|
||||
}
|
||||
|
||||
out := make([]ReplicaCandidate, 0, len(rows))
|
||||
for _, rw := range rows {
|
||||
out = append(out, ReplicaCandidate(rw))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TouchNodeModel updates the last_used timestamp for LRU tracking on a single
|
||||
// replica row.
|
||||
func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) {
|
||||
@@ -1198,8 +1390,12 @@ func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSche
|
||||
}
|
||||
return r.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
|
||||
Columns: []clause.Column{{Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{
|
||||
"node_selector", "min_replicas", "max_replicas",
|
||||
"route_policy", "balance_abs_threshold", "balance_rel_threshold", "min_prefix_match",
|
||||
"updated_at",
|
||||
}),
|
||||
}).
|
||||
Create(config).Error
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed())
|
||||
|
||||
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil)
|
||||
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(foundNode.ID).To(Equal(node.ID))
|
||||
Expect(foundNM.ModelName).To(Equal("my-model"))
|
||||
@@ -257,7 +257,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
})
|
||||
|
||||
It("returns error when model is not loaded anywhere", func() {
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil)
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -274,7 +274,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
|
||||
|
||||
foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil)
|
||||
foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(foundNode.Name).To(Equal("lock-light"))
|
||||
})
|
||||
@@ -299,7 +299,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
|
||||
|
||||
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID})
|
||||
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(foundNode.ID).To(Equal(included.ID))
|
||||
Expect(foundNM.NodeID).To(Equal(included.ID))
|
||||
@@ -326,7 +326,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
// (FindAndLockNodeWithModel atomically increments to lock the row.)
|
||||
picks := make([]string, 0, 9)
|
||||
for i := 0; i < 9; i++ {
|
||||
n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil)
|
||||
n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
picks = append(picks, n.Name)
|
||||
Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed())
|
||||
@@ -355,7 +355,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
// query must return an error so Route() falls through to schedule
|
||||
// a fresh load on a matching node instead of reusing the excluded
|
||||
// replica.
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID})
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -422,7 +422,7 @@ var _ = Describe("NodeRegistry", func() {
|
||||
goPick := PickBestReplica(candidates)
|
||||
Expect(goPick).ToNot(BeNil())
|
||||
|
||||
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil)
|
||||
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
|
||||
@@ -433,6 +433,124 @@ var _ = Describe("NodeRegistry", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("FindAndLockNodeWithModel preference", func() {
|
||||
var nodeA, nodeB *BackendNode
|
||||
|
||||
BeforeEach(func() {
|
||||
nodeA = makeNode("pref-a", "10.0.0.70:50051", 8_000_000_000)
|
||||
nodeB = makeNode("pref-b", "10.0.0.71:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), nodeA, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), nodeB, true)).To(Succeed())
|
||||
// Both loaded+healthy for model "pref-model", in_flight 0.
|
||||
Expect(registry.SetNodeModel(context.Background(), nodeA.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), nodeB.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("locks the preferred node when eligible", func() {
|
||||
node, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: nodeB.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(node.ID).To(Equal(nodeB.ID))
|
||||
Expect(nm.NodeID).To(Equal(nodeB.ID))
|
||||
|
||||
// in_flight is incremented atomically via gorm.Expr, so verify the
|
||||
// persisted value through a re-fetch (the returned struct mirrors
|
||||
// the pre-increment read, like the default-pick path).
|
||||
persisted, err := registry.GetNodeModel(context.Background(), nodeB.ID, "pref-model", 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(persisted.InFlight).To(Equal(1))
|
||||
})
|
||||
|
||||
It("falls back to default order when preferred not loaded", func() {
|
||||
node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: "ZZZ"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(node.ID).To(BeElementOf(nodeA.ID, nodeB.ID))
|
||||
})
|
||||
|
||||
It("nil preference behaves like before", func() {
|
||||
node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(node).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("locks the EXACT preferred replica when the node hosts two replicas", func() {
|
||||
// A single node hosts replica 0 and replica 1 of a model, both
|
||||
// loaded+healthy. The preference must lock the SPECIFIC replica
|
||||
// requested, not the least-loaded replica on the node.
|
||||
node := makeNode("pref-multi", "10.0.0.72:50051", 16_000_000_000)
|
||||
node.MaxReplicasPerModel = 2
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "addr0", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "addr1", 0)).To(Succeed())
|
||||
|
||||
// pref={node, 1} must lock replica 1 specifically.
|
||||
gotNode, nm1, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
|
||||
&RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(gotNode.ID).To(Equal(node.ID))
|
||||
Expect(nm1.ReplicaIndex).To(Equal(1))
|
||||
|
||||
// pref={node, 0} must lock replica 0 specifically.
|
||||
_, nm0, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
|
||||
&RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(nm0.ReplicaIndex).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LoadedReplicaStats", func() {
|
||||
var n1, n2, n3 *BackendNode
|
||||
|
||||
BeforeEach(func() {
|
||||
n1 = makeNode("stats-1", "10.0.0.80:50051", 8_000_000_000)
|
||||
n2 = makeNode("stats-2", "10.0.0.81:50051", 8_000_000_000)
|
||||
n3 = makeNode("stats-3", "10.0.0.82:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n3, true)).To(Succeed())
|
||||
// n1 loaded+busy, n2 loaded+idle, n3 has a different model only.
|
||||
Expect(registry.SetNodeModel(context.Background(), n1.ID, "stats-model", 0, "loaded", "10.0.0.80:6000", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n2.ID, "stats-model", 0, "loaded", "10.0.0.81:6000", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n3.ID, "other-model", 0, "loaded", "", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns loaded healthy replicas with in-flight counts", func() {
|
||||
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stats).To(HaveLen(2))
|
||||
byNode := map[string]ReplicaCandidate{}
|
||||
for _, s := range stats {
|
||||
byNode[s.NodeID] = s
|
||||
}
|
||||
Expect(byNode).To(HaveKey(n1.ID))
|
||||
Expect(byNode).To(HaveKey(n2.ID))
|
||||
Expect(byNode[n1.ID].InFlight).To(Equal(2))
|
||||
Expect(byNode[n2.ID].InFlight).To(Equal(0))
|
||||
})
|
||||
|
||||
It("filters to the candidate node set when provided", func() {
|
||||
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", []string{n2.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stats).To(HaveLen(1))
|
||||
Expect(stats[0].NodeID).To(Equal(n2.ID))
|
||||
})
|
||||
|
||||
It("excludes unhealthy nodes", func() {
|
||||
Expect(registry.MarkUnhealthy(context.Background(), n1.ID)).To(Succeed())
|
||||
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stats).To(HaveLen(1))
|
||||
Expect(stats[0].NodeID).To(Equal(n2.ID))
|
||||
})
|
||||
|
||||
It("returns empty for a model with no loaded replicas", func() {
|
||||
stats, err := registry.LoadedReplicaStats(context.Background(), "no-such-model", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stats).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {
|
||||
It("transitions healthy -> unhealthy -> healthy", func() {
|
||||
node := makeNode("roundtrip-node", "10.0.0.60:50051", 8_000_000_000)
|
||||
@@ -632,6 +750,30 @@ var _ = Describe("NodeRegistry", func() {
|
||||
Expect(fetched.MaxReplicas).To(Equal(5))
|
||||
})
|
||||
|
||||
It("persists and updates route policy and thresholds", func() {
|
||||
err := registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "prefix-cache-model", RoutePolicy: "prefix_cache",
|
||||
BalanceAbsThreshold: 3, BalanceRelThreshold: 2.0, MinPrefixMatch: 0.4,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
got, err := registry.GetModelScheduling(context.Background(), "prefix-cache-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
|
||||
Expect(got.BalanceAbsThreshold).To(Equal(3))
|
||||
Expect(got.BalanceRelThreshold).To(BeNumerically("==", 2.0))
|
||||
Expect(got.MinPrefixMatch).To(BeNumerically("==", 0.4))
|
||||
|
||||
// Update must not be dropped on conflict.
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "prefix-cache-model", RoutePolicy: "round_robin",
|
||||
})).ToNot(HaveOccurred())
|
||||
|
||||
got, err = registry.GetModelScheduling(context.Background(), "prefix-cache-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got.RoutePolicy).To(Equal("round_robin"))
|
||||
})
|
||||
|
||||
It("lists all configs", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
|
||||
@@ -903,6 +1045,187 @@ var _ = Describe("NodeRegistry", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetReplicaRemovedHook", func() {
|
||||
type removed struct {
|
||||
model, node string
|
||||
replica int
|
||||
}
|
||||
|
||||
It("fires once with the specific replica after RemoveNodeModel", func() {
|
||||
node := makeNode("hook-remove-one", "10.0.0.230:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-model", 1, "loaded", "a", 0)).To(Succeed())
|
||||
|
||||
var fired []removed
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
|
||||
})
|
||||
|
||||
// RemoveNodeModel(replica 1) must fire with the SPECIFIC replica index.
|
||||
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "hook-model", 1)).To(Succeed())
|
||||
Expect(fired).To(HaveLen(1))
|
||||
Expect(fired[0]).To(Equal(removed{model: "hook-model", node: node.ID, replica: 1}))
|
||||
})
|
||||
|
||||
It("fires once with replica<0 after RemoveAllNodeModelReplicas", func() {
|
||||
node := makeNode("hook-remove-all", "10.0.0.231:50051", 16_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 0, "loaded", "a", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 1, "loaded", "b", 0)).To(Succeed())
|
||||
|
||||
var fired []removed
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
|
||||
})
|
||||
|
||||
// One call covers all replicas of that model on the node: a negative
|
||||
// replica index signals "all replicas", and the consumer's
|
||||
// InvalidateNode drops every entry for the (model, node) pair.
|
||||
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "hook-all-model")).To(Succeed())
|
||||
Expect(fired).To(HaveLen(1))
|
||||
Expect(fired[0].model).To(Equal("hook-all-model"))
|
||||
Expect(fired[0].node).To(Equal(node.ID))
|
||||
Expect(fired[0].replica).To(BeNumerically("<", 0))
|
||||
})
|
||||
|
||||
It("does not panic when no hook is set", func() {
|
||||
node := makeNode("hook-unset", "10.0.0.232:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "no-hook-model", 0, "loaded", "a", 0)).To(Succeed())
|
||||
|
||||
Expect(func() {
|
||||
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "no-hook-model", 0)).To(Succeed())
|
||||
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "no-hook-model")).To(Succeed())
|
||||
}).ToNot(Panic())
|
||||
})
|
||||
|
||||
// firedModelSet collects the distinct model names the hook saw for the
|
||||
// given node. The bulk node-scoped deletes below remove every replica of
|
||||
// every model on the node in one statement, so the chokepoint must fire
|
||||
// the hook once per distinct model name (the consumer's Invalidate
|
||||
// drops all entries for that (model, node) pair).
|
||||
seedTwoModels := func(node *BackendNode) {
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 0, "loaded", "a0", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 1, "loaded", "a1", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", 0, "loaded", "b0", 0)).To(Succeed())
|
||||
}
|
||||
|
||||
It("fires once per distinct model after MarkOffline", func() {
|
||||
node := makeNode("hook-offline", "10.0.0.240:50051", 8_000_000_000)
|
||||
seedTwoModels(node)
|
||||
|
||||
fired := map[removed]int{}
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
// Bulk node-scoped deletes signal "all replicas" with replica<0.
|
||||
Expect(replicaIndex).To(BeNumerically("<", 0))
|
||||
fired[removed{model: modelName, node: nodeID}]++
|
||||
})
|
||||
|
||||
Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
|
||||
Expect(fired).To(HaveLen(2))
|
||||
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
|
||||
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
|
||||
})
|
||||
|
||||
It("fires once per distinct model after MarkDraining", func() {
|
||||
node := makeNode("hook-draining", "10.0.0.241:50051", 8_000_000_000)
|
||||
seedTwoModels(node)
|
||||
|
||||
fired := map[removed]int{}
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
// Bulk node-scoped deletes signal "all replicas" with replica<0.
|
||||
Expect(replicaIndex).To(BeNumerically("<", 0))
|
||||
fired[removed{model: modelName, node: nodeID}]++
|
||||
})
|
||||
|
||||
Expect(registry.MarkDraining(context.Background(), node.ID)).To(Succeed())
|
||||
Expect(fired).To(HaveLen(2))
|
||||
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
|
||||
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
|
||||
})
|
||||
|
||||
It("fires once per distinct model after Deregister", func() {
|
||||
node := makeNode("hook-deregister", "10.0.0.242:50051", 8_000_000_000)
|
||||
seedTwoModels(node)
|
||||
|
||||
fired := map[removed]int{}
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
// Bulk node-scoped deletes signal "all replicas" with replica<0.
|
||||
Expect(replicaIndex).To(BeNumerically("<", 0))
|
||||
fired[removed{model: modelName, node: nodeID}]++
|
||||
})
|
||||
|
||||
Expect(registry.Deregister(context.Background(), node.ID)).To(Succeed())
|
||||
Expect(fired).To(HaveLen(2))
|
||||
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
|
||||
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
|
||||
})
|
||||
|
||||
It("fires once per distinct model when re-registration clears stale rows", func() {
|
||||
node := makeNode("hook-reregister", "10.0.0.243:50051", 8_000_000_000)
|
||||
seedTwoModels(node)
|
||||
|
||||
fired := map[removed]int{}
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
// Bulk node-scoped deletes signal "all replicas" with replica<0.
|
||||
Expect(replicaIndex).To(BeNumerically("<", 0))
|
||||
fired[removed{model: modelName, node: nodeID}]++
|
||||
})
|
||||
|
||||
// Re-register the same node (same name): the re-register path
|
||||
// clears the stale model rows, which must fire the hook.
|
||||
again := makeNode("hook-reregister", "10.0.0.243:50052", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), again, true)).To(Succeed())
|
||||
Expect(fired).To(HaveLen(2))
|
||||
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
|
||||
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
|
||||
})
|
||||
|
||||
// Atomicity: the bulk node-scoped delete in MarkOffline/MarkDraining/
|
||||
// re-register now captures the model names and deletes the rows inside a
|
||||
// single transaction. A true SetNodeModel-between-capture-and-delete race
|
||||
// can't be forced deterministically here, but we can assert the
|
||||
// post-condition the transaction guarantees: the set of fired hooks
|
||||
// equals exactly the set of node_models rows the operation removed, with
|
||||
// nothing left behind. If the capture and delete ever saw inconsistent
|
||||
// snapshots, either a surviving row (delete missed it) or a missing hook
|
||||
// (capture missed it) would break one of these assertions.
|
||||
It("MarkOffline fires hooks for exactly the rows it deletes (consistent snapshot)", func() {
|
||||
node := makeNode("hook-atomic-offline", "10.0.0.244:50051", 8_000_000_000)
|
||||
seedTwoModels(node)
|
||||
|
||||
// Capture what the transaction should remove, straight from the DB,
|
||||
// before running the operation.
|
||||
before, err := registry.GetNodeModels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectedModels := map[string]struct{}{}
|
||||
for _, nm := range before {
|
||||
expectedModels[nm.ModelName] = struct{}{}
|
||||
}
|
||||
Expect(expectedModels).To(HaveLen(2), "seed should create two distinct models")
|
||||
|
||||
fired := map[string]struct{}{}
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
|
||||
Expect(nodeID).To(Equal(node.ID))
|
||||
Expect(replicaIndex).To(BeNumerically("<", 0))
|
||||
fired[modelName] = struct{}{}
|
||||
})
|
||||
|
||||
Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
|
||||
|
||||
// Hooks fired for exactly the distinct models that existed.
|
||||
Expect(fired).To(Equal(expectedModels),
|
||||
"hooks must fire for exactly the set of models the transaction deleted")
|
||||
|
||||
// And the delete actually emptied the node_models rows for the node:
|
||||
// no row survives that did not get a hook.
|
||||
after, err := registry.GetNodeModels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(after).To(BeEmpty(), "no node_models row should survive the bulk delete")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ApplyAutoLabels", func() {
|
||||
It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() {
|
||||
node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000)
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
@@ -43,6 +45,22 @@ type SmartRouterOptions struct {
|
||||
// anti-affinity is disabled at the scheduler layer; the per-node
|
||||
// watchdog still enforces the rule on arrival.
|
||||
ConflictResolver ConcurrencyConflictResolver
|
||||
// PrefixProvider, when set, enables prefix-cache-aware routing: requests
|
||||
// carrying a prompt prefix chain (distributedhdr.PrefixChain) are biased
|
||||
// toward the node that already holds the longest matching prefix, subject
|
||||
// to the load guard in prefixcache.Select. nil disables it entirely and
|
||||
// routing is byte-for-byte the round-robin floor. At runtime this is the
|
||||
// *prefixcache.Sync so Observe/Invalidate broadcast to peers.
|
||||
PrefixProvider prefixcache.Provider
|
||||
// PrefixConfig holds the global policy + thresholds. Per-model overrides on
|
||||
// ModelSchedulingConfig refine it per request. Unused when PrefixProvider
|
||||
// is nil.
|
||||
PrefixConfig prefixcache.Config
|
||||
// Pressure, when set, records a forced-disturb each time a request had a
|
||||
// usable hot prefix match but the load guard forced it off the warm node.
|
||||
// The reconciler reads the same instance to autoscale a saturated cache-warm
|
||||
// replica. nil disables recording (the disabled path stays a no-op).
|
||||
Pressure *prefixcache.Pressure
|
||||
}
|
||||
|
||||
// SmartRouter routes inference requests to the best available backend node.
|
||||
@@ -56,6 +74,14 @@ type SmartRouter struct {
|
||||
db *gorm.DB // for advisory locks during routing
|
||||
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
|
||||
conflictResolver ConcurrencyConflictResolver
|
||||
// prefixProvider is the prefix-cache routing seam (nil disables it; see
|
||||
// SmartRouterOptions.PrefixProvider). prefixConfig holds the global policy
|
||||
// and thresholds.
|
||||
prefixProvider prefixcache.Provider
|
||||
prefixConfig prefixcache.Config
|
||||
// pressure records forced-disturb events (hot match forced off the warm
|
||||
// node by the load guard). nil disables recording. See SmartRouterOptions.
|
||||
pressure *prefixcache.Pressure
|
||||
// installFlight coalesces concurrent identical NATS install requests
|
||||
// (same nodeID + backend + modelID + replica) so 6 simultaneous chat
|
||||
// completions for one not-yet-loaded model produce ONE round-trip, not
|
||||
@@ -91,6 +117,9 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
stagingTracker: NewStagingTracker(),
|
||||
conflictResolver: opts.ConflictResolver,
|
||||
probeCache: newProbeCache(probeCacheTTL),
|
||||
prefixProvider: opts.PrefixProvider,
|
||||
prefixConfig: opts.PrefixConfig,
|
||||
pressure: opts.Pressure,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,18 +259,31 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
trackingKey = modelName
|
||||
}
|
||||
|
||||
// Fetch the model's scheduling config once: it is immutable for the life of
|
||||
// this request, and resolveSelectorCandidates, buildPreference, and
|
||||
// nodeMatchesScheduling all read it. Fetching once gives a consistent
|
||||
// snapshot and avoids three DB round-trips for one row. nil sched means
|
||||
// "no scheduling constraints", same as before.
|
||||
sched, _ := r.registry.GetModelScheduling(ctx, trackingKey)
|
||||
|
||||
// Resolve the model's NodeSelector once so cached-replica lookup and the
|
||||
// new-load scheduler agree on the candidate set. Without this, a cached
|
||||
// replica on a node the selector now excludes was picked over a matching
|
||||
// replica elsewhere, and the fall-through then tried to load on the
|
||||
// matching node where the model was already at capacity (eviction-busy).
|
||||
candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey)
|
||||
candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey, sched)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Compute the prefix-cache preference once for this request. pref biases
|
||||
// FindAndLockNodeWithModel toward the warm-cache node; observeChain is
|
||||
// non-nil only when this model uses prefix_cache, gating the Observe calls
|
||||
// below. Both are nil (no-op) when prefix-cache routing is disabled.
|
||||
pref, observeChain := r.buildPreference(ctx, trackingKey, candidateNodeIDs, sched)
|
||||
|
||||
// Step 1: Find and atomically lock a node with this model loaded
|
||||
node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs)
|
||||
node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref)
|
||||
if err == nil && node != nil {
|
||||
modelAddr := node.Address
|
||||
if nm.Address != "" {
|
||||
@@ -258,7 +300,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
"node", node.Name, "model", modelName, "replica", replicaIdx)
|
||||
} else {
|
||||
// Verify node still matches scheduling constraints
|
||||
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
|
||||
if !r.nodeMatchesScheduling(ctx, node, sched) {
|
||||
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
|
||||
xlog.Info("Cached model on node that no longer matches selector, falling through",
|
||||
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
|
||||
@@ -269,6 +311,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
// onFirstComplete callback releases the reservation after the first inference
|
||||
// call finishes, so in-flight returns to 0 when idle.
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
|
||||
r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx})
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
|
||||
tracked.OnFirstComplete(func() {
|
||||
@@ -288,7 +331,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
// Step 2: Model not loaded — schedule loading with distributed lock to prevent duplicates
|
||||
loadModel := func() (*RouteResult, error) {
|
||||
// Re-check after acquiring lock — another request may have loaded it
|
||||
node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs)
|
||||
node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref)
|
||||
if err == nil && node != nil {
|
||||
modelAddr := node.Address
|
||||
if nm.Address != "" {
|
||||
@@ -305,7 +348,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
"node", node.Name, "model", modelName, "replica", replicaIdx)
|
||||
} else {
|
||||
// Verify node still matches scheduling constraints
|
||||
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
|
||||
if !r.nodeMatchesScheduling(ctx, node, sched) {
|
||||
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
|
||||
xlog.Info("Cached model on node that no longer matches selector, falling through",
|
||||
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
|
||||
@@ -314,6 +357,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
// Model loaded while we waited — FindAndLockNodeWithModel already incremented
|
||||
// in-flight as a reservation. Release it after the first inference completes.
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
|
||||
r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx})
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
|
||||
tracked.OnFirstComplete(func() {
|
||||
@@ -337,6 +381,10 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cold load landed on result.Node replica result.ReplicaIndex: record the
|
||||
// assignment so subsequent requests with the same prefix prefer it.
|
||||
r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: result.Node.ID, Replica: result.ReplicaIndex})
|
||||
|
||||
replicaIdx := result.ReplicaIndex
|
||||
tracked := NewInFlightTrackingClient(result.Client, r.registry, result.Node.ID, trackingKey, replicaIdx)
|
||||
tracked.OnFirstComplete(func() {
|
||||
@@ -389,13 +437,117 @@ func extractNodeIDs(nodes []BackendNode) []string {
|
||||
return ids
|
||||
}
|
||||
|
||||
// buildPreference computes the per-request route preference from the prefix
|
||||
// chain on ctx and the model's resolved policy. The returned observeChain is
|
||||
// non-nil only when the resolved policy is prefix_cache, signalling Route to
|
||||
// record the assignment after a successful pick; for round-robin models it is
|
||||
// nil so the tree is never polluted. The *RoutePreference is non-nil only when
|
||||
// a load-eligible preferred node was chosen.
|
||||
//
|
||||
// When prefix-cache routing is disabled (nil provider), no chain is present,
|
||||
// or the policy resolves to round-robin, both returns are nil and routing is
|
||||
// the unchanged round-robin floor.
|
||||
func (r *SmartRouter) buildPreference(ctx context.Context, modelID string, candidateNodeIDs []string, sched *ModelSchedulingConfig) (*RoutePreference, []uint64) {
|
||||
if r.prefixProvider == nil {
|
||||
return nil, nil
|
||||
}
|
||||
chain := distributedhdr.PrefixChain(ctx)
|
||||
if len(chain) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Resolve per-model policy + thresholds over the global config.
|
||||
policy := r.prefixConfig.GlobalPolicy
|
||||
cfg := r.prefixConfig
|
||||
if sched != nil {
|
||||
policy = prefixcache.ParsePolicy(sched.RoutePolicy).Resolve(r.prefixConfig.GlobalPolicy)
|
||||
if sched.BalanceAbsThreshold > 0 {
|
||||
cfg.BalanceAbsThreshold = sched.BalanceAbsThreshold
|
||||
}
|
||||
if sched.BalanceRelThreshold > 0 {
|
||||
cfg.BalanceRelThreshold = sched.BalanceRelThreshold
|
||||
}
|
||||
if sched.MinPrefixMatch > 0 {
|
||||
cfg.MinPrefixMatch = sched.MinPrefixMatch
|
||||
}
|
||||
}
|
||||
if policy != prefixcache.RoutePolicyPrefixCache {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Load the candidate replicas PER REPLICA. Affinity is tracked per replica
|
||||
// (each replica is a separate process with its own KV cache), so two
|
||||
// replicas of the same model on the same node are two distinct candidates.
|
||||
// FindAndLockNodeWithModel then locks the EXACT (node, replica) the policy
|
||||
// chose.
|
||||
stats, err := r.registry.LoadedReplicaStats(ctx, modelID, candidateNodeIDs)
|
||||
if err != nil {
|
||||
xlog.Debug("prefixcache: loading replica stats failed, skipping preference", "model", modelID, "error", err)
|
||||
return nil, chain
|
||||
}
|
||||
if len(stats) == 0 {
|
||||
return nil, chain
|
||||
}
|
||||
cands := make([]prefixcache.Candidate, 0, len(stats))
|
||||
keys := make([]prefixcache.ReplicaKey, 0, len(stats))
|
||||
for _, s := range stats {
|
||||
key := prefixcache.ReplicaKey{NodeID: s.NodeID, Replica: s.ReplicaIndex}
|
||||
cands = append(cands, prefixcache.Candidate{Key: key, InFlight: s.InFlight})
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
d := r.prefixProvider.Decide(modelID, chain, keys, time.Now())
|
||||
chosen, ok := prefixcache.Select(cands, d, cfg)
|
||||
|
||||
// Observability for the prefix-cache routing decision. One line per request
|
||||
// at Debug: enable with DEBUG=true on the frontend to assess cache-aware
|
||||
// routing. hotMatchHonored=true means we routed to the cache-warm replica;
|
||||
// false with HasHot means the load guard forced a cold pick.
|
||||
xlog.Debug("prefix-cache routing decision",
|
||||
"model", modelID,
|
||||
"chainDepth", len(chain),
|
||||
"candidates", len(cands),
|
||||
"hotNode", d.Hot.NodeID,
|
||||
"hotReplica", d.Hot.Replica,
|
||||
"hasHot", d.HasHot,
|
||||
"matchRatio", d.MatchRatio,
|
||||
"minMatch", cfg.MinPrefixMatch,
|
||||
"chosen", fmt.Sprintf("%s/%d", chosen.NodeID, chosen.Replica),
|
||||
"hotMatchHonored", d.HasHot && chosen == d.Hot)
|
||||
|
||||
// Forced-disturb: a usable hot prefix match existed but the load guard
|
||||
// forced us off the warm replica (Select picked a different replica). This
|
||||
// is the scale-worthy signal - the cache-warm replica is saturated. It
|
||||
// deliberately does not fire for all-unique workloads (no hot match),
|
||||
// avoiding false-positive scale-ups. nil pressure is a no-op.
|
||||
if r.pressure != nil && d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && chosen != d.Hot {
|
||||
r.pressure.Record(modelID, time.Now())
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil, chain
|
||||
}
|
||||
return &RoutePreference{PreferredNodeID: chosen.NodeID, PreferredReplica: chosen.Replica}, chain
|
||||
}
|
||||
|
||||
// observePrefix records that the replica `key` served the request whose prompt
|
||||
// prefix is chain. It is a no-op when prefix-cache routing is disabled or the
|
||||
// chain is empty (round-robin models pass a nil chain so the tree is never
|
||||
// polluted).
|
||||
func (r *SmartRouter) observePrefix(modelID string, chain []uint64, key prefixcache.ReplicaKey) {
|
||||
if r.prefixProvider == nil || len(chain) == 0 {
|
||||
return
|
||||
}
|
||||
r.prefixProvider.Observe(modelID, chain, key, time.Now())
|
||||
xlog.Debug("prefix-cache observed assignment", "model", modelID, "node", key.NodeID, "replica", key.Replica, "chainDepth", len(chain))
|
||||
}
|
||||
|
||||
// resolveSelectorCandidates returns the node IDs that match the model's
|
||||
// NodeSelector. Returns nil when no selector is configured ("any healthy node"
|
||||
// — registry helpers treat nil as no filter). Returns an error when a
|
||||
// non-empty selector matches zero healthy nodes, since there is nothing to
|
||||
// route or schedule on.
|
||||
func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string) ([]string, error) {
|
||||
sched, _ := r.registry.GetModelScheduling(ctx, modelID)
|
||||
func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string, sched *ModelSchedulingConfig) ([]string, error) {
|
||||
if sched == nil || sched.NodeSelector == "" {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -469,9 +621,8 @@ func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID str
|
||||
|
||||
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
|
||||
// Returns true if no constraints exist or the node matches all selector labels.
|
||||
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
|
||||
sched, err := r.registry.GetModelScheduling(ctx, modelName)
|
||||
if err != nil || sched == nil || sched.NodeSelector == "" {
|
||||
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, sched *ModelSchedulingConfig) bool {
|
||||
if sched == nil || sched.NodeSelector == "" {
|
||||
return true // no constraints
|
||||
}
|
||||
|
||||
@@ -518,7 +669,8 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
|
||||
// Check for scheduling constraints (node selector). If a selector is set,
|
||||
// we restrict the candidate pool to matching nodes; otherwise nil means
|
||||
// "any healthy node".
|
||||
candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID)
|
||||
sched, _ := r.registry.GetModelScheduling(ctx, modelID)
|
||||
candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID, sched)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
ggrpc "google.golang.org/grpc"
|
||||
@@ -111,18 +113,34 @@ type fakeModelRouter struct {
|
||||
findNodesWithModelByName map[string][]BackendNode
|
||||
findNodesWithModelErr error
|
||||
|
||||
// LoadedReplicaStats returns (keyed by model name)
|
||||
loadedReplicaStatsByName map[string][]ReplicaCandidate
|
||||
loadedReplicaStatsErr error
|
||||
|
||||
// Track calls for assertions
|
||||
decrementCalls []string // "nodeID:modelName"
|
||||
incrementCalls []string
|
||||
removeCalls []string
|
||||
setCalls []string
|
||||
touchCalls []string
|
||||
|
||||
// Preferences passed to FindAndLockNodeWithModel, in call order. nil
|
||||
// entries are recorded too, so tests can assert "preference was nil".
|
||||
findAndLockPrefs []*RoutePreference
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string) (*BackendNode, *NodeModel, error) {
|
||||
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
|
||||
f.findAndLockPrefs = append(f.findAndLockPrefs, pref)
|
||||
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) {
|
||||
if f.loadedReplicaStatsErr != nil {
|
||||
return nil, f.loadedReplicaStatsErr
|
||||
}
|
||||
return f.loadedReplicaStatsByName[modelName], nil
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
|
||||
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
|
||||
return nil
|
||||
@@ -1055,3 +1073,355 @@ var _ = Describe("SmartRouter", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type observeRecord struct {
|
||||
model string
|
||||
chain []uint64
|
||||
key prefixcache.ReplicaKey
|
||||
}
|
||||
|
||||
type invalidateRecord struct {
|
||||
model string
|
||||
key prefixcache.ReplicaKey
|
||||
}
|
||||
|
||||
// fakePrefixProvider records all interactions and returns a configurable
|
||||
// decision.
|
||||
type fakePrefixProvider struct {
|
||||
decideCalls int
|
||||
observed []observeRecord
|
||||
invalidated []invalidateRecord
|
||||
invalidatedNode []string
|
||||
decision prefixcache.PrefixDecision
|
||||
}
|
||||
|
||||
func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision {
|
||||
f.decideCalls++
|
||||
return f.decision
|
||||
}
|
||||
|
||||
func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool {
|
||||
f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key})
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) {
|
||||
f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key})
|
||||
}
|
||||
|
||||
func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) {
|
||||
f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID)
|
||||
}
|
||||
|
||||
func (f *fakePrefixProvider) Evict(_ time.Time) {}
|
||||
|
||||
var _ = Describe("SmartRouter prefix-cache routing", func() {
|
||||
var (
|
||||
backend *stubBackend
|
||||
factory *stubClientFactory
|
||||
unloader *fakeUnloader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
backend = &stubBackend{healthResult: true}
|
||||
factory = &stubClientFactory{client: backend}
|
||||
unloader = &fakeUnloader{
|
||||
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"},
|
||||
}
|
||||
})
|
||||
|
||||
// loadedReg builds a fake registry with one loaded healthy replica for
|
||||
// "m" on node "X", plus matching replica stats so buildPreference can run.
|
||||
loadedReg := func() *fakeModelRouter {
|
||||
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
||||
nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"}
|
||||
return &fakeModelRouter{
|
||||
findAndLockNode: node,
|
||||
findAndLockNM: nm,
|
||||
getModelScheduling: &ModelSchedulingConfig{
|
||||
RoutePolicy: "prefix_cache",
|
||||
},
|
||||
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
||||
"m": {{NodeID: "X", InFlight: 0}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Context("nil provider (round-robin floor)", func() {
|
||||
It("passes a nil preference and never decides or observes", func() {
|
||||
reg := loadedReg()
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(reg.findAndLockPrefs).ToNot(BeEmpty())
|
||||
for _, p := range reg.findAndLockPrefs {
|
||||
Expect(p).To(BeNil())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a provider", func() {
|
||||
It("passes the decided node as the preference and observes the pick", func() {
|
||||
reg := loadedReg()
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(prov.decideCalls).To(BeNumerically(">=", 1))
|
||||
Expect(reg.findAndLockPrefs[0]).ToNot(BeNil())
|
||||
Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X"))
|
||||
Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0))
|
||||
Expect(prov.observed).To(HaveLen(1))
|
||||
Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
||||
Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3}))
|
||||
})
|
||||
|
||||
It("routes a recurring prefix back to the previously observed node", func() {
|
||||
// Real Index as the provider: first request observes X, second
|
||||
// request with the same chain must yield PreferredNodeID == X.
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
reg := loadedReg()
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: idx,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// First request landed on X (cold placement on the only candidate)
|
||||
// and observed the prefix there.
|
||||
dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now())
|
||||
Expect(dFirst.HasHot).To(BeTrue())
|
||||
Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
||||
|
||||
// Second request, same chain: X is now the warm-cache hot match, so
|
||||
// the preference must point at it.
|
||||
_, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1]
|
||||
Expect(last).ToNot(BeNil())
|
||||
Expect(last.PreferredNodeID).To(Equal("X"))
|
||||
Expect(last.PreferredReplica).To(Equal(0))
|
||||
})
|
||||
|
||||
It("prefers the exact hot replica when two replicas share a node", func() {
|
||||
// Two replicas of "m" live on the SAME node X: replica 0 and replica
|
||||
// 1. A hot prefix observed on (X,0) must produce a preference that
|
||||
// locks replica 0 specifically, NOT the sibling replica 1 on the same
|
||||
// node. This is the replica-granular regression this change fixes.
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
||||
nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"}
|
||||
reg := &fakeModelRouter{
|
||||
findAndLockNode: node,
|
||||
findAndLockNM: nm,
|
||||
getModelScheduling: &ModelSchedulingConfig{
|
||||
RoutePolicy: "prefix_cache",
|
||||
},
|
||||
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
||||
"m": {
|
||||
{NodeID: "X", ReplicaIndex: 0, InFlight: 0},
|
||||
{NodeID: "X", ReplicaIndex: 1, InFlight: 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Seed the index so (X,0) is the warm replica for this chain.
|
||||
idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now())
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: idx,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
pref := reg.findAndLockPrefs[0]
|
||||
Expect(pref).ToNot(BeNil())
|
||||
Expect(pref.PreferredNodeID).To(Equal("X"))
|
||||
Expect(pref.PreferredReplica).To(Equal(0),
|
||||
"the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen")
|
||||
})
|
||||
|
||||
It("does not decide or observe when no prefix chain is present", func() {
|
||||
reg := loadedReg()
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
_, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(prov.decideCalls).To(Equal(0))
|
||||
Expect(prov.observed).To(BeEmpty())
|
||||
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
||||
})
|
||||
|
||||
It("does not observe for round-robin models even with a chain", func() {
|
||||
reg := loadedReg()
|
||||
reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"}
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(prov.decideCalls).To(Equal(0))
|
||||
Expect(prov.observed).To(BeEmpty())
|
||||
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("forced-disturb pressure", func() {
|
||||
// disturbReg builds a registry with two candidate replicas for "m":
|
||||
// the hot node X is saturated (high in_flight) and Y is free. Select
|
||||
// will therefore reject the hot node and pick Y, which is the
|
||||
// forced-disturb signal. findAndLockNode returns Y so Route succeeds.
|
||||
disturbReg := func() *fakeModelRouter {
|
||||
nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"}
|
||||
nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"}
|
||||
return &fakeModelRouter{
|
||||
findAndLockNode: nodeY,
|
||||
findAndLockNM: nm,
|
||||
getModelScheduling: &ModelSchedulingConfig{
|
||||
RoutePolicy: "prefix_cache",
|
||||
},
|
||||
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
||||
"m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
It("records pressure when a strong hot match was forced off the warm node", func() {
|
||||
reg := disturbReg()
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
||||
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
||||
HasHot: true,
|
||||
MatchRatio: 1.0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}},
|
||||
}}
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0),
|
||||
"hot match existed but the load guard forced us off X: must record pressure")
|
||||
})
|
||||
|
||||
It("does not record pressure when the hot node is itself eligible", func() {
|
||||
reg := loadedReg() // single node X, in_flight 0 → X stays eligible
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
||||
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
||||
HasHot: true,
|
||||
MatchRatio: 1.0,
|
||||
}}
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
||||
"chosen == hot node, no disturb")
|
||||
})
|
||||
|
||||
It("does not record pressure for an all-unique workload with no hot match", func() {
|
||||
reg := loadedReg()
|
||||
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
||||
HasHot: false, // no prefix match at all
|
||||
MatchRatio: 0,
|
||||
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}},
|
||||
}}
|
||||
pressure := prefixcache.NewPressure(time.Minute)
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: prov,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
Pressure: pressure,
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
||||
"no hot match means no cache to disturb: must not false-positive")
|
||||
})
|
||||
})
|
||||
|
||||
Context("removal chokepoint on unload", func() {
|
||||
It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() {
|
||||
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
reg := loadedReg()
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: idx,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
})
|
||||
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6})
|
||||
// Warm the cache: X now holds the prefix.
|
||||
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
||||
|
||||
// UnloadModel must route the eviction through the registry removal
|
||||
// chokepoint (RemoveAllNodeModelReplicas). The registry's
|
||||
// SetReplicaRemovedHook is what invalidates the prefix index in
|
||||
// production; the router no longer invalidates directly. Here the
|
||||
// fake registry records the removal but fires no hook, so we assert
|
||||
// the chokepoint is exercised rather than the downstream
|
||||
// invalidation (covered by the registry hook integration tests).
|
||||
Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed())
|
||||
Expect(reg.removeCalls).To(ContainElement("X:m"),
|
||||
"UnloadModel must remove the replica via the registry removal chokepoint")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -558,3 +558,17 @@ All fields are optional and composable:
|
||||
- Ensure the port range is not blocked by firewalls or used by other services
|
||||
- Verify the backend gallery configuration is correct
|
||||
- The worker needs network access to download backends from the gallery
|
||||
|
||||
## Roadmap: Routing and Caching Enhancements
|
||||
|
||||
The scheduling algorithm above is load-based (least in-flight, then least-recently-used). Work is underway to make routing **prefix-cache-aware**: bias each request toward the replica that already holds the relevant KV/prefix cache (multi-turn conversations and shared system prompts), so backends reuse cache instead of recomputing it. The first step is a router-side radix tree of prompt-prefix hashes mapped to nodes, with longest-prefix match, a load guard that preserves round-robin behavior under imbalance, and NATS sync across frontends. It is purely a routing-layer hint (no backend changes) and never routes worse than today's round-robin.
|
||||
|
||||
Further enhancements, surfaced from a survey of SGLang, vLLM production-stack, Ray Serve, llm-d, AIBrix, and NVIDIA Dynamo, are tracked under the routing roadmap epic ([#10063](https://github.com/mudler/LocalAI/issues/10063)):
|
||||
|
||||
- **Reported/precise KV-event mode** ([#10064](https://github.com/mudler/LocalAI/issues/10064)): subscribe to actual backend KV-cache events for exact residency instead of inferring it from routing history.
|
||||
- **Multi-tier cache-overlap scoring** ([#10065](https://github.com/mudler/LocalAI/issues/10065)): credit GPU/CPU/disk cache tiers separately.
|
||||
- **Pluggable scorer/filter/picker pipeline** ([#10066](https://github.com/mudler/LocalAI/issues/10066)): composable multi-signal routing (cache, queue depth, KV utilization, latency).
|
||||
- **Load-shaping** ([#10067](https://github.com/mudler/LocalAI/issues/10067)): anti-herding (softmax/temperature) and dispatch-time freshness.
|
||||
- **Prefill/decode disaggregation routing** ([#10068](https://github.com/mudler/LocalAI/issues/10068)): route prefill and decode to separate pools with KV transfer.
|
||||
- **Per-user fairness (VTC)** ([#10069](https://github.com/mudler/LocalAI/issues/10069)): balance per-user token usage against pod load.
|
||||
- **Minor tuning + MCP parity** ([#10070](https://github.com/mudler/LocalAI/issues/10070)): per-model TTL override, probabilistic LRU updates, and MCP scheduling-config tool parity.
|
||||
|
||||
37
pkg/distributedhdr/prefixhash.go
Normal file
37
pkg/distributedhdr/prefixhash.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package distributedhdr
|
||||
|
||||
import "context"
|
||||
|
||||
type prefixChainKey struct{}
|
||||
|
||||
// WithPrefixChain attaches a prompt prefix-hash chain to ctx so the distributed
|
||||
// router can make a prefix-cache-aware decision. Set at inference entry where
|
||||
// the rendered prompt is known; read in SmartRouter.Route.
|
||||
func WithPrefixChain(ctx context.Context, chain []uint64) context.Context {
|
||||
return context.WithValue(ctx, prefixChainKey{}, chain)
|
||||
}
|
||||
|
||||
// PrefixChain returns the chain attached by WithPrefixChain, or nil.
|
||||
func PrefixChain(ctx context.Context) []uint64 {
|
||||
if v, ok := ctx.Value(prefixChainKey{}).([]uint64); ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrefixChainHook, when set at startup (distributed mode only), builds a prefix
|
||||
// hash chain from a model id and rendered prompt. Left nil in single-process
|
||||
// mode so there is zero overhead. See core/application/distributed.go.
|
||||
var PrefixChainHook func(model, prompt string) []uint64
|
||||
|
||||
// MaybeWithPrefixChain attaches a prefix chain to ctx iff the hook is set and
|
||||
// returns a non-empty chain. Otherwise returns ctx unchanged.
|
||||
func MaybeWithPrefixChain(ctx context.Context, model, prompt string) context.Context {
|
||||
if PrefixChainHook == nil {
|
||||
return ctx
|
||||
}
|
||||
if chain := PrefixChainHook(model, prompt); len(chain) > 0 {
|
||||
return WithPrefixChain(ctx, chain)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
37
pkg/distributedhdr/prefixhash_test.go
Normal file
37
pkg/distributedhdr/prefixhash_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package distributedhdr_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("prefix chain ctx", func() {
|
||||
It("round-trips the chain through ctx", func() {
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
||||
Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{1, 2, 3}))
|
||||
})
|
||||
It("returns nil when absent", func() {
|
||||
Expect(distributedhdr.PrefixChain(context.Background())).To(BeNil())
|
||||
})
|
||||
|
||||
It("uses the hook to build the chain when set", func() {
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return []uint64{42} }
|
||||
defer func() { distributedhdr.PrefixChainHook = nil }()
|
||||
ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
|
||||
Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{42}))
|
||||
})
|
||||
It("is a no-op when the hook is nil", func() {
|
||||
distributedhdr.PrefixChainHook = nil
|
||||
ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
|
||||
Expect(distributedhdr.PrefixChain(ctx)).To(BeNil())
|
||||
})
|
||||
It("is a no-op when the hook returns an empty chain", func() {
|
||||
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return nil }
|
||||
defer func() { distributedhdr.PrefixChainHook = nil }()
|
||||
ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
|
||||
Expect(distributedhdr.PrefixChain(ctx)).To(BeNil())
|
||||
})
|
||||
})
|
||||
248
pkg/radixtree/radixtree.go
Normal file
248
pkg/radixtree/radixtree.go
Normal file
@@ -0,0 +1,248 @@
|
||||
// Package radixtree implements a generic prefix tree over sequences of uint64
|
||||
// key-elements, mapping the longest stored prefix of a query sequence to a
|
||||
// value. Entries carry a TTL and the tree tracks a recency-weighted score per
|
||||
// value. The clock is injected (callers pass `now`) so behavior is fully
|
||||
// deterministic and testable. It has no external dependencies.
|
||||
package radixtree
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Options configures a Tree.
|
||||
type Options struct {
|
||||
// TTL is the idle lifetime of an entry. An entry whose lastSeen is older
|
||||
// than TTL (relative to the `now` passed in) is treated as absent and is
|
||||
// swept by Evict. Refreshed on every Insert that traverses it. The boundary
|
||||
// is strict greater-than: an entry whose age is exactly equal to TTL is
|
||||
// still live; it expires only once age exceeds TTL.
|
||||
TTL time.Duration
|
||||
// HalfLife controls recency weighting in Weight(). An entry contributes
|
||||
// 0.5^(age/HalfLife). Zero means "no decay" (every live entry counts 1).
|
||||
HalfLife time.Duration
|
||||
// MaxEntries bounds the number of value-bearing nodes. Zero means
|
||||
// unbounded. When exceeded, Insert evicts the least-recently-seen entry.
|
||||
MaxEntries int
|
||||
}
|
||||
|
||||
// Tree is a prefix tree. V is the stored value type (for prefix-cache routing,
|
||||
// a node identifier). Safe for concurrent use.
|
||||
type Tree[V comparable] struct {
|
||||
mu sync.RWMutex
|
||||
opts Options
|
||||
root *node[V]
|
||||
size int
|
||||
}
|
||||
|
||||
type node[V comparable] struct {
|
||||
children map[uint64]*node[V]
|
||||
value V
|
||||
hasValue bool
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// New creates an empty Tree.
|
||||
func New[V comparable](opts Options) *Tree[V] {
|
||||
return &Tree[V]{opts: opts, root: &node[V]{children: map[uint64]*node[V]{}}}
|
||||
}
|
||||
|
||||
// LongestMatch returns the value at the deepest stored, non-expired prefix of
|
||||
// key, the matched depth (number of key elements consumed), and ok.
|
||||
func (t *Tree[V]) LongestMatch(key []uint64, now time.Time) (V, int, bool) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
var best V
|
||||
bestDepth, found := 0, false
|
||||
cur := t.root
|
||||
for i, k := range key {
|
||||
next, ok := cur.children[k]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
cur = next
|
||||
if cur.hasValue && !t.expired(cur, now) {
|
||||
best, bestDepth, found = cur.value, i+1, true
|
||||
}
|
||||
}
|
||||
return best, bestDepth, found
|
||||
}
|
||||
|
||||
// expired reports whether n's lastSeen is older than the configured TTL. The
|
||||
// comparison is strict greater-than: an entry whose age equals TTL exactly is
|
||||
// still considered live. With TTL == 0 (unbounded) nothing ever expires.
|
||||
func (t *Tree[V]) expired(n *node[V], now time.Time) bool {
|
||||
return t.opts.TTL > 0 && now.Sub(n.lastSeen) > t.opts.TTL
|
||||
}
|
||||
|
||||
// Insert records value at EVERY node along the key chain, not just the leaf,
|
||||
// so each prefix-block node remembers the value (node id) that served that
|
||||
// prefix. This is what makes LongestMatch find a shared prefix even when the
|
||||
// query tail diverges (SGLang/vLLM-style prefix matching). Re-inserting a
|
||||
// different value over a shared prefix node overwrites it: the last writer
|
||||
// owns the shared prefix node (a recency heuristic, and the correct one - the
|
||||
// most recent chain that traversed that block is the one most likely warm).
|
||||
// lastSeen is refreshed on every traversed node so active prefixes stay live.
|
||||
// Inserting an empty key is a no-op: the root never holds a value.
|
||||
func (t *Tree[V]) Insert(key []uint64, value V, now time.Time) {
|
||||
if len(key) == 0 {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
cur := t.root
|
||||
for _, k := range key {
|
||||
next, ok := cur.children[k]
|
||||
if !ok {
|
||||
next = &node[V]{children: map[uint64]*node[V]{}}
|
||||
cur.children[k] = next
|
||||
}
|
||||
cur = next
|
||||
if !cur.hasValue {
|
||||
t.size++
|
||||
}
|
||||
cur.value, cur.hasValue, cur.lastSeen = value, true, now
|
||||
}
|
||||
if t.opts.MaxEntries > 0 && t.size > t.opts.MaxEntries {
|
||||
t.evictOldestLocked(now)
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestLocked drops the single least-recently-seen value-bearing node and
|
||||
// prunes any empty branches the removal leaves behind. Called with t.mu held.
|
||||
func (t *Tree[V]) evictOldestLocked(now time.Time) {
|
||||
var victim *node[V]
|
||||
var walk func(n *node[V])
|
||||
walk = func(n *node[V]) {
|
||||
if n.hasValue && (victim == nil || n.lastSeen.Before(victim.lastSeen)) {
|
||||
victim = n
|
||||
}
|
||||
for _, c := range n.children {
|
||||
walk(c)
|
||||
}
|
||||
}
|
||||
walk(t.root)
|
||||
if victim != nil {
|
||||
// Clear the victim's value and reclaim it plus any ancestors that are
|
||||
// now both value-less and childless.
|
||||
t.pruneWalk(t.root, func(n *node[V]) bool { return n == victim })
|
||||
}
|
||||
}
|
||||
|
||||
// pruneWalk clears the value of every node for which shouldClear returns true,
|
||||
// then removes the now empty (value-less and childless) branches that result.
|
||||
// It keeps t.size accurate by decrementing once per cleared node. Returns true
|
||||
// if n itself should be removed from its parent. Called with t.mu held.
|
||||
func (t *Tree[V]) pruneWalk(n *node[V], shouldClear func(*node[V]) bool) bool {
|
||||
for k, c := range n.children {
|
||||
if t.pruneWalk(c, shouldClear) {
|
||||
delete(n.children, k)
|
||||
}
|
||||
}
|
||||
if n.hasValue && shouldClear(n) {
|
||||
n.hasValue = false
|
||||
var zero V
|
||||
n.value = zero
|
||||
t.size--
|
||||
}
|
||||
return n != t.root && !n.hasValue && len(n.children) == 0
|
||||
}
|
||||
|
||||
// Len returns the number of live (value-bearing) entries, including not-yet-
|
||||
// swept expired ones. Use after Evict for the post-sweep count.
|
||||
func (t *Tree[V]) Len() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.size
|
||||
}
|
||||
|
||||
// Evict removes expired value-bearing nodes and prunes resulting empty
|
||||
// branches. O(n) in tree size; call periodically from a background sweeper.
|
||||
func (t *Tree[V]) Evict(now time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.pruneWalk(t.root, func(n *node[V]) bool { return t.expired(n, now) })
|
||||
}
|
||||
|
||||
// contribution returns the recency-weighted score a single live, non-expired
|
||||
// node adds to its value's weight: 1.0 when HalfLife<=0 (a plain count), else
|
||||
// 0.5^(age/HalfLife). It does not check hasValue or expiry; callers must filter
|
||||
// those first. Shared by Weight and WeightsFor so the metric stays identical.
|
||||
func (t *Tree[V]) contribution(n *node[V], now time.Time) float64 {
|
||||
if t.opts.HalfLife <= 0 {
|
||||
return 1
|
||||
}
|
||||
age := now.Sub(n.lastSeen).Seconds()
|
||||
return math.Pow(0.5, age/t.opts.HalfLife.Seconds())
|
||||
}
|
||||
|
||||
// Weight returns the recency-weighted count of live entries anchored to value:
|
||||
// sum over non-expired entries of 0.5^(age/HalfLife). With HalfLife==0 every
|
||||
// live entry contributes 1.0 (a plain count). This is the "valuable warm cache"
|
||||
// proxy used for cold placement and autoscale.
|
||||
func (t *Tree[V]) Weight(value V, now time.Time) float64 {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
var sum float64
|
||||
var walk func(n *node[V])
|
||||
walk = func(n *node[V]) {
|
||||
if n.hasValue && n.value == value && !t.expired(n, now) {
|
||||
sum += t.contribution(n, now)
|
||||
}
|
||||
for _, c := range n.children {
|
||||
walk(c)
|
||||
}
|
||||
}
|
||||
walk(t.root)
|
||||
return sum
|
||||
}
|
||||
|
||||
// WeightsFor returns the recency-weighted weight (same metric as Weight) for
|
||||
// each value in values, computed in a single tree traversal. Values not present
|
||||
// in the tree map to 0. This is O(N + len(values)) versus calling Weight once
|
||||
// per value (O(len(values) * N)). Concurrency-safe (read lock).
|
||||
func (t *Tree[V]) WeightsFor(values []V, now time.Time) map[V]float64 {
|
||||
want := make(map[V]struct{}, len(values))
|
||||
result := make(map[V]float64, len(values))
|
||||
for _, v := range values {
|
||||
want[v] = struct{}{}
|
||||
result[v] = 0
|
||||
}
|
||||
if len(want) == 0 {
|
||||
return result
|
||||
}
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
var walk func(n *node[V])
|
||||
walk = func(n *node[V]) {
|
||||
if n.hasValue && !t.expired(n, now) {
|
||||
if _, ok := want[n.value]; ok {
|
||||
result[n.value] += t.contribution(n, now)
|
||||
}
|
||||
}
|
||||
for _, c := range n.children {
|
||||
walk(c)
|
||||
}
|
||||
}
|
||||
walk(t.root)
|
||||
return result
|
||||
}
|
||||
|
||||
// Remove drops every entry whose value equals value, then prunes empty
|
||||
// branches. Used when a replica is unloaded or its node goes offline so the
|
||||
// tree never points at a node that no longer holds the model. It is the
|
||||
// equality special case of RemoveFunc.
|
||||
func (t *Tree[V]) Remove(value V) {
|
||||
t.RemoveFunc(func(v V) bool { return v == value })
|
||||
}
|
||||
|
||||
// RemoveFunc drops every entry whose value satisfies pred, then prunes empty
|
||||
// branches. Generalizes Remove (Remove(v) == RemoveFunc(func(x V) bool { return
|
||||
// x == v })). Used to drop, in one walk, every entry that belongs to a class of
|
||||
// values (for example all replicas of a single node).
|
||||
func (t *Tree[V]) RemoveFunc(pred func(V) bool) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.pruneWalk(t.root, func(n *node[V]) bool { return pred(n.value) })
|
||||
}
|
||||
13
pkg/radixtree/radixtree_suite_test.go
Normal file
13
pkg/radixtree/radixtree_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package radixtree_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestRadixTree(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "RadixTree Suite")
|
||||
}
|
||||
354
pkg/radixtree/radixtree_test.go
Normal file
354
pkg/radixtree/radixtree_test.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package radixtree_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/radixtree"
|
||||
)
|
||||
|
||||
var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var _ = Describe("Tree construction", func() {
|
||||
It("returns an empty tree that matches nothing", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
_, depth, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(depth).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Insert and LongestMatch", func() {
|
||||
It("returns the deepest matching prefix value", func() {
|
||||
// Non-overlapping chains keep the longest-prefix intent clean: every
|
||||
// node on the value's own chain records that value, and no other Insert
|
||||
// overwrites a shared prefix node. A query that runs off the end of a
|
||||
// chain stops matching at the deepest stored element it reached.
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0)
|
||||
tr.Insert([]uint64{7, 8}, "nodeA", t0)
|
||||
|
||||
v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("nodeB"))
|
||||
Expect(depth).To(Equal(4))
|
||||
|
||||
v, depth, ok = tr.LongestMatch([]uint64{7, 8, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("nodeA"))
|
||||
Expect(depth).To(Equal(2))
|
||||
})
|
||||
|
||||
It("lets the last writer own a shared prefix node", func() {
|
||||
// When two chains share a leading block, value-at-every-node means the
|
||||
// later Insert overwrites the shared prefix node. Inserting nodeA on
|
||||
// [1,2] then nodeB on [1,2,3,4] makes nodeB own [1] and [1,2], so a
|
||||
// query that diverges within the shared block resolves to nodeB. This
|
||||
// is the intended recency heuristic: the most recent chain through that
|
||||
// block is the one most likely still warm.
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2}, "nodeA", t0)
|
||||
tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0)
|
||||
|
||||
v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("nodeB"))
|
||||
Expect(depth).To(Equal(4))
|
||||
|
||||
// The shared prefix [1,2] is now owned by nodeB (last writer wins).
|
||||
v, depth, ok = tr.LongestMatch([]uint64{1, 2, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("nodeB"))
|
||||
Expect(depth).To(Equal(2))
|
||||
})
|
||||
|
||||
It("returns ok=false when no prefix is stored", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{7, 8}, "nodeA", t0)
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("matches a shared prefix when the query tail diverges", func() {
|
||||
// SGLang/vLLM-style prefix matching: a single Insert of a full chain
|
||||
// must let any query that shares a leading block match at the depth of
|
||||
// the deepest shared element, even though the tails differ. This is the
|
||||
// core use case (shared system prompt / multi-turn extension / volatile
|
||||
// tail), not exact-repeat.
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2, 3, 4, 5}, "nodeA", t0)
|
||||
v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 9, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(depth).To(Equal(3)) // shared prefix [1,2,3]
|
||||
Expect(v).To(Equal("nodeA"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TTL expiry", func() {
|
||||
It("does not match an entry past its TTL", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
tr.Insert([]uint64{1, 2}, "nodeA", t0)
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Minute))
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("refreshes lastSeen on re-insert so a live path survives", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
tr.Insert([]uint64{1, 2}, "nodeA", t0)
|
||||
tr.Insert([]uint64{1, 2}, "nodeA", t0.Add(50*time.Second))
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(90*time.Second))
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("Evict reclaims expired nodes", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
// Value-at-every-node: Insert of a 2-element chain records nodeA at both
|
||||
// {1} and {1,2}, so Len is 2 (one valued node per distinct prefix).
|
||||
tr.Insert([]uint64{1, 2}, "nodeA", t0)
|
||||
Expect(tr.Len()).To(Equal(2))
|
||||
tr.Evict(t0.Add(2 * time.Minute))
|
||||
Expect(tr.Len()).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Weight", func() {
|
||||
It("counts live entries for a value with no decay", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
tr.Insert([]uint64{9}, "B", t0)
|
||||
Expect(tr.Weight("A", t0)).To(BeNumerically("==", 2))
|
||||
Expect(tr.Weight("B", t0)).To(BeNumerically("==", 1))
|
||||
Expect(tr.Weight("C", t0)).To(BeNumerically("==", 0))
|
||||
})
|
||||
|
||||
It("decays older entries by half-life", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
// one half-life later, the entry weighs 0.5
|
||||
Expect(tr.Weight("A", t0.Add(time.Minute))).To(BeNumerically("~", 0.5, 0.001))
|
||||
})
|
||||
|
||||
It("ignores expired entries", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
Expect(tr.Weight("A", t0.Add(2*time.Minute))).To(BeNumerically("==", 0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("WeightsFor", func() {
|
||||
It("matches per-value Weight with no decay", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
tr.Insert([]uint64{9}, "B", t0)
|
||||
|
||||
got := tr.WeightsFor([]string{"A", "B", "C"}, t0)
|
||||
Expect(got).To(HaveLen(3))
|
||||
Expect(got["A"]).To(BeNumerically("==", 2))
|
||||
Expect(got["B"]).To(BeNumerically("==", 1))
|
||||
Expect(got["C"]).To(BeNumerically("==", 0))
|
||||
})
|
||||
|
||||
It("matches per-value Weight under decay", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
tr.Insert([]uint64{1, 2}, "A", t0.Add(30*time.Second))
|
||||
tr.Insert([]uint64{9}, "B", t0)
|
||||
|
||||
now := t0.Add(time.Minute)
|
||||
got := tr.WeightsFor([]string{"A", "B", "C"}, now)
|
||||
Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12))
|
||||
Expect(got["B"]).To(BeNumerically("~", tr.Weight("B", now), 1e-12))
|
||||
Expect(got["C"]).To(BeNumerically("==", 0))
|
||||
})
|
||||
|
||||
It("respects TTL expiry and matches Weight at a non-zero age under decay", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute, HalfLife: 30 * time.Second})
|
||||
tr.Insert([]uint64{1}, "A", t0) // will be expired at now
|
||||
tr.Insert([]uint64{2}, "A", t0.Add(90*time.Second)) // live, aged 30s at now
|
||||
tr.Insert([]uint64{9}, "B", t0) // expired at now
|
||||
|
||||
now := t0.Add(2 * time.Minute)
|
||||
got := tr.WeightsFor([]string{"A", "B"}, now)
|
||||
Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12))
|
||||
Expect(got["A"]).To(BeNumerically("~", 0.5, 0.001)) // single live entry aged one half-life
|
||||
Expect(got["B"]).To(BeNumerically("==", 0))
|
||||
})
|
||||
|
||||
It("returns an empty map for an empty values slice", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
Expect(tr.WeightsFor(nil, t0)).To(BeEmpty())
|
||||
Expect(tr.WeightsFor([]string{}, t0)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("maps a value not present in the tree to 0", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
got := tr.WeightsFor([]string{"Z"}, t0)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got["Z"]).To(BeNumerically("==", 0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Remove", func() {
|
||||
It("drops every entry anchored to a value and prunes", func() {
|
||||
// Non-overlapping chains so Remove("A") and the survival of B are both
|
||||
// meaningful: with value-at-every-node, overlapping chains would let the
|
||||
// later writer own the shared prefix nodes, so A could own nothing and
|
||||
// the test would be vacuous.
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
tr.Insert([]uint64{7, 8, 9}, "B", t0)
|
||||
tr.Remove("A")
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
|
||||
Expect(ok).To(BeFalse()) // A gone; its branch is fully reclaimed
|
||||
v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("B")) // B survives
|
||||
Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RemoveFunc", func() {
|
||||
It("drops every entry matching the predicate, prunes, and keeps the rest", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2}, "drop-a", t0)
|
||||
tr.Insert([]uint64{3, 4}, "drop-b", t0)
|
||||
tr.Insert([]uint64{7, 8, 9}, "keep", t0)
|
||||
// Drop everything whose value starts with "drop".
|
||||
tr.RemoveFunc(func(v string) bool { return len(v) >= 4 && v[:4] == "drop" })
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
_, _, ok = tr.LongestMatch([]uint64{3, 4}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("keep"))
|
||||
Expect(tr.Len()).To(Equal(3)) // only the 3-node "keep" chain remains
|
||||
})
|
||||
|
||||
It("makes Remove a special case of RemoveFunc", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
tr.Insert([]uint64{7, 8, 9}, "B", t0)
|
||||
tr.RemoveFunc(func(v string) bool { return v == "A" })
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(v).To(Equal("B"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TTL boundary", func() {
|
||||
It("treats age exactly equal to TTL as still live, and one tick past as expired", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
|
||||
// age == TTL: strict greater-than means this is still live.
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute))
|
||||
Expect(ok).To(BeTrue())
|
||||
|
||||
// one nanosecond past TTL: expired.
|
||||
_, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute+time.Nanosecond))
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("MaxEntries eviction", func() {
|
||||
It("drops the least-recently-seen entry when the cap is exceeded", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
tr.Insert([]uint64{2}, "B", t0.Add(time.Second))
|
||||
tr.Insert([]uint64{3}, "C", t0.Add(2*time.Second))
|
||||
|
||||
Expect(tr.Len()).To(Equal(2))
|
||||
|
||||
// A was the least-recently-seen, so it is the one dropped.
|
||||
_, _, ok := tr.LongestMatch([]uint64{1}, t0.Add(2*time.Second))
|
||||
Expect(ok).To(BeFalse())
|
||||
|
||||
// B and C survive.
|
||||
_, _, ok = tr.LongestMatch([]uint64{2}, t0.Add(2*time.Second))
|
||||
Expect(ok).To(BeTrue())
|
||||
_, _, ok = tr.LongestMatch([]uint64{3}, t0.Add(2*time.Second))
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("prunes value-less ancestors left behind by an eviction", func() {
|
||||
// Value-at-every-node: Inserting the deep chain B = [1,2,3] records B at
|
||||
// {1}, {1,2}, and {1,2,3} (three valued nodes). With the cap at 2, the
|
||||
// least-recently-seen valued nodes are evicted one per subsequent Insert.
|
||||
// The two fresh single-element keys (C, D) are newer, so eviction keeps
|
||||
// peeling B's nodes off until B's entire branch is reclaimed - none of
|
||||
// its internal nodes may linger and inflate Len past the cap.
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
|
||||
tr.Insert([]uint64{1, 2, 3}, "B", t0)
|
||||
tr.Insert([]uint64{5}, "C", t0.Add(time.Second))
|
||||
tr.Insert([]uint64{6}, "D", t0.Add(2*time.Second))
|
||||
|
||||
Expect(tr.Len()).To(Equal(2))
|
||||
// B (oldest) evicted; its deep branch reclaimed.
|
||||
_, _, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0.Add(2*time.Second))
|
||||
Expect(ok).To(BeFalse())
|
||||
_, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Second))
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(tr.Weight("B", t0.Add(2*time.Second))).To(BeNumerically("==", 0))
|
||||
})
|
||||
|
||||
It("reclaims structure so the tree never grows past the cap under churn", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
|
||||
tr.Insert([]uint64{1}, "A", t0)
|
||||
tr.Insert([]uint64{2}, "B", t0.Add(time.Second))
|
||||
Expect(tr.Len()).To(Equal(2))
|
||||
|
||||
for i := range 10 {
|
||||
tr.Insert([]uint64{uint64(100 + i)}, "X", t0.Add(time.Duration(i+2)*time.Second))
|
||||
Expect(tr.Len()).To(Equal(2))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Empty key", func() {
|
||||
It("LongestMatch on an empty key returns ok=false", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
tr.Insert([]uint64{1, 2}, "A", t0)
|
||||
_, depth, ok := tr.LongestMatch([]uint64{}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(depth).To(Equal(0))
|
||||
})
|
||||
|
||||
It("Insert with an empty key is a no-op that creates no root value", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
Expect(func() { tr.Insert([]uint64{}, "A", t0) }).NotTo(Panic())
|
||||
Expect(tr.Len()).To(Equal(0))
|
||||
_, _, ok := tr.LongestMatch([]uint64{}, t0)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Concurrent access", func() {
|
||||
It("is race-free under parallel insert/match/weight", func() {
|
||||
tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
|
||||
done := make(chan struct{})
|
||||
for g := range 8 {
|
||||
go func(g int) {
|
||||
defer GinkgoRecover()
|
||||
for i := range 1000 {
|
||||
tr.Insert([]uint64{uint64(g), uint64(i % 10)}, "n", t0)
|
||||
tr.LongestMatch([]uint64{uint64(g), 1}, t0)
|
||||
tr.Weight("n", t0)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}(g)
|
||||
}
|
||||
for range 8 {
|
||||
<-done
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -63,7 +63,7 @@ var _ = Describe("Model Routing", Label("Distributed"), func() {
|
||||
Expect(models[0].InFlight).To(Equal(2))
|
||||
|
||||
// FindAndLockNodeWithModel should return this node and atomically increment in-flight
|
||||
foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil)
|
||||
foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(foundNode.ID).To(Equal(node.ID))
|
||||
Expect(foundModel.ModelName).To(Equal("llama3"))
|
||||
|
||||
234
tests/e2e/distributed/prefix_cache_routing_test.go
Normal file
234
tests/e2e/distributed/prefix_cache_routing_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package distributed_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
ggrpc "google.golang.org/grpc"
|
||||
|
||||
pgdriver "gorm.io/driver/postgres"
|
||||
gormDB "gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// prefixStubBackend implements grpc.Backend with a canned-success HealthCheck
|
||||
// and LoadModel so SmartRouter.probeHealth passes and any cold load returns
|
||||
// success — no real inference happens. Mirrors the stubBackend pattern used by
|
||||
// the SmartRouter unit tests in core/services/nodes/router_test.go, reproduced
|
||||
// here because that fake lives in the internal (unexported) nodes package.
|
||||
type prefixStubBackend struct {
|
||||
grpcPkg.Backend // embed so unused methods satisfy the interface; they panic only if called
|
||||
|
||||
healthResult bool
|
||||
}
|
||||
|
||||
func (f *prefixStubBackend) HealthCheck(_ context.Context) (bool, error) {
|
||||
return f.healthResult, nil
|
||||
}
|
||||
|
||||
func (f *prefixStubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
return &pb.Result{Success: true}, nil
|
||||
}
|
||||
|
||||
func (f *prefixStubBackend) IsBusy() bool { return false }
|
||||
|
||||
// prefixStubClientFactory hands the same fake backend to every NewClient call,
|
||||
// so the SmartRouter never opens a real gRPC connection during routing.
|
||||
type prefixStubClientFactory struct {
|
||||
client *prefixStubBackend
|
||||
}
|
||||
|
||||
func (f *prefixStubClientFactory) NewClient(_ string, _ bool) grpcPkg.Backend {
|
||||
return f.client
|
||||
}
|
||||
|
||||
var _ = Describe("Prefix-cache aware routing", Label("Distributed"), func() {
|
||||
const model = "model"
|
||||
|
||||
var (
|
||||
infra *TestInfra
|
||||
db *gormDB.DB
|
||||
registry *nodes.NodeRegistry
|
||||
router *nodes.SmartRouter
|
||||
idx *prefixcache.Index
|
||||
|
||||
nodeXID string
|
||||
nodeYID string
|
||||
|
||||
chainA = []uint64{1, 2, 3, 4, 5} // conversation A
|
||||
chainShared = []uint64{1, 2, 3, 9, 9} // shares leading prefix [1,2,3] with A
|
||||
chainUnrelated = []uint64{7, 8, 9} // no shared prefix with A
|
||||
)
|
||||
|
||||
// routeAndSettle drives one request through the router for the given prefix
|
||||
// chain and immediately settles the in-flight reservation the way a real
|
||||
// inference completion would (Release closes the client; the DecrementInFlight
|
||||
// emulates the OnFirstComplete callback that fires after the first inference).
|
||||
// Settling keeps both nodes balanced at in_flight=0 so the prefix-cache load
|
||||
// guard never falsely forces a request off its warm node between steps.
|
||||
routeAndSettle := func(chain []uint64) string {
|
||||
GinkgoHelper()
|
||||
ctx := distributedhdr.WithPrefixChain(context.Background(), chain)
|
||||
result, err := router.Route(ctx, model, model, "llama-cpp",
|
||||
&pb.ModelOptions{ModelFile: model}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).ToNot(BeNil())
|
||||
Expect(result.Node).ToNot(BeNil())
|
||||
nodeID := result.Node.ID
|
||||
result.Release()
|
||||
Expect(registry.DecrementInFlight(context.Background(), nodeID, model, 0)).To(Succeed())
|
||||
return nodeID
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
infra = SetupInfra("localai_prefix_cache_routing_test")
|
||||
|
||||
var err error
|
||||
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
registry, err = nodes.NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The prefix-cache index is the real radix-tree provider. Keep a handle so
|
||||
// the specs can assert Decide() directly in addition to observing Route().
|
||||
idx = prefixcache.NewIndex(prefixcache.DefaultConfig())
|
||||
|
||||
// Wire the registry chokepoint hook ourselves. In production distributed.go
|
||||
// wires this; a bare SmartRouter test must register it so removal-path
|
||||
// invalidation is exercised end to end. A negative replica index means
|
||||
// "all replicas of the node" (InvalidateNode); otherwise drop the exact
|
||||
// replica.
|
||||
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replica int) {
|
||||
if replica < 0 {
|
||||
idx.InvalidateNode(modelName, nodeID)
|
||||
} else {
|
||||
idx.Invalidate(modelName, prefixcache.ReplicaKey{NodeID: nodeID, Replica: replica})
|
||||
}
|
||||
})
|
||||
|
||||
// Register TWO healthy nodes and mark the model loaded on both (replica 0).
|
||||
nodeX := &nodes.BackendNode{Name: "node-x", Address: "127.0.0.1:50051"}
|
||||
nodeY := &nodes.BackendNode{Name: "node-y", Address: "127.0.0.1:50052"}
|
||||
Expect(registry.Register(context.Background(), nodeX, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), nodeY, true)).To(Succeed())
|
||||
nodeXID = nodeX.ID
|
||||
nodeYID = nodeY.ID
|
||||
Expect(registry.SetNodeModel(context.Background(), nodeXID, model, 0, "loaded", "", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), nodeYID, model, 0, "loaded", "", 0)).To(Succeed())
|
||||
|
||||
factory := &prefixStubClientFactory{client: &prefixStubBackend{healthResult: true}}
|
||||
router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
ClientFactory: factory,
|
||||
PrefixProvider: idx,
|
||||
PrefixConfig: prefixcache.DefaultConfig(),
|
||||
DB: db,
|
||||
})
|
||||
})
|
||||
|
||||
It("locks affinity, honors shared prefixes, isolates unrelated chains, and re-homes on failover", func() {
|
||||
now := time.Now()
|
||||
// Both nodes host replica 0 of the model.
|
||||
keys := []prefixcache.ReplicaKey{{NodeID: nodeXID, Replica: 0}, {NodeID: nodeYID, Replica: 0}}
|
||||
|
||||
// --- Step 1: cold miss + observe -------------------------------------
|
||||
// chainA's prefix has never been seen, so there is no hot match yet; the
|
||||
// request cold-places on some loaded node X and the assignment is recorded.
|
||||
Expect(idx.Decide(model, chainA, keys, now).HasHot).To(BeFalse(),
|
||||
"step 1: chainA must be a cold miss (no prior affinity)")
|
||||
placedNode := routeAndSettle(chainA)
|
||||
Expect(placedNode).To(Or(Equal(nodeXID), Equal(nodeYID)))
|
||||
// From here on, "X" is whichever node served chainA first.
|
||||
nodeX := placedNode
|
||||
var nodeY string
|
||||
if nodeX == nodeXID {
|
||||
nodeY = nodeYID
|
||||
} else {
|
||||
nodeY = nodeXID
|
||||
}
|
||||
hotX := prefixcache.ReplicaKey{NodeID: nodeX, Replica: 0}
|
||||
Expect(idx.Decide(model, chainA, keys, time.Now()).Hot).To(Equal(hotX),
|
||||
"step 1: chainA must now be recorded against the replica that served it")
|
||||
|
||||
// --- Step 2: hot-match affinity --------------------------------------
|
||||
// The SAME chain routes back to X.
|
||||
Expect(routeAndSettle(chainA)).To(Equal(nodeX),
|
||||
"step 2: a repeat of chainA must return to its warm node X")
|
||||
|
||||
// --- Step 3: shared-prefix match (the regression we fixed) -----------
|
||||
// A DIFFERENT chain that shares the leading prefix [1,2,3] with X's chain
|
||||
// but diverges at the tail still matches the shared head and routes to X.
|
||||
// Before the radix-tree fix this fell through to a cold placement.
|
||||
Expect(idx.Decide(model, chainShared, keys, time.Now()).Hot).To(Equal(hotX),
|
||||
"step 3: chainShared must hot-match X on the shared prefix")
|
||||
Expect(routeAndSettle(chainShared)).To(Equal(nodeX),
|
||||
"step 3: chainShared must route to X via the shared-prefix match")
|
||||
|
||||
// --- Step 4: negative control ----------------------------------------
|
||||
// A completely unrelated chain shares no prefix with X's chain, so it must
|
||||
// NOT hot-match X's affinity. (Cold placement may still pick X or Y by
|
||||
// load/cacheWeight, but it must not be a false hot match.) Asserting the
|
||||
// provider decision directly is the robust check.
|
||||
Expect(idx.Decide(model, chainUnrelated, keys, time.Now()).HasHot).To(BeFalse(),
|
||||
"step 4: chainUnrelated must be a cold miss, not a false hot match on X")
|
||||
|
||||
// --- Step 5: failover + invalidation ---------------------------------
|
||||
// Remove node X's replica of the model. This fires the registry chokepoint
|
||||
// hook, which invalidates the prefix-cache entry for X. A request for X's
|
||||
// chain must then fail over to the surviving node Y, and the prefix entry
|
||||
// must no longer pin to X (it re-homes to Y on the next observe).
|
||||
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), nodeX, model)).To(Succeed())
|
||||
|
||||
yKeys := []prefixcache.ReplicaKey{{NodeID: nodeY, Replica: 0}}
|
||||
// The chokepoint hook dropped X from the index immediately.
|
||||
Expect(idx.Decide(model, chainA, yKeys, time.Now()).Hot).ToNot(Equal(hotX),
|
||||
"step 5: after X's replica is removed, chainA must no longer pin to X")
|
||||
|
||||
// Route(chainA): only Y still hosts the model, so it fails over to Y.
|
||||
Expect(routeAndSettle(chainA)).To(Equal(nodeY),
|
||||
"step 5: chainA must fail over to the surviving node Y")
|
||||
|
||||
// And the entry has re-homed: chainA now hot-matches Y, never X.
|
||||
reHomed := idx.Decide(model, chainA, yKeys, time.Now())
|
||||
hotY := prefixcache.ReplicaKey{NodeID: nodeY, Replica: 0}
|
||||
Expect(reHomed.Hot).ToNot(Equal(hotX),
|
||||
"step 5: chainA must not re-home to the removed node X")
|
||||
Expect(reHomed.Hot).To(Equal(hotY),
|
||||
"step 5: chainA must re-home to the surviving node Y")
|
||||
})
|
||||
|
||||
It("tracks affinity per replica when ONE node hosts TWO replicas of the model", func() {
|
||||
// This is the bug the replica-granular change fixes: two replicas of the
|
||||
// same model on the SAME node are distinct KV caches. A prefix observed
|
||||
// on replica (node,0) must NOT be reported as hot on the sibling replica
|
||||
// (node,1) of the same node.
|
||||
const multiNodeModel = "multi-replica-model"
|
||||
multiNode := &nodes.BackendNode{Name: "node-multi", Address: "127.0.0.1:50060", MaxReplicasPerModel: 2}
|
||||
Expect(registry.Register(context.Background(), multiNode, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 0, "loaded", "addr0", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 1, "loaded", "addr1", 0)).To(Succeed())
|
||||
|
||||
chain := []uint64{42, 43, 44}
|
||||
key0 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 0}
|
||||
key1 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 1}
|
||||
|
||||
// Observe the chain on replica 0 only.
|
||||
idx.Observe(multiNodeModel, chain, key0, time.Now())
|
||||
|
||||
d := idx.Decide(multiNodeModel, chain, []prefixcache.ReplicaKey{key0, key1}, time.Now())
|
||||
Expect(d.HasHot).To(BeTrue())
|
||||
Expect(d.Hot).To(Equal(key0),
|
||||
"the prefix was served by replica 0; the SAME-node sibling replica 1 must NOT be chosen")
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user