diff --git a/core/application/distributed.go b/core/application/distributed.go
index a82ca1931..64e8dc12e 100644
--- a/core/application/distributed.go
+++ b/core/application/distributed.go
@@ -16,7 +16,9 @@ import (
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/core/services/storage"
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"gorm.io/gorm"
@@ -240,6 +242,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
)
+ // Prefix-cache-aware routing. Enabled by default; an operator can opt out
+ // with --distributed-prefix-cache=false, which leaves prefixProvider and
+ // pressure nil so the SmartRouter and reconciler behave exactly as the
+ // round-robin floor (true no-op). When enabled we build the local index,
+ // wrap it in a NATS-backed Sync (publishes our observations, applies peers'
+ // via the subscriptions below), install the extraction hook used by
+ // core/backend/llm.go, and run a background eviction ticker on the app ctx.
+ var prefixProvider prefixcache.Provider
+ var pressure *prefixcache.Pressure
+ var prefixCfg prefixcache.Config
+ if !cfg.Distributed.PrefixCacheDisabled {
+ prefixCfg = prefixcache.DefaultConfig()
+ if cfg.Distributed.PrefixCacheTTL > 0 {
+ prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
+ }
+ if err := prefixCfg.Validate(); err != nil {
+ return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
+ }
+ idx := prefixcache.NewIndex(prefixCfg)
+ prefixSync := prefixcache.NewSync(idx, natsClient)
+ pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
+ prefixProvider = prefixSync
+
+ // Invalidate the prefix-cache index whenever a replica row is removed.
+ // SetReplicaRemovedHook fires from the single chokepoint all removal paths
+ // funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
+ // one hook covers every path: reconciler scale-down, probe reaper,
+ // health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
+ // it only inside this enabled block keeps the disabled path a true no-op
+ // (the registry stays hook-less).
+ registry.SetReplicaRemovedHook(func(model, node string, replica int) {
+ if replica < 0 {
+ prefixSync.InvalidateNode(model, node)
+ } else {
+ prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
+ }
+ })
+
+ distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
+ return prefixcache.ExtractChain(model, prompt, prefixCfg)
+ }
+
+ // Apply peers' observations/invalidations to the same Sync. ApplyObserve
+ // and ApplyInvalidate update only the local index and do not re-publish,
+ // so there is no broadcast loop.
+ if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
+ prefixSync.ApplyObserve(ev, time.Now())
+ }); err != nil {
+ return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
+ }
+ if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
+ prefixSync.ApplyInvalidate(ev)
+ }); err != nil {
+ return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
+ }
+
+ // Background eviction: sweep idle entries on the app context. Stopped
+ // when the app context is cancelled (mirrors the reconciler loop which
+ // also runs on options.Context). TTL/2 keeps stale entries from
+ // outliving their idle window by more than half a TTL.
+ evictInterval := prefixCfg.TTL / 2
+ go func() {
+ ticker := time.NewTicker(evictInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-cfg.Context.Done():
+ return
+ case <-ticker.C:
+ prefixSync.Evict(time.Now())
+ }
+ }
+ }()
+ xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
+ } else {
+ xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
+ }
+
// All dependencies ready — build SmartRouter with all options at once
var conflictResolver nodes.ConcurrencyConflictResolver
if configLoader != nil {
@@ -252,6 +332,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
AuthToken: routerAuthToken,
DB: authDB,
ConflictResolver: conflictResolver,
+ PrefixProvider: prefixProvider,
+ PrefixConfig: prefixCfg,
+ Pressure: pressure,
})
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
@@ -268,6 +351,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
Interval: 30 * time.Second,
ScaleDownDelay: 5 * time.Minute,
ProbeStaleAfter: 2 * time.Minute,
+ Pressure: pressure,
+ PressureThreshold: prefixCfg.PressureScaleThreshold,
})
// Create ModelRouterAdapter to wire into ModelLoader
diff --git a/core/backend/llm.go b/core/backend/llm.go
index 053e984e8..4f6b4d216 100644
--- a/core/backend/llm.go
+++ b/core/backend/llm.go
@@ -19,6 +19,7 @@ import (
"github.com/mudler/LocalAI/core/trace"
"github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
@@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}
}
+ // Make the rendered prompt's prefix chain available to the distributed router
+ // for prefix-cache-aware node selection. No-op in single-process mode. The
+ // model id MUST match the id ModelOptions feeds to model.WithModelID, so both
+ // use the shared config.ModelConfig.ModelID() helper (Name with a fallback to
+ // Model) or the chain salt and the tracking key would diverge.
+ //
+ // s is empty for UseTokenizerTemplate models (the backend tokenizes the
+ // structured messages itself), so fall back to a prefix-stable serialization
+ // of the messages - otherwise prefix routing would silently degrade to
+ // round-robin for the bulk of modern chat models.
+ chainSource := s
+ if chainSource == "" {
+ chainSource = messagesPrefixSource(messages)
+ }
+ ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)
+
opts := ModelOptions(*c, o, model.WithContext(ctx))
inferenceModel, err := loader.Load(opts...)
if err != nil {
diff --git a/core/backend/options.go b/core/backend/options.go
index 0215bf37a..c891b6d67 100644
--- a/core/backend/options.go
+++ b/core/backend/options.go
@@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back
}
func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option {
- name := c.Name
- if name == "" {
- name = c.Model
- }
-
defOpts := []model.Option{
model.WithBackendString(c.Backend),
model.WithModel(c.Model),
model.WithContext(so.Context),
- model.WithModelID(name),
+ model.WithModelID(c.ModelID()),
}
threads := 1
diff --git a/core/backend/prefix_source.go b/core/backend/prefix_source.go
new file mode 100644
index 000000000..2623033e3
--- /dev/null
+++ b/core/backend/prefix_source.go
@@ -0,0 +1,36 @@
+package backend
+
+import (
+ "strings"
+
+ "github.com/mudler/LocalAI/core/schema"
+)
+
+// messagesPrefixSource builds a deterministic, prefix-stable serialization of a
+// chat conversation for prefix-cache-aware routing. It is the fallback used when
+// the frontend did not render a prompt string: models with
+// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages
+// backend-side, so the frontend's rendered prompt is empty and a chain built
+// from it would always be empty - silently degrading prefix routing to
+// round-robin for the bulk of modern chat models.
+//
+// Messages are emitted head-first in turn order (role line + content line per
+// message), so two conversations sharing a leading system prompt and early turns
+// share a leading byte prefix. That is exactly what ExtractChain hashes into a
+// shared chain prefix, landing both requests on the same cache-warm replica.
+func messagesPrefixSource(messages schema.Messages) string {
+ var b strings.Builder
+ for _, m := range messages {
+ b.WriteString(m.Role)
+ b.WriteByte('\n')
+ content := m.StringContent
+ if content == "" {
+ if s, ok := m.Content.(string); ok {
+ content = s
+ }
+ }
+ b.WriteString(content)
+ b.WriteByte('\n')
+ }
+ return b.String()
+}
diff --git a/core/backend/prefix_source_internal_test.go b/core/backend/prefix_source_internal_test.go
new file mode 100644
index 000000000..0395b35e4
--- /dev/null
+++ b/core/backend/prefix_source_internal_test.go
@@ -0,0 +1,53 @@
+package backend
+
+import (
+ "strings"
+
+ "github.com/mudler/LocalAI/core/schema"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("messagesPrefixSource", func() {
+ mk := func(role, content string) schema.Message {
+ return schema.Message{Role: role, StringContent: content}
+ }
+
+ It("serializes messages head-first in turn order", func() {
+ got := messagesPrefixSource(schema.Messages{
+ mk("system", "You are helpful."),
+ mk("user", "Hi"),
+ })
+ Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n"))
+ })
+
+ It("is deterministic across calls for the same conversation", func() {
+ conv := schema.Messages{mk("system", "S"), mk("user", "U")}
+ Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv)))
+ })
+
+ It("shares a leading byte prefix when the system prompt is shared", func() {
+ shared := "system\nShared system prompt.\nuser\n"
+ a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")})
+ b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")})
+ Expect(strings.HasPrefix(a, shared)).To(BeTrue())
+ Expect(strings.HasPrefix(b, shared)).To(BeTrue())
+ })
+
+ It("does NOT share a prefix when the system prompt differs", func() {
+ a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")})
+ b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")})
+ Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue())
+ Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue())
+ })
+
+ It("returns empty for no messages", func() {
+ Expect(messagesPrefixSource(nil)).To(Equal(""))
+ })
+
+ It("falls back to Content when StringContent is empty", func() {
+ got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}})
+ Expect(got).To(Equal("user\nplain\n"))
+ })
+})
diff --git a/core/cli/run.go b/core/cli/run.go
index 09a58971b..8bbc2b20c 100644
--- a/core/cli/run.go
+++ b/core/cli/run.go
@@ -145,19 +145,21 @@ type RunCMD struct {
DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"`
// Distributed / Horizontal Scaling
- Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
- InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
- NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
- StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
- StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
- StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
- StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
- StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
- RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
- AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
- BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
- BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
- ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
+ Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"`
+ InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"`
+ NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"`
+ StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"`
+ StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"`
+ StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"`
+ StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"`
+ StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"`
+ RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"`
+ AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"`
+ DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"`
+ DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
+ BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
+ BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
+ ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
Version bool
@@ -284,6 +286,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}
+ if !r.DistributedPrefixCache {
+ opts = append(opts, config.DisablePrefixCache)
+ }
+ if r.DistributedPrefixCacheTTL != "" {
+ d, err := time.ParseDuration(r.DistributedPrefixCacheTTL)
+ if err != nil {
+ return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err)
+ }
+ opts = append(opts, config.WithPrefixCacheTTL(d))
+ }
if r.ExposeNodeHeader {
opts = append(opts, config.WithExposeNodeHeader(true))
}
diff --git a/core/config/distributed_config.go b/core/config/distributed_config.go
index 0ff783321..e0e0454d9 100644
--- a/core/config/distributed_config.go
+++ b/core/config/distributed_config.go
@@ -49,6 +49,17 @@ type DistributedConfig struct {
AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"`
JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"`
+
+ // PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to
+ // round-robin (the floor). Prefix-cache routing is ON by default in
+ // distributed mode; this flag exists so operators can opt out. The CLI
+ // surfaces a default-true --distributed-prefix-cache enable flag and sets
+ // this when the operator passes --distributed-prefix-cache=false.
+ PrefixCacheDisabled bool
+ // PrefixCacheTTL is the idle-timeout for prefix-cache index entries and
+ // drives the background eviction cadence (eviction runs every TTL/2). Zero
+ // means use the prefixcache package default (5m).
+ PrefixCacheTTL time.Duration
}
// Validate checks that the distributed configuration is internally consistent.
@@ -158,6 +169,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) {
o.Distributed.AutoApproveNodes = true
}
+// DisablePrefixCache turns off prefix-cache-aware routing (falls back to
+// round-robin). Prefix-cache routing is enabled by default in distributed mode.
+var DisablePrefixCache = func(o *ApplicationConfig) {
+ o.Distributed.PrefixCacheDisabled = true
+}
+
+// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the
+// background eviction cadence, which runs every TTL/2).
+func WithPrefixCacheTTL(d time.Duration) AppOption {
+ return func(o *ApplicationConfig) {
+ o.Distributed.PrefixCacheTTL = d
+ }
+}
+
// Flag names for distributed timeout / interval configuration. These are
// the kebab-case identifiers kong derives from the matching RunCMD struct
// fields; they appear in Validate error messages and any other operator-
diff --git a/core/config/model_config.go b/core/config/model_config.go
index d57544c6f..a1b000798 100644
--- a/core/config/model_config.go
+++ b/core/config/model_config.go
@@ -694,6 +694,18 @@ func (c *ModelConfig) IsModelURL() bool {
return uri.LooksLikeURL()
}
+// ModelID returns the identifier used to reference this model across the
+// system: the configured Name, falling back to Model when Name is empty.
+// This is the single source of truth for the id fed to model.WithModelID and
+// the prefix-cache chain salt; both MUST agree with the router's tracking key
+// or the prefix-cache salt diverges silently.
+func (c ModelConfig) ModelID() string {
+ if c.Name != "" {
+ return c.Name
+ }
+ return c.Model
+}
+
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *ModelConfig) ModelFileName() string {
diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go
index a93912e1b..5abfb8eaf 100644
--- a/core/config/model_config_test.go
+++ b/core/config/model_config_test.go
@@ -10,6 +10,23 @@ import (
)
var _ = Describe("Test cases for config related functions", func() {
+ Context("ModelID", func() {
+ It("returns Name when set", func() {
+ c := ModelConfig{Name: "my-name"}
+ c.Model = "my-model"
+ Expect(c.ModelID()).To(Equal("my-name"))
+ })
+ It("falls back to Model when Name is empty", func() {
+ c := ModelConfig{}
+ c.Model = "my-model"
+ Expect(c.ModelID()).To(Equal("my-model"))
+ })
+ It("returns empty string when both are empty", func() {
+ c := ModelConfig{}
+ Expect(c.ModelID()).To(Equal(""))
+ })
+ })
+
Context("Test Read configuration functions", func() {
It("Test Validate", func() {
tmp, err := os.CreateTemp("", "config.yaml")
diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go
index 83fd2170c..90dbe70b0 100644
--- a/core/http/endpoints/localai/nodes.go
+++ b/core/http/endpoints/localai/nodes.go
@@ -6,6 +6,7 @@ import (
"crypto/subtle"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -25,6 +26,7 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/nodes"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/pkg/httpclient"
)
@@ -911,14 +913,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
}
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
+//
+// The four prefix-cache fields are POINTERS so an omitted field is
+// distinguishable from an explicit zero. On update, an omitted prefix-cache
+// field preserves the model's previously-configured value instead of resetting
+// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
+// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
type SetSchedulingRequest struct {
- ModelName string `json:"model_name"`
- NodeSelector map[string]string `json:"node_selector,omitempty"`
- MinReplicas int `json:"min_replicas"`
- MaxReplicas int `json:"max_replicas"`
+ ModelName string `json:"model_name"`
+ NodeSelector map[string]string `json:"node_selector,omitempty"`
+ MinReplicas int `json:"min_replicas"`
+ MaxReplicas int `json:"max_replicas"`
+ RoutePolicy *string `json:"route_policy,omitempty"`
+ BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
+ BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
+ MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"`
+}
+
+// validateSchedulingRequest enforces the invariants of a scheduling config.
+// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the
+// single source of truth), and are checked against the RESOLVED values passed
+// in (provided-or-preserved), so validation only rejects bad values the caller
+// actually supplied. It returns nil when valid, or an error with a user-facing
+// message describing the first violation.
+func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error {
+ if req.ModelName == "" {
+ return errors.New("model_name is required")
+ }
+ if req.MinReplicas < 0 {
+ return errors.New("min_replicas must be >= 0")
+ }
+ if req.MaxReplicas < 0 {
+ return errors.New("max_replicas must be >= 0")
+ }
+ if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
+ return errors.New("min_replicas must be <= max_replicas")
+ }
+ if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil {
+ return err
+ }
+ return nil
}
// SetSchedulingEndpoint creates or updates a model scheduling config.
+//
+// The registry upsert full-replaces all columns, so a request that omits the
+// prefix-cache fields would otherwise wipe a model's previously-configured
+// routing settings. To avoid that footgun the four prefix-cache fields are
+// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the
+// existing config's value (or the zero default when no config exists yet). The
+// non-prefix fields keep their full-replace PUT behavior.
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
@@ -926,17 +970,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
}
- if req.ModelName == "" {
- return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
+
+ // Fetch the existing config (may be nil) so omitted prefix-cache fields
+ // can fall back to the stored value rather than resetting to zero.
+ var existing *nodes.ModelSchedulingConfig
+ if req.ModelName != "" {
+ var err error
+ existing, err = registry.GetModelScheduling(ctx, req.ModelName)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config"))
+ }
}
- if req.MinReplicas < 0 {
- return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
+
+ // Resolve each prefix-cache field: provided pointer wins, otherwise keep
+ // the existing value (zero/default when there is no existing config).
+ routePolicy := ""
+ absThr := 0
+ relThr := 0.0
+ minMatch := 0.0
+ if existing != nil {
+ routePolicy = existing.RoutePolicy
+ absThr = existing.BalanceAbsThreshold
+ relThr = existing.BalanceRelThreshold
+ minMatch = existing.MinPrefixMatch
}
- if req.MaxReplicas < 0 {
- return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
+ if req.RoutePolicy != nil {
+ routePolicy = *req.RoutePolicy
}
- if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
- return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
+ if req.BalanceAbsThreshold != nil {
+ absThr = *req.BalanceAbsThreshold
+ }
+ if req.BalanceRelThreshold != nil {
+ relThr = *req.BalanceRelThreshold
+ }
+ if req.MinPrefixMatch != nil {
+ minMatch = *req.MinPrefixMatch
+ }
+
+ if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error()))
}
// Serialize node selector to JSON
@@ -950,10 +1022,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
}
config := &nodes.ModelSchedulingConfig{
- ModelName: req.ModelName,
- NodeSelector: selectorJSON,
- MinReplicas: req.MinReplicas,
- MaxReplicas: req.MaxReplicas,
+ ModelName: req.ModelName,
+ NodeSelector: selectorJSON,
+ MinReplicas: req.MinReplicas,
+ MaxReplicas: req.MaxReplicas,
+ RoutePolicy: routePolicy,
+ BalanceAbsThreshold: absThr,
+ BalanceRelThreshold: relThr,
+ MinPrefixMatch: minMatch,
}
if err := registry.SetModelScheduling(ctx, config); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
diff --git a/core/http/endpoints/localai/nodes_scheduling_validation_test.go b/core/http/endpoints/localai/nodes_scheduling_validation_test.go
new file mode 100644
index 000000000..800a14d7a
--- /dev/null
+++ b/core/http/endpoints/localai/nodes_scheduling_validation_test.go
@@ -0,0 +1,66 @@
+package localai
+
+import (
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("validateSchedulingRequest", func() {
+ base := func() SetSchedulingRequest {
+ return SetSchedulingRequest{ModelName: "m"}
+ }
+
+ It("accepts an empty route policy (inherit) with valid thresholds", func() {
+ Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed())
+ })
+
+ It("accepts the prefix_cache policy", func() {
+ Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed())
+ })
+
+ It("accepts the round_robin policy", func() {
+ Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed())
+ })
+
+ It("accepts balance_rel_threshold >= 1", func() {
+ Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed())
+ })
+
+ It("rejects a missing model_name", func() {
+ req := base()
+ req.ModelName = ""
+ err := validateSchedulingRequest(req, "", 0, 0, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("model_name is required"))
+ })
+
+ It("rejects an unknown route_policy (no silent default)", func() {
+ err := validateSchedulingRequest(base(), "bogus", 0, 0, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("route_policy"))
+ })
+
+ It("rejects min_prefix_match above 1", func() {
+ err := validateSchedulingRequest(base(), "", 0, 0, 2)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
+ })
+
+ It("rejects a negative min_prefix_match", func() {
+ err := validateSchedulingRequest(base(), "", 0, 0, -0.1)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
+ })
+
+ It("rejects a negative balance_abs_threshold", func() {
+ err := validateSchedulingRequest(base(), "", -1, 0, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
+ })
+
+ It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
+ err := validateSchedulingRequest(base(), "", 0, 0.5, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
+ })
+})
diff --git a/core/http/endpoints/localai/nodes_test.go b/core/http/endpoints/localai/nodes_test.go
index f9eb8071b..fdb29987d 100644
--- a/core/http/endpoints/localai/nodes_test.go
+++ b/core/http/endpoints/localai/nodes_test.go
@@ -230,6 +230,114 @@ var _ = Describe("Node HTTP handlers", func() {
})
})
+ Describe("SetSchedulingEndpoint", func() {
+ postScheduling := func(body string) *httptest.ResponseRecorder {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := SetSchedulingEndpoint(registry)
+ Expect(handler(c)).To(Succeed())
+ return rec
+ }
+
+ It("persists prefix-cache fields and round-trips them via GET", func() {
+ ctx := context.Background()
+ rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`)
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ cfg, err := registry.GetModelScheduling(ctx, "pc-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(cfg).ToNot(BeNil())
+ Expect(cfg.RoutePolicy).To(Equal("prefix_cache"))
+ Expect(cfg.BalanceAbsThreshold).To(Equal(3))
+ Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
+
+ e := echo.New()
+ getReq := httptest.NewRequest(http.MethodGet, "/", nil)
+ getRec := httptest.NewRecorder()
+ gc := e.NewContext(getReq, getRec)
+ gc.SetParamNames("model")
+ gc.SetParamValues("pc-model")
+ Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed())
+ Expect(getRec.Code).To(Equal(http.StatusOK))
+
+ var got nodes.ModelSchedulingConfig
+ Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed())
+ Expect(got.RoutePolicy).To(Equal("prefix_cache"))
+ Expect(got.BalanceAbsThreshold).To(Equal(3))
+ Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9))
+ })
+
+ It("returns 400 for an out-of-range min_prefix_match", func() {
+ rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`)
+ Expect(rec.Code).To(Equal(http.StatusBadRequest))
+ var resp map[string]any
+ Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
+ errObj, ok := resp["error"].(map[string]any)
+ Expect(ok).To(BeTrue())
+ Expect(errObj["message"]).To(ContainSubstring("min_prefix_match"))
+ })
+
+ It("returns 400 for an unknown route_policy", func() {
+ rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`)
+ Expect(rec.Code).To(Equal(http.StatusBadRequest))
+ var resp map[string]any
+ Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
+ errObj, ok := resp["error"].(map[string]any)
+ Expect(ok).To(BeTrue())
+ Expect(errObj["message"]).To(ContainSubstring("route_policy"))
+ })
+
+ It("returns 400 for a balance_rel_threshold between 0 and 1", func() {
+ rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`)
+ Expect(rec.Code).To(Equal(http.StatusBadRequest))
+ var resp map[string]any
+ Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
+ errObj, ok := resp["error"].(map[string]any)
+ Expect(ok).To(BeTrue())
+ Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold"))
+ })
+
+ // Regression for the partial-update footgun: a min/max-only POST used to
+ // full-replace every column and silently reset the prefix-cache settings
+ // to empty/zero. The pointer-merge must preserve omitted prefix fields.
+ It("preserves prefix-cache settings across a min_replicas-only update", func() {
+ ctx := context.Background()
+
+ rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ // Update only min_replicas - omits all prefix-cache fields.
+ rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`)
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ cfg, err := registry.GetModelScheduling(ctx, "merge-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(cfg).ToNot(BeNil())
+ Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update")
+ Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved")
+ Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved")
+ })
+
+ It("updates a prefix-cache field when it is explicitly provided", func() {
+ ctx := context.Background()
+
+ rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`)
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`)
+ Expect(rec.Code).To(Equal(http.StatusOK))
+
+ cfg, err := registry.GetModelScheduling(ctx, "update-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(cfg).ToNot(BeNil())
+ Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update")
+ Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved")
+ })
+ })
+
Describe("ListNodesEndpoint", func() {
It("returns an empty list when no nodes are registered", func() {
e := echo.New()
diff --git a/core/http/react-ui/src/pages/Nodes.jsx b/core/http/react-ui/src/pages/Nodes.jsx
index 4acfb5ebd..f2eb9d955 100644
--- a/core/http/react-ui/src/pages/Nodes.jsx
+++ b/core/http/react-ui/src/pages/Nodes.jsx
@@ -493,6 +493,13 @@ function SchedulingForm({ onSave, onCancel }) {
const [selector, setSelector] = useState({})
const [minReplicas, setMinReplicas] = useState(1)
const [maxReplicas, setMaxReplicas] = useState(0)
+ // Prefix-cache routing controls. Empty routePolicy means "inherit the
+ // cluster default"; the three thresholds at 0 likewise inherit, so they
+ // stay out of the POST body's effective override only when explicitly set.
+ const [routePolicy, setRoutePolicy] = useState('')
+ const [balanceAbsThreshold, setBalanceAbsThreshold] = useState(0)
+ const [balanceRelThreshold, setBalanceRelThreshold] = useState(0)
+ const [minPrefixMatch, setMinPrefixMatch] = useState(0)
const hasSelector = Object.keys(selector).length > 0
@@ -508,6 +515,10 @@ function SchedulingForm({ onSave, onCancel }) {
node_selector: hasSelector ? selector : undefined,
min_replicas: mode === 'placement' ? 0 : minReplicas,
max_replicas: mode === 'placement' ? 0 : maxReplicas,
+ route_policy: routePolicy,
+ balance_abs_threshold: balanceAbsThreshold,
+ balance_rel_threshold: balanceRelThreshold,
+ min_prefix_match: minPrefixMatch,
})
}
@@ -593,6 +604,76 @@ function SchedulingForm({ onSave, onCancel }) {
/>
)}
+
+ {/* Per-model routing policy. Left empty/zero these inherit the
+ cluster-wide defaults; set them to override how requests for this
+ model are spread across replicas. */}
+
+
+
+
+ Prefix Cache routes shared-prefix requests to the same replica to reuse its KV cache, falling back to round-robin when replicas are imbalanced.
+
+
+
+ {routePolicy === 'prefix_cache' && (
+
+
+
+ setMinPrefixMatch(parseFloat(e.target.value) || 0)}
+ />
+
+ Fraction of the prompt (0..1) that must match a cached prefix before affinity kicks in. 0 inherits the default.
+
+
+
+
+ setBalanceAbsThreshold(parseInt(e.target.value) || 0)}
+ />
+
+ Max absolute in-flight gap allowed before falling back to round-robin. 0 inherits the default.
+
+
+
+
+ setBalanceRelThreshold(parseFloat(e.target.value) || 0)}
+ />
+
+ Max relative in-flight ratio (>= 1) allowed before falling back to round-robin. 0 inherits the default.
+
+
+
+ )}
{/* Hairline divider above the actions, matching the project's form pattern. */}
@@ -1475,6 +1556,8 @@ export default function Nodes() {
Node Selector
Min Replicas
Max Replicas
+
Routing
+
Thresholds
Status
Actions
@@ -1519,6 +1602,18 @@ export default function Nodes() {
{isUnsatisfiable ? (
= 0 it targets the single replica (Model, NodeID, Replica). When
+// Replica < 0 it targets ALL replicas of (Model, NodeID), for example when a
+// whole node goes offline.
+type PrefixCacheInvalidateEvent struct {
+ Model string `json:"model"`
+ NodeID string `json:"node_id"`
+ Replica int `json:"replica"`
+}
diff --git a/core/services/messaging/subjects_prefixcache_test.go b/core/services/messaging/subjects_prefixcache_test.go
new file mode 100644
index 000000000..2b8eb5771
--- /dev/null
+++ b/core/services/messaging/subjects_prefixcache_test.go
@@ -0,0 +1,27 @@
+package messaging_test
+
+import (
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/messaging"
+)
+
+var _ = Describe("PrefixCache subjects", func() {
+ It("exposes stable subject constants", func() {
+ Expect(messaging.SubjectPrefixCacheObserve).To(Equal("prefixcache.observe"))
+ Expect(messaging.SubjectPrefixCacheInvalidate).To(Equal("prefixcache.invalidate"))
+ })
+
+ It("carries a replica index on the observe event", func() {
+ ev := messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 3}
+ Expect(ev.Replica).To(Equal(3))
+ })
+
+ It("uses a negative replica on the invalidate event to mean all replicas of a node", func() {
+ all := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1}
+ Expect(all.Replica).To(BeNumerically("<", 0))
+ one := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0}
+ Expect(one.Replica).To(Equal(0))
+ })
+})
diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go
index c1dcd2af9..4e82d56cf 100644
--- a/core/services/nodes/interfaces.go
+++ b/core/services/nodes/interfaces.go
@@ -9,7 +9,7 @@ import (
// ModelRouter is used by SmartRouter for routing decisions and model lifecycle.
type ModelRouter interface {
- FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error)
+ FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error)
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
@@ -37,6 +37,7 @@ type ModelRouter interface {
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
+ LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error)
}
// ConcurrencyConflictResolver returns the names of configured models that
diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go
index 190dac731..1670808c7 100644
--- a/core/services/nodes/model_router_test.go
+++ b/core/services/nodes/model_router_test.go
@@ -27,7 +27,7 @@ func newFakeModelRouterForSmartRouter() *fakeModelRouterForSmartRouter {
}
}
-func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string) (*BackendNode, *NodeModel, error) {
+func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string, _ *RoutePreference) (*BackendNode, *NodeModel, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.node, f.nodeModel, f.findErr
@@ -121,6 +121,9 @@ func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ strin
func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) {
return nil, nil
}
+func (f *fakeModelRouterForSmartRouter) LoadedReplicaStats(_ context.Context, _ string, _ []string) ([]ReplicaCandidate, error) {
+ return nil, nil
+}
// Compile-time check
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
diff --git a/core/services/nodes/prefixcache/config.go b/core/services/nodes/prefixcache/config.go
new file mode 100644
index 000000000..dc897f0b0
--- /dev/null
+++ b/core/services/nodes/prefixcache/config.go
@@ -0,0 +1,95 @@
+package prefixcache
+
+import (
+ "fmt"
+ "time"
+)
+
+// Config holds prefix-cache-aware routing settings. Per-model overrides
+// (policy, abs/rel thresholds, min-match) live on ModelSchedulingConfig; TTL
+// and window/depth are global-only.
+type Config struct {
+ GlobalPolicy RoutePolicy
+ MinPrefixMatch float64 // ratio matched/total, [0,1]
+ BalanceAbsThreshold int // absolute in-flight slack
+ BalanceRelThreshold float64 // relative load ratio, >= 1
+ TTL time.Duration // idle-timeout for entries
+ HalfLife time.Duration // recency decay for cacheWeight
+ WindowBytes int // chunk window size
+ MaxDepth int // max trailing blocks hashed
+ // PressureWindow is the rolling window over which forced-disturb events are
+ // counted for the autoscale signal (see Pressure). Default 1 minute.
+ PressureWindow time.Duration
+ // PressureScaleThreshold is the minimum forced-disturb count within
+ // PressureWindow that makes the reconciler treat the cache-warm replica as
+ // saturated and scale up (subject to MaxReplicas and capacity). Default 1,
+ // i.e. any sustained forced-disturb.
+ PressureScaleThreshold int
+}
+
+func DefaultConfig() Config {
+ return Config{
+ GlobalPolicy: RoutePolicyPrefixCache,
+ MinPrefixMatch: 0.3,
+ BalanceAbsThreshold: 2,
+ BalanceRelThreshold: 1.5,
+ TTL: 5 * time.Minute,
+ HalfLife: 2 * time.Minute,
+ WindowBytes: 256,
+ MaxDepth: 64,
+ PressureWindow: time.Minute,
+ PressureScaleThreshold: 1,
+ }
+}
+
+// validateThresholdBounds enforces the numeric bounds shared between the
+// per-model override validator (ValidateThresholds) and Config.Validate:
+// minMatch in [0,1]; absThr >= 0; relThr == 0 (inherit) or >= 1. It is the
+// single source of truth for those bounds so the endpoint and the global
+// config cannot drift apart.
+func validateThresholdBounds(absThr int, relThr, minMatch float64) error {
+ if minMatch < 0 || minMatch > 1 {
+ return fmt.Errorf("prefixcache: min_prefix_match must be in [0,1], got %v", minMatch)
+ }
+ if absThr < 0 {
+ return fmt.Errorf("prefixcache: balance_abs_threshold must be >= 0, got %d", absThr)
+ }
+ if relThr != 0 && relThr < 1 {
+ return fmt.Errorf("prefixcache: balance_rel_threshold must be 0 (inherit) or >= 1, got %v", relThr)
+ }
+ return nil
+}
+
+// ValidateThresholds checks per-model override bounds. routePolicy must be one
+// of "", "round_robin", "prefix_cache" (explicit allow-list - NOT ParsePolicy,
+// which maps unknown to Default and would accept typos). minMatch in [0,1];
+// absThr >= 0; relThr == 0 (inherit) or >= 1.
+func ValidateThresholds(routePolicy string, absThr int, relThr, minMatch float64) error {
+ switch routePolicy {
+ case "", "round_robin", "prefix_cache":
+ default:
+ return fmt.Errorf(`prefixcache: route_policy must be one of "", "round_robin", "prefix_cache", got %q`, routePolicy)
+ }
+ return validateThresholdBounds(absThr, relThr, minMatch)
+}
+
+func (c Config) Validate() error {
+ // Config.BalanceRelThreshold has no "inherit" sentinel - it is a concrete
+ // global value that must be >= 1 - so pass 0 for relThr to the shared
+ // numeric check and assert the >= 1 floor here separately.
+ if err := validateThresholdBounds(c.BalanceAbsThreshold, 0, c.MinPrefixMatch); err != nil {
+ return err
+ }
+ if c.BalanceRelThreshold < 1 {
+ return fmt.Errorf("prefixcache: balance_rel_threshold must be >= 1, got %v", c.BalanceRelThreshold)
+ }
+ if c.WindowBytes <= 0 || c.MaxDepth <= 0 {
+ return fmt.Errorf("prefixcache: window_bytes and max_depth must be > 0")
+ }
+ // TTL must be positive: it is the entry idle-lifetime and the eviction
+ // ticker runs at TTL/2, so time.NewTicker would panic on TTL <= 0.
+ if c.TTL <= 0 {
+ return fmt.Errorf("prefixcache: ttl must be > 0, got %v", c.TTL)
+ }
+ return nil
+}
diff --git a/core/services/nodes/prefixcache/config_test.go b/core/services/nodes/prefixcache/config_test.go
new file mode 100644
index 000000000..07ed0466f
--- /dev/null
+++ b/core/services/nodes/prefixcache/config_test.go
@@ -0,0 +1,73 @@
+package prefixcache_test
+
+import (
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+var _ = Describe("Config", func() {
+ It("supplies defaults", func() {
+ c := prefixcache.DefaultConfig()
+ Expect(c.GlobalPolicy).To(Equal(prefixcache.RoutePolicyPrefixCache)) // default ON
+ Expect(c.MinPrefixMatch).To(BeNumerically("==", 0.3))
+ Expect(c.BalanceAbsThreshold).To(Equal(2))
+ Expect(c.BalanceRelThreshold).To(BeNumerically("==", 1.5))
+ Expect(c.TTL).To(Equal(5 * time.Minute))
+ Expect(c.WindowBytes).To(Equal(256))
+ Expect(c.MaxDepth).To(Equal(64))
+ })
+
+ It("rejects invalid values", func() {
+ c := prefixcache.DefaultConfig()
+ c.MinPrefixMatch = 1.5
+ Expect(c.Validate()).To(HaveOccurred())
+ c = prefixcache.DefaultConfig()
+ c.BalanceAbsThreshold = -1
+ Expect(c.Validate()).To(HaveOccurred())
+ c = prefixcache.DefaultConfig()
+ c.TTL = 0
+ Expect(c.Validate()).To(HaveOccurred()) // TTL/2 ticker would panic
+ })
+})
+
+var _ = Describe("ValidateThresholds", func() {
+ It("accepts valid values across all route policies", func() {
+ Expect(prefixcache.ValidateThresholds("", 3, 0, 0.4)).To(Succeed())
+ Expect(prefixcache.ValidateThresholds("round_robin", 0, 1.5, 0)).To(Succeed())
+ Expect(prefixcache.ValidateThresholds("prefix_cache", 2, 2.0, 1.0)).To(Succeed())
+ })
+
+ It("rejects an unknown route_policy (explicit allow-list, no silent default)", func() {
+ err := prefixcache.ValidateThresholds("bogus", 0, 0, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("route_policy"))
+ })
+
+ It("rejects min_prefix_match above 1", func() {
+ err := prefixcache.ValidateThresholds("", 0, 0, 1.5)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
+ })
+
+ It("rejects a negative min_prefix_match", func() {
+ err := prefixcache.ValidateThresholds("", 0, 0, -0.1)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("min_prefix_match"))
+ })
+
+ It("rejects a negative balance_abs_threshold", func() {
+ err := prefixcache.ValidateThresholds("", -1, 0, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("balance_abs_threshold"))
+ })
+
+ It("rejects balance_rel_threshold between 0 and 1 exclusive", func() {
+ err := prefixcache.ValidateThresholds("", 0, 0.5, 0)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("balance_rel_threshold"))
+ })
+})
diff --git a/core/services/nodes/prefixcache/export_test.go b/core/services/nodes/prefixcache/export_test.go
new file mode 100644
index 000000000..dbfe62166
--- /dev/null
+++ b/core/services/nodes/prefixcache/export_test.go
@@ -0,0 +1,18 @@
+package prefixcache
+
+// LenForTest exposes the internal per-model slice length so black-box tests can
+// assert that Record bounds its backing slice. Test-only.
+func (p *Pressure) LenForTest(model string) int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return len(p.events[model])
+}
+
+// TreeCountForTest exposes the number of per-model radix trees the Index
+// currently retains, so black-box tests can assert that Invalidate does not
+// intern empty trees for models that never used the prefix cache. Test-only.
+func (ix *Index) TreeCountForTest() int {
+ ix.mu.RLock()
+ defer ix.mu.RUnlock()
+ return len(ix.trees)
+}
diff --git a/core/services/nodes/prefixcache/extractor.go b/core/services/nodes/prefixcache/extractor.go
new file mode 100644
index 000000000..65531bafb
--- /dev/null
+++ b/core/services/nodes/prefixcache/extractor.go
@@ -0,0 +1,57 @@
+package prefixcache
+
+import (
+ "encoding/binary"
+
+ "github.com/cespare/xxhash/v2"
+)
+
+// ExtractChain renders prompt into a cumulative chain of prefix hashes:
+// h[0]=H(salt,block0), h[i]=H(h[i-1],block_i). Blocks are fixed
+// cfg.WindowBytes-byte windows over the prompt bytes, chunked from absolute
+// offset 0 with fixed boundaries [0,W), [W,2W), ... and the chain is capped to
+// the FIRST cfg.MaxDepth blocks (the head).
+//
+// Head-first chunking is what makes this a true prefix-chain. The reusable
+// KV/prefix cache is always at the HEAD of the prompt: the system prompt and
+// early turns are stable, new content is appended at the end, and the KV cache
+// is valid up to the first differing token scanning from the start. Because the
+// boundaries are anchored at offset 0 (never length-dependent), a prompt P and
+// any extension P+suffix share their entire leading overlap, so turn N and turn
+// N+1 match for longest-prefix routing. Prefixes deeper than
+// MaxDepth*WindowBytes bytes are treated as equal (two prompts agreeing on the
+// first MaxDepth head blocks yield identical chains): an accepted routing-hint
+// limitation, since the cap bounds the chain length for very long prompts.
+//
+// xxhash is used (not hash/maphash) because the hash MUST be identical across
+// frontend processes: peers exchange these hashes over NATS, and maphash uses a
+// per-process random seed that would make peers disagree.
+func ExtractChain(model, prompt string, cfg Config) []uint64 {
+ if prompt == "" {
+ return nil
+ }
+ data := []byte(prompt)
+ nBlocks := (len(data) + cfg.WindowBytes - 1) / cfg.WindowBytes
+ depth := min(nBlocks, cfg.MaxDepth)
+ salt := xxhash.Sum64String(model)
+ // One Digest reused across blocks: Reset() restores the seed-0 initial
+ // state, so Reset()+Write produces the byte-identical value to a fresh
+ // New()+Write. xxhash seed 0 is stateless, so output is unchanged while we
+ // avoid allocating a Digest per block. The output determinism across
+ // processes (peers exchange these hashes over NATS) is preserved.
+ h := xxhash.New()
+ chain := make([]uint64, 0, depth)
+ prev := salt
+ var pb [8]byte
+ for i := range depth {
+ off := i * cfg.WindowBytes
+ end := min(off+cfg.WindowBytes, len(data))
+ h.Reset()
+ binary.LittleEndian.PutUint64(pb[:], prev)
+ _, _ = h.Write(pb[:])
+ _, _ = h.Write(data[off:end])
+ prev = h.Sum64()
+ chain = append(chain, prev)
+ }
+ return chain
+}
diff --git a/core/services/nodes/prefixcache/extractor_test.go b/core/services/nodes/prefixcache/extractor_test.go
new file mode 100644
index 000000000..53ee95d3f
--- /dev/null
+++ b/core/services/nodes/prefixcache/extractor_test.go
@@ -0,0 +1,75 @@
+package prefixcache_test
+
+import (
+ "strings"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+var _ = Describe("Extractor", func() {
+ cfg := prefixcache.DefaultConfig()
+
+ It("produces a deterministic chain for the same prompt and model", func() {
+ a := prefixcache.ExtractChain("modelX", "hello world", cfg)
+ b := prefixcache.ExtractChain("modelX", "hello world", cfg)
+ Expect(a).To(Equal(b))
+ Expect(len(a)).To(BeNumerically(">", 0))
+ })
+
+ It("shares the head but diverges on a volatile tail", func() {
+ base := strings.Repeat("system rules ", 100) // > one window
+ x := prefixcache.ExtractChain("m", base+"Current time 12:00:00", cfg)
+ y := prefixcache.ExtractChain("m", base+"Current time 12:00:01", cfg)
+ // leading hashes (the stable head) are identical
+ Expect(x[0]).To(Equal(y[0]))
+ // the final (tail) hash differs
+ Expect(x[len(x)-1]).NotTo(Equal(y[len(y)-1]))
+ })
+
+ It("salts by model so identical text yields different chains per model", func() {
+ Expect(prefixcache.ExtractChain("m1", "abc", cfg)[0]).
+ NotTo(Equal(prefixcache.ExtractChain("m2", "abc", cfg)[0]))
+ })
+
+ It("caps depth", func() {
+ small := cfg
+ small.WindowBytes = 1
+ small.MaxDepth = 4
+ chain := prefixcache.ExtractChain("m", "abcdefghij", small)
+ Expect(len(chain)).To(Equal(4))
+ })
+
+ It("returns nil for empty prompt", func() {
+ Expect(prefixcache.ExtractChain("m", "", cfg)).To(BeNil())
+ })
+
+ It("stays stable across turns once the prompt grows past the depth cap", func() {
+ small := cfg
+ small.WindowBytes = 4
+ small.MaxDepth = 3 // 12-byte head budget
+
+ // base is longer than MaxDepth*WindowBytes so the chain is capped to
+ // the first 3 head blocks.
+ base := "system-rules-stable-prefix-that-exceeds-the-budget"
+ Expect(len(base)).To(BeNumerically(">", small.WindowBytes*small.MaxDepth))
+
+ turnN := prefixcache.ExtractChain("m", base, small)
+ turnN1 := prefixcache.ExtractChain("m", base+"more text appended", small)
+ // Both capped to the same first MaxDepth head blocks -> identical chains.
+ Expect(turnN).To(HaveLen(small.MaxDepth))
+ Expect(turnN1).To(HaveLen(small.MaxDepth))
+ Expect(turnN1).To(Equal(turnN))
+
+ // A prompt diverging WITHIN the budget shares the leading hashes up to
+ // the divergence block and differs after. "system-r" matches base for
+ // the first two 4-byte blocks ("syst","em-r"), then block 2 differs.
+ divergent := prefixcache.ExtractChain("m", "system-rDIFFERENT-tail", small)
+ Expect(divergent).To(HaveLen(small.MaxDepth))
+ Expect(divergent[0]).To(Equal(turnN[0]))
+ Expect(divergent[1]).To(Equal(turnN[1]))
+ Expect(divergent[2]).NotTo(Equal(turnN[2]))
+ })
+})
diff --git a/core/services/nodes/prefixcache/index.go b/core/services/nodes/prefixcache/index.go
new file mode 100644
index 000000000..fe629ff56
--- /dev/null
+++ b/core/services/nodes/prefixcache/index.go
@@ -0,0 +1,129 @@
+package prefixcache
+
+import (
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/mudler/LocalAI/pkg/radixtree"
+)
+
+// Index is the guessed (routing-history) Provider backed by per-model radix
+// trees keyed by ReplicaKey. Affinity is per replica, so the same prefix served
+// by two replicas of one node resolves back to the exact replica that served it.
+// Safe for concurrent use.
+type Index struct {
+ cfg Config
+ mu sync.RWMutex
+ trees map[string]*radixtree.Tree[ReplicaKey]
+}
+
+func NewIndex(cfg Config) *Index {
+ return &Index{cfg: cfg, trees: map[string]*radixtree.Tree[ReplicaKey]{}}
+}
+
+// existingTree returns the tree for model without creating one. The bool
+// reports whether a tree already existed.
+func (ix *Index) existingTree(model string) (*radixtree.Tree[ReplicaKey], bool) {
+ ix.mu.RLock()
+ defer ix.mu.RUnlock()
+ t, ok := ix.trees[model]
+ return t, ok
+}
+
+func (ix *Index) tree(model string) *radixtree.Tree[ReplicaKey] {
+ ix.mu.RLock()
+ t, ok := ix.trees[model]
+ ix.mu.RUnlock()
+ if ok {
+ return t
+ }
+ ix.mu.Lock()
+ defer ix.mu.Unlock()
+ if t, ok = ix.trees[model]; ok {
+ return t
+ }
+ t = radixtree.New[ReplicaKey](radixtree.Options{TTL: ix.cfg.TTL, HalfLife: ix.cfg.HalfLife})
+ ix.trees[model] = t
+ return t
+}
+
+func (ix *Index) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
+ t := ix.tree(model)
+ var d PrefixDecision
+ // WeightsFor computes every candidate weight in a single tree walk and
+ // returns a map pre-populated with an entry (weight 0 by default) for every
+ // requested candidate. Candidacy is therefore exactly "is a key in weights",
+ // so we derive the hot-match membership check from it rather than building a
+ // second set.
+ weights := t.WeightsFor(candidates, now)
+ if len(chain) > 0 {
+ if key, depth, ok := t.LongestMatch(chain, now); ok {
+ // LongestMatch searches the whole tree, so the deepest match can be
+ // a replica that is offline / unloaded / not in the candidate set.
+ // Treating that as a hot match produces a false forced-disturb signal
+ // upstream (the warm replica was absent, not load-saturated). Only honor
+ // the match when the matched replica is an actual candidate; otherwise
+ // fall back to cold placement.
+ if _, ok := weights[key]; ok {
+ d.Hot = key
+ d.HasHot = true
+ d.MatchRatio = float64(depth) / float64(len(chain))
+ }
+ }
+ }
+ // Cold order: candidates ascending by cacheWeight, tie-break by NodeID then
+ // Replica. The sort comparator reads precomputed weights instead of triggering
+ // an O(tree size) Weight call per comparison. With at most one candidate the
+ // input order is already the cold order, so skip the sort.
+ order := make([]ReplicaKey, len(candidates))
+ copy(order, candidates)
+ if len(order) > 1 {
+ sort.Slice(order, func(i, j int) bool {
+ if weights[order[i]] != weights[order[j]] {
+ return weights[order[i]] < weights[order[j]]
+ }
+ return order[i].less(order[j])
+ })
+ }
+ d.ColdOrder = order
+ return d
+}
+
+func (ix *Index) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
+ if len(chain) == 0 || key.NodeID == "" {
+ return false
+ }
+ t := ix.tree(model)
+ // New/extended iff the current deepest match for this exact chain is not
+ // already this replica at full depth.
+ cur, depth, ok := t.LongestMatch(chain, now)
+ t.Insert(chain, key, now)
+ return !ok || depth < len(chain) || cur != key
+}
+
+// Invalidate drops all entries for ONE replica. It never interns an empty tree
+// (a registry chokepoint fires Invalidate for every replica removal of every
+// model, including round-robin models that never used the prefix cache, so
+// lazily creating a tree here would grow the trees map unboundedly).
+func (ix *Index) Invalidate(model string, key ReplicaKey) {
+ if t, ok := ix.existingTree(model); ok {
+ t.RemoveFunc(func(k ReplicaKey) bool { return k == key })
+ }
+}
+
+// InvalidateNode drops entries for ALL replicas of nodeID. Like Invalidate it
+// does not intern an empty tree.
+func (ix *Index) InvalidateNode(model, nodeID string) {
+ if t, ok := ix.existingTree(model); ok {
+ t.RemoveFunc(func(k ReplicaKey) bool { return k.NodeID == nodeID })
+ }
+}
+
+func (ix *Index) Evict(now time.Time) {
+ ix.mu.RLock()
+ defer ix.mu.RUnlock()
+ for _, t := range ix.trees {
+ t.Evict(now)
+ }
+}
diff --git a/core/services/nodes/prefixcache/index_test.go b/core/services/nodes/prefixcache/index_test.go
new file mode 100644
index 000000000..4719551d1
--- /dev/null
+++ b/core/services/nodes/prefixcache/index_test.go
@@ -0,0 +1,169 @@
+package prefixcache_test
+
+import (
+ "sync"
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC)
+
+var _ = Describe("Index provider", func() {
+ cfg := prefixcache.DefaultConfig()
+
+ It("returns no hot match before anything is observed", func() {
+ idx := prefixcache.NewIndex(cfg)
+ d := idx.Decide("m", []uint64{1, 2, 3}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
+ Expect(d.HasHot).To(BeFalse())
+ // cold order present (all weights zero -> deterministic by node id)
+ Expect(d.ColdOrder).To(ConsistOf(rk("A", 0), rk("B", 0)))
+ })
+
+ It("returns the observed replica as hot match with the right ratio", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(rk("A", 0)))
+ Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
+ })
+
+ It("orders cold candidates by ascending cacheWeight", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1}, rk("A", 0), t0)
+ idx.Observe("m", []uint64{2}, rk("A", 0), t0) // A weight 2
+ idx.Observe("m", []uint64{3}, rk("B", 0), t0) // B weight 1
+ d := idx.Decide("m", []uint64{9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
+ Expect(d.HasHot).To(BeFalse())
+ Expect(d.ColdOrder).To(Equal([]prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)})) // B lower weight first
+ })
+
+ It("drops the hot match when the matched replica is not in the candidate set", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ // A holds the longest match, but A is not a candidate (offline /
+ // unloaded). The matched replica must be ignored so cold placement runs
+ // and no false forced-disturb fires upstream.
+ d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("B", 0), rk("C", 0)}, t0)
+ Expect(d.HasHot).To(BeFalse())
+ Expect(d.MatchRatio).To(Equal(0.0))
+ Expect(d.ColdOrder).To(ConsistOf(rk("B", 0), rk("C", 0)))
+ })
+
+ It("returns a hot match for a query that only shares a prefix with an observed chain", func() {
+ // The real-world case: a replica served chain [1,2,3,4]; a new request
+ // shares the leading block [1,2,3] but diverges at the tail ([1,2,3,9]).
+ // With prefix matching (value recorded at every node) Decide must still
+ // route to the warm replica, matching at the depth of the shared prefix.
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ d := idx.Decide("m", []uint64{1, 2, 3, 9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(rk("A", 0)))
+ Expect(d.MatchRatio).To(BeNumerically("~", 3.0/4.0, 0.001)) // shared [1,2,3] of len-4 query
+ })
+
+ It("keeps the hot match when the matched replica is a candidate", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0)
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(rk("A", 0)))
+ Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001))
+ })
+
+ It("tracks affinity per replica, not per node", func() {
+ // Two replicas on the SAME node, each serving a different chain that share
+ // a leading block. The hot match for a query extending chain1 must be the
+ // EXACT replica that served chain1, not the other replica on the same node.
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) // replica 0 owns [1,2,3,4]
+ idx.Observe("m", []uint64{1, 2, 5, 6}, rk("A", 1), t0) // replica 1 owns [1,2,5,6]
+ cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
+ d := idx.Decide("m", []uint64{1, 2, 3, 4, 7}, cands, t0)
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(rk("A", 0))) // distinct replicas on one node have distinct affinity
+ d2 := idx.Decide("m", []uint64{1, 2, 5, 6, 7}, cands, t0)
+ Expect(d2.HasHot).To(BeTrue())
+ Expect(d2.Hot).To(Equal(rk("A", 1)))
+ })
+
+ It("Invalidate drops one replica while InvalidateNode drops all replicas of a node", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ idx.Observe("m", []uint64{5, 6, 7, 8}, rk("A", 1), t0)
+ cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
+
+ // Invalidate replica 0 only: replica 1 survives.
+ idx.Invalidate("m", rk("A", 0))
+ Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
+ d1 := idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0)
+ Expect(d1.HasHot).To(BeTrue())
+ Expect(d1.Hot).To(Equal(rk("A", 1)))
+
+ // Re-observe both, then InvalidateNode drops BOTH replicas.
+ idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0)
+ idx.InvalidateNode("m", "A")
+ Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse())
+ Expect(idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0).HasHot).To(BeFalse())
+ })
+
+ It("forgets a replica on Invalidate", func() {
+ idx := prefixcache.NewIndex(cfg)
+ idx.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
+ idx.Invalidate("m", rk("A", 0))
+ d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
+ Expect(d.HasHot).To(BeFalse())
+ })
+
+ It("does not intern an empty tree when invalidating a model that has none", func() {
+ idx := prefixcache.NewIndex(cfg)
+ Expect(idx.TreeCountForTest()).To(Equal(0))
+ // Round-robin model that never used the prefix cache: invalidating a
+ // replica removal must be a no-op and must not retain a tree.
+ idx.Invalidate("never-cached", rk("A", 0))
+ idx.Invalidate("never-cached", rk("B", 0))
+ idx.InvalidateNode("other", "C")
+ Expect(idx.TreeCountForTest()).To(Equal(0))
+ // And a Decide afterwards still works without a hot match.
+ d := idx.Decide("never-cached", []uint64{1}, []prefixcache.ReplicaKey{rk("A", 0)}, t0)
+ Expect(d.HasHot).To(BeFalse())
+ })
+
+ It("is safe for concurrent Decide/Observe/Invalidate (run with -race)", func() {
+ idx := prefixcache.NewIndex(cfg)
+ models := []string{"m1", "m2"}
+ nodes := []string{"A", "B", "C"}
+ cands := []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0), rk("C", 0)}
+ var wg sync.WaitGroup
+ for g := range 8 {
+ wg.Add(1)
+ go func(g int) {
+ defer GinkgoRecover()
+ defer wg.Done()
+ model := models[g%len(models)]
+ node := nodes[g%len(nodes)]
+ now := t0
+ for i := range 200 {
+ chain := []uint64{uint64(g), uint64(i % 7), uint64(i)}
+ switch i % 4 {
+ case 0:
+ idx.Observe(model, chain, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2}, now)
+ case 1:
+ idx.Decide(model, chain, cands, now)
+ case 2:
+ idx.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2})
+ case 3:
+ idx.InvalidateNode(model, node)
+ }
+ now = now.Add(time.Millisecond)
+ }
+ }(g)
+ }
+ wg.Wait()
+ })
+})
diff --git a/core/services/nodes/prefixcache/policy.go b/core/services/nodes/prefixcache/policy.go
new file mode 100644
index 000000000..3a87f25b7
--- /dev/null
+++ b/core/services/nodes/prefixcache/policy.go
@@ -0,0 +1,47 @@
+// Package prefixcache implements prefix-cache-aware routing for distributed
+// mode: it turns a request prompt into a chain of prefix hashes, tracks which
+// node served which prefix in an in-memory radix tree, and provides a
+// load-guarded preferred-node decision. See docs/content/features/distributed-mode.md.
+package prefixcache
+
+// RoutePolicy selects the routing strategy for a model. The zero value is
+// RoutePolicyDefault, meaning "inherit the cluster-wide default".
+type RoutePolicy int
+
+const (
+ RoutePolicyDefault RoutePolicy = iota // inherit global default
+ RoutePolicyRoundRobin // today's behavior (the floor)
+ RoutePolicyPrefixCache // cache-aware routing
+)
+
+// ParsePolicy maps a config string to a RoutePolicy. Unknown or empty strings
+// map to RoutePolicyDefault.
+func ParsePolicy(s string) RoutePolicy {
+ switch s {
+ case "round_robin":
+ return RoutePolicyRoundRobin
+ case "prefix_cache":
+ return RoutePolicyPrefixCache
+ default:
+ return RoutePolicyDefault
+ }
+}
+
+func (p RoutePolicy) String() string {
+ switch p {
+ case RoutePolicyRoundRobin:
+ return "round_robin"
+ case RoutePolicyPrefixCache:
+ return "prefix_cache"
+ default:
+ return "default"
+ }
+}
+
+// Resolve returns p unless it is Default, in which case it returns global.
+func (p RoutePolicy) Resolve(global RoutePolicy) RoutePolicy {
+ if p == RoutePolicyDefault {
+ return global
+ }
+ return p
+}
diff --git a/core/services/nodes/prefixcache/policy_test.go b/core/services/nodes/prefixcache/policy_test.go
new file mode 100644
index 000000000..5529de8a6
--- /dev/null
+++ b/core/services/nodes/prefixcache/policy_test.go
@@ -0,0 +1,29 @@
+package prefixcache_test
+
+import (
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+var _ = Describe("RoutePolicy", func() {
+ It("parses known values and defaults unknown to Default (zero)", func() {
+ Expect(prefixcache.ParsePolicy("round_robin")).To(Equal(prefixcache.RoutePolicyRoundRobin))
+ Expect(prefixcache.ParsePolicy("prefix_cache")).To(Equal(prefixcache.RoutePolicyPrefixCache))
+ Expect(prefixcache.ParsePolicy("")).To(Equal(prefixcache.RoutePolicyDefault))
+ Expect(prefixcache.ParsePolicy("bogus")).To(Equal(prefixcache.RoutePolicyDefault))
+ })
+
+ It("stringifies", func() {
+ Expect(prefixcache.RoutePolicyPrefixCache.String()).To(Equal("prefix_cache"))
+ Expect(prefixcache.RoutePolicyRoundRobin.String()).To(Equal("round_robin"))
+ })
+
+ It("resolves per-model against a global default", func() {
+ Expect(prefixcache.RoutePolicyDefault.Resolve(prefixcache.RoutePolicyPrefixCache)).
+ To(Equal(prefixcache.RoutePolicyPrefixCache))
+ Expect(prefixcache.RoutePolicyRoundRobin.Resolve(prefixcache.RoutePolicyPrefixCache)).
+ To(Equal(prefixcache.RoutePolicyRoundRobin))
+ })
+})
diff --git a/core/services/nodes/prefixcache/prefixcache_suite_test.go b/core/services/nodes/prefixcache/prefixcache_suite_test.go
new file mode 100644
index 000000000..6c2c86073
--- /dev/null
+++ b/core/services/nodes/prefixcache/prefixcache_suite_test.go
@@ -0,0 +1,13 @@
+package prefixcache_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestPrefixCache(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "PrefixCache Suite")
+}
diff --git a/core/services/nodes/prefixcache/pressure.go b/core/services/nodes/prefixcache/pressure.go
new file mode 100644
index 000000000..29e3e0f15
--- /dev/null
+++ b/core/services/nodes/prefixcache/pressure.go
@@ -0,0 +1,82 @@
+package prefixcache
+
+import (
+ "sync"
+ "time"
+)
+
+// Pressure is a concurrency-safe rolling per-model counter of forced-disturb
+// events. A forced-disturb is recorded by the router when a usable hot prefix
+// match existed but the load guard forced the request off the warm node (see
+// SmartRouter.buildPreference). The reconciler reads Count to decide whether
+// the cache-warm replica is saturated enough to warrant a scale-up.
+//
+// Entries older than the window are dropped on both Record and Count, so the
+// slice never grows unbounded - even for a model that takes records but is
+// never Counted (e.g. one with zero loaded replicas the reconciler skips). An
+// idle model's history also decays to zero on the next read.
+type Pressure struct {
+ mu sync.Mutex
+ window time.Duration
+ events map[string][]time.Time
+}
+
+// NewPressure creates a Pressure counter that remembers events for the given
+// rolling window.
+func NewPressure(window time.Duration) *Pressure {
+ return &Pressure{
+ window: window,
+ events: make(map[string][]time.Time),
+ }
+}
+
+// pruneLocked drops entries older than cutoff, compacting in place. The cutoff
+// boundary itself is inclusive so an event exactly window-old still counts.
+// Callers must hold p.mu.
+func pruneLocked(ts []time.Time, cutoff time.Time) []time.Time {
+ kept := ts[:0]
+ for _, t := range ts {
+ if !t.Before(cutoff) {
+ kept = append(kept, t)
+ }
+ }
+ return kept
+}
+
+// Record appends a forced-disturb timestamp for the model and prunes entries
+// older than the window, so the per-model slice stays bounded regardless of how
+// often Count runs.
+func (p *Pressure) Record(model string, now time.Time) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ cutoff := now.Add(-p.window)
+ kept := append(pruneLocked(p.events[model], cutoff), now)
+ p.events[model] = kept
+}
+
+// Count returns the number of records for the model within [now-window, now],
+// dropping any entries older than the window so the backing slice stays bounded.
+func (p *Pressure) Count(model string, now time.Time) int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ ts := p.events[model]
+ if len(ts) == 0 {
+ return 0
+ }
+ kept := pruneLocked(ts, now.Add(-p.window))
+ if len(kept) == 0 {
+ delete(p.events, model)
+ return 0
+ }
+ p.events[model] = kept
+ return len(kept)
+}
+
+// Reset clears all recorded events for model. Call after acting on the signal
+// (a pressure-triggered scale-up) so a single burst does not trigger repeated
+// scale-ups across consecutive ticks.
+func (p *Pressure) Reset(model string) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.events, model)
+}
diff --git a/core/services/nodes/prefixcache/pressure_test.go b/core/services/nodes/prefixcache/pressure_test.go
new file mode 100644
index 000000000..e6b741cb0
--- /dev/null
+++ b/core/services/nodes/prefixcache/pressure_test.go
@@ -0,0 +1,98 @@
+package prefixcache_test
+
+import (
+ "time"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Pressure counter", func() {
+ t0 := time.Unix(1700000000, 0)
+
+ It("counts events within the window and forgets older ones", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ p.Record("m", t0)
+ p.Record("m", t0.Add(30*time.Second))
+ Expect(p.Count("m", t0.Add(40*time.Second))).To(Equal(2))
+ Expect(p.Count("m", t0.Add(90*time.Second))).To(Equal(1)) // first expired
+ })
+
+ It("tracks pressure per model independently", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ p.Record("a", t0)
+ p.Record("a", t0.Add(10*time.Second))
+ p.Record("b", t0.Add(20*time.Second))
+ Expect(p.Count("a", t0.Add(30*time.Second))).To(Equal(2))
+ Expect(p.Count("b", t0.Add(30*time.Second))).To(Equal(1))
+ Expect(p.Count("c", t0.Add(30*time.Second))).To(Equal(0))
+ })
+
+ It("returns zero for a model that was never recorded", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ Expect(p.Count("never", t0)).To(Equal(0))
+ })
+
+ It("includes the boundary timestamp at exactly now-window", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ p.Record("m", t0)
+ // now-window == t0 exactly, so the entry is still within [now-window, now].
+ Expect(p.Count("m", t0.Add(time.Minute))).To(Equal(1))
+ // one nanosecond past the window drops it.
+ Expect(p.Count("m", t0.Add(time.Minute+1))).To(Equal(0))
+ })
+
+ It("bounds the backing slice in Record without any Count calls", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ // Record many timestamps, advancing now well past the window between
+ // each, and never call Count. Each Record must prune the entries that
+ // have fallen out of [now-window, now] so the slice cannot accumulate.
+ var last time.Time
+ for i := range 1000 {
+ last = t0.Add(time.Duration(i) * 10 * time.Second)
+ p.Record("m", last)
+ }
+ // With a 1m window and 10s spacing, at most ~7 records (the boundary is
+ // inclusive) can be within [last-window, last]. The slice must stay that
+ // bounded, never growing toward 1000.
+ Expect(p.LenForTest("m")).To(BeNumerically("<=", 7))
+ // And the in-window count must reflect only those bounded entries.
+ Expect(p.Count("m", last)).To(Equal(p.LenForTest("m")))
+ })
+
+ It("clears all recorded events on Reset", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ p.Record("m", t0)
+ p.Record("m", t0.Add(10*time.Second))
+ p.Record("m", t0.Add(20*time.Second))
+ Expect(p.Count("m", t0.Add(30*time.Second))).To(BeNumerically(">", 0))
+
+ p.Reset("m")
+
+ // After Reset the model has no in-window events even though the
+ // timestamps would otherwise still be within [now-window, now].
+ Expect(p.Count("m", t0.Add(30*time.Second))).To(Equal(0))
+ Expect(p.LenForTest("m")).To(Equal(0))
+ })
+
+ It("Reset only clears the named model", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ p.Record("a", t0)
+ p.Record("b", t0)
+ p.Reset("a")
+ Expect(p.Count("a", t0.Add(time.Second))).To(Equal(0))
+ Expect(p.Count("b", t0.Add(time.Second))).To(Equal(1))
+ })
+
+ It("does not accumulate repeated out-of-window Records", func() {
+ p := prefixcache.NewPressure(time.Minute)
+ // Each record is more than a window apart, so every Record prunes the
+ // previous one. The slice should never hold more than a single entry.
+ for i := range 100 {
+ p.Record("m", t0.Add(time.Duration(i)*2*time.Minute))
+ }
+ Expect(p.LenForTest("m")).To(Equal(1))
+ Expect(p.Count("m", t0.Add(198*time.Minute))).To(Equal(1))
+ })
+})
diff --git a/core/services/nodes/prefixcache/provider.go b/core/services/nodes/prefixcache/provider.go
new file mode 100644
index 000000000..976713aac
--- /dev/null
+++ b/core/services/nodes/prefixcache/provider.go
@@ -0,0 +1,24 @@
+package prefixcache
+
+import "time"
+
+// Provider is the seam between SmartRouter and the prefix-cache implementation.
+// The radix-tree (guessed) implementation is the only one today; a future
+// KV-event (reported) implementation can satisfy the same interface without
+// changing SmartRouter (epic #10063 / #10064). Affinity is tracked per replica:
+// each loaded replica is a separate process with its own KV cache.
+type Provider interface {
+ // Decide computes the prefix decision for a request given the candidate
+ // replicas (the selector-filtered set). It does not consult load - load
+ // filtering happens in the DB transaction.
+ Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision
+ // Observe records that the replica served the request whose prefix is chain.
+ // Returns true when the assignment was new or extended (caller broadcasts).
+ Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool
+ // Invalidate drops all entries for ONE replica.
+ Invalidate(model string, key ReplicaKey)
+ // InvalidateNode drops entries for ALL replicas of a node.
+ InvalidateNode(model, nodeID string)
+ // Evict sweeps expired entries for all models.
+ Evict(now time.Time)
+}
diff --git a/core/services/nodes/prefixcache/select.go b/core/services/nodes/prefixcache/select.go
new file mode 100644
index 000000000..9ede1f5b1
--- /dev/null
+++ b/core/services/nodes/prefixcache/select.go
@@ -0,0 +1,93 @@
+package prefixcache
+
+// ReplicaKey identifies a specific loaded replica (a backend process). Affinity
+// is tracked per replica, not per node, because each replica is a separate
+// process with its own KV cache.
+type ReplicaKey struct {
+ NodeID string
+ Replica int
+}
+
+// less reports whether a sorts before b, ordering by NodeID then Replica. It is
+// the deterministic tiebreak used wherever two replicas are otherwise equal.
+func (a ReplicaKey) less(b ReplicaKey) bool {
+ if a.NodeID != b.NodeID {
+ return a.NodeID < b.NodeID
+ }
+ return a.Replica < b.Replica
+}
+
+// Candidate is a load-eligible-or-not replica view from the registry. There is
+// one Candidate per LOADED replica: the router no longer collapses replicas per
+// node, so two replicas of the same model on the same node are two candidates.
+type Candidate struct {
+ Key ReplicaKey
+ InFlight int
+}
+
+// PrefixDecision is computed from the in-memory tree before the DB transaction.
+// Hot is the replica holding the longest prefix match and HasHot reports whether
+// there is one (a ReplicaKey has no "" sentinel). MatchRatio is matched/total
+// for that match. ColdOrder lists candidate replicas ascending by cacheWeight
+// (lowest = least valuable warm cache = best cold target).
+type PrefixDecision struct {
+ Hot ReplicaKey
+ HasHot bool
+ MatchRatio float64
+ ColdOrder []ReplicaKey
+}
+
+// Select implements filter-then-score per replica: keep candidates within the
+// load guard (relative to the min in-flight across ALL candidate replicas), then
+// prefer the exact hot-match replica, else the lowest-cacheWeight eligible
+// replica via ColdOrder, else a deterministic eligible fallback (least in-flight,
+// tiebreak by NodeID then Replica). Returns (ReplicaKey{}, false) when nothing is
+// selectable.
+func Select(cands []Candidate, d PrefixDecision, cfg Config) (ReplicaKey, bool) {
+ if len(cands) == 0 {
+ return ReplicaKey{}, false
+ }
+ minIF := cands[0].InFlight
+ for _, c := range cands {
+ minIF = min(minIF, c.InFlight)
+ }
+ eligible := map[ReplicaKey]bool{}
+ for _, c := range cands {
+ withinAbs := c.InFlight <= minIF+cfg.BalanceAbsThreshold
+ // +1 softens the relative guard when minIF==0 so a zero baseline does
+ // not require exact-zero in-flight; the absolute guard governs near 0.
+ withinRel := float64(c.InFlight) <= float64(minIF)*cfg.BalanceRelThreshold+1
+ if withinAbs && withinRel {
+ eligible[c.Key] = true
+ }
+ }
+ // Hot match wins if eligible and strong enough.
+ if d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && eligible[d.Hot] {
+ return d.Hot, true
+ }
+ // Cold placement: lowest cacheWeight eligible replica.
+ for _, k := range d.ColdOrder {
+ if eligible[k] {
+ return k, true
+ }
+ }
+ // Deterministic eligible fallback: least in-flight, tiebreak NodeID then
+ // Replica. ColdOrder may not cover the eligible set (the caller may pass an
+ // empty ColdOrder), so this guarantees Select still returns the best eligible
+ // replica rather than failing.
+ var best Candidate
+ found := false
+ for _, c := range cands {
+ if !eligible[c.Key] {
+ continue
+ }
+ if !found || c.InFlight < best.InFlight ||
+ (c.InFlight == best.InFlight && c.Key.less(best.Key)) {
+ best, found = c, true
+ }
+ }
+ if found {
+ return best.Key, true
+ }
+ return ReplicaKey{}, false
+}
diff --git a/core/services/nodes/prefixcache/select_test.go b/core/services/nodes/prefixcache/select_test.go
new file mode 100644
index 000000000..f11bc1879
--- /dev/null
+++ b/core/services/nodes/prefixcache/select_test.go
@@ -0,0 +1,139 @@
+package prefixcache_test
+
+import (
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+func rk(node string, replica int) prefixcache.ReplicaKey {
+ return prefixcache.ReplicaKey{NodeID: node, Replica: replica}
+}
+
+var _ = Describe("Select (filter-then-score)", func() {
+ cfg := prefixcache.DefaultConfig() // abs=2, rel=1.5, minMatch=0.3
+
+ cand := func(node string, replica, inflight int) prefixcache.Candidate {
+ return prefixcache.Candidate{Key: rk(node, replica), InFlight: inflight}
+ }
+
+ It("returns the hot-match replica when it is load-eligible and match >= min", func() {
+ cands := []prefixcache.Candidate{cand("A", 0, 1), cand("B", 0, 0)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("A", 0), HasHot: true, MatchRatio: 0.5,
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("A", 0))) // A in-flight 1 <= min(0)+2 and <= 0*1.5+1
+ })
+
+ It("rejects the hot match when it violates the absolute load guard", func() {
+ cands := []prefixcache.Candidate{cand("A", 0, 5), cand("B", 0, 0)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("A", 0), HasHot: true, MatchRatio: 0.9,
+ ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("B", 0))) // A 5 > min(0)+2, drop to cold placement
+ })
+
+ It("ignores a match below min_prefix_match", func() {
+ cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("A", 0), HasHot: true, MatchRatio: 0.2, // < 0.3
+ ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("B", 0))) // cold placement: lowest cacheWeight eligible
+ })
+
+ It("cold-places to lowest-cacheWeight replica within the eligible subset", func() {
+ cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0), cand("C", 0, 9)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ ColdOrder: []prefixcache.ReplicaKey{rk("C", 0), rk("B", 0), rk("A", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("B", 0))) // C filtered out by load; B is next in cold order
+ })
+
+ It("returns false when no candidates", func() {
+ _, ok := prefixcache.Select(nil, prefixcache.PrefixDecision{}, cfg)
+ Expect(ok).To(BeFalse())
+ })
+
+ It("falls back to the least-in-flight eligible replica when ColdOrder is empty", func() {
+ // Deterministic eligible fallback: ColdOrder does not cover the eligible
+ // set, so Select picks the least-in-flight eligible replica, tiebreaking by
+ // NodeID then Replica.
+ cands := []prefixcache.Candidate{cand("B", 1, 0), cand("B", 0, 0), cand("A", 0, 0)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("A", 0))) // all in-flight 0; A < B; within B, replica 0 < 1
+ })
+
+ It("returns false when no candidate is eligible", func() {
+ // Impossible in practice (min is always eligible) but guards the contract:
+ // an empty eligible set yields no selection. Here every candidate is the
+ // min, so one is always eligible; instead test the documented zero value.
+ cands := []prefixcache.Candidate{cand("A", 0, 0)}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("A", 0)))
+ })
+})
+
+var _ = Describe("Select replica granularity", func() {
+ cfg := prefixcache.DefaultConfig()
+
+ It("distinguishes two replicas of the same node as separate candidates", func() {
+ // Two replicas on NodeA: replica 0 is hot but saturated, replica 1 is cool.
+ // The round-robin floor must drop to replica 1, NOT collapse them per node.
+ cands := []prefixcache.Candidate{
+ {Key: rk("A", 0), InFlight: 50},
+ {Key: rk("A", 1), InFlight: 0},
+ }
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
+ ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("A", 1)))
+ })
+
+ It("pins back to the exact hot replica when it is within slack", func() {
+ cands := []prefixcache.Candidate{
+ {Key: rk("A", 0), InFlight: 1},
+ {Key: rk("A", 1), InFlight: 0},
+ }
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0,
+ ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("A", 0))) // within slack -> reuse exact replica
+ })
+})
+
+var _ = Describe("Select round-robin floor invariant", func() {
+ It("never pins to a saturated hot replica (round-robin floor)", func() {
+ cfg := prefixcache.DefaultConfig()
+ cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 50}, {Key: rk("cool", 0), InFlight: 0}}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
+ ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("cool", 0)))
+ })
+
+ It("improves reuse when balanced", func() {
+ cfg := prefixcache.DefaultConfig()
+ cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 1}, {Key: rk("cool", 0), InFlight: 0}}
+ got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{
+ Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0,
+ ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)},
+ }, cfg)
+ Expect(ok).To(BeTrue())
+ Expect(got).To(Equal(rk("hot", 0))) // within slack -> reuse
+ })
+})
diff --git a/core/services/nodes/prefixcache/sync.go b/core/services/nodes/prefixcache/sync.go
new file mode 100644
index 000000000..421fea7e2
--- /dev/null
+++ b/core/services/nodes/prefixcache/sync.go
@@ -0,0 +1,91 @@
+package prefixcache
+
+import (
+ "time"
+
+ "github.com/mudler/LocalAI/core/services/messaging"
+ "github.com/mudler/xlog"
+)
+
+// publisher is the minimal slice of messaging.Client that Sync needs.
+type publisher interface {
+ Publish(subject string, v any) error
+}
+
+// Sync wraps an Index, broadcasting new/extended observations to peers and
+// applying peers' broadcasts. It is the cross-frontend coherence layer.
+type Sync struct {
+ idx Provider
+ pub publisher
+}
+
+func NewSync(idx Provider, pub publisher) *Sync { return &Sync{idx: idx, pub: pub} }
+
+// Observe records locally and, if new/extended, broadcasts to peers. It returns
+// whether the local index treated the assignment as new or extended, so Sync
+// satisfies prefixcache.Provider.
+func (s *Sync) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool {
+ changed := s.idx.Observe(model, chain, key, now)
+ if changed && s.pub != nil {
+ ev := messaging.PrefixCacheObserveEvent{Model: model, Chain: chain, NodeID: key.NodeID, Replica: key.Replica}
+ if err := s.pub.Publish(messaging.SubjectPrefixCacheObserve, ev); err != nil {
+ xlog.Debug("prefixcache: observe publish failed", "error", err)
+ }
+ }
+ return changed
+}
+
+// Invalidate drops the local entry for one replica and broadcasts to peers. The
+// local drop is a no-op for models that were never cached (Index.Invalidate does
+// not intern a tree). The broadcast is UNCONDITIONAL (when a publisher is
+// configured): the registry chokepoint fires for every replica removal, and a
+// peer frontend may hold a stale entry for the model even when THIS frontend
+// never cached it, so gating the broadcast on local-tree existence would drop
+// cross-frontend invalidations and leave peers routing to a removed replica
+// until their TTL.
+func (s *Sync) Invalidate(model string, key ReplicaKey) {
+ s.idx.Invalidate(model, key)
+ if s.pub != nil {
+ ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: key.NodeID, Replica: key.Replica}
+ if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
+ xlog.Debug("prefixcache: invalidate publish failed", "error", err)
+ }
+ }
+}
+
+// InvalidateNode drops the local entries for ALL replicas of node and broadcasts
+// to peers. Like Invalidate the broadcast is unconditional for cross-frontend
+// coherence. A negative Replica on the wire means "all replicas of the node".
+func (s *Sync) InvalidateNode(model, node string) {
+ s.idx.InvalidateNode(model, node)
+ if s.pub != nil {
+ ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: node, Replica: -1}
+ if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil {
+ xlog.Debug("prefixcache: invalidate-node publish failed", "error", err)
+ }
+ }
+}
+
+// ApplyObserve applies a peer observe event locally (no re-broadcast).
+func (s *Sync) ApplyObserve(ev messaging.PrefixCacheObserveEvent, now time.Time) {
+ s.idx.Observe(ev.Model, ev.Chain, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica}, now)
+}
+
+// ApplyInvalidate applies a peer invalidate event locally (no re-broadcast). A
+// negative Replica targets all replicas of the node.
+func (s *Sync) ApplyInvalidate(ev messaging.PrefixCacheInvalidateEvent) {
+ if ev.Replica < 0 {
+ s.idx.InvalidateNode(ev.Model, ev.NodeID)
+ return
+ }
+ s.idx.Invalidate(ev.Model, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica})
+}
+
+// Decide delegates to the wrapped index.
+func (s *Sync) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision {
+ return s.idx.Decide(model, chain, candidates, now)
+}
+
+// Evict delegates eviction of expired entries to the wrapped index. It does not
+// broadcast: each frontend evicts its own copy on its own TTL clock.
+func (s *Sync) Evict(now time.Time) { s.idx.Evict(now) }
diff --git a/core/services/nodes/prefixcache/sync_test.go b/core/services/nodes/prefixcache/sync_test.go
new file mode 100644
index 000000000..001454618
--- /dev/null
+++ b/core/services/nodes/prefixcache/sync_test.go
@@ -0,0 +1,118 @@
+package prefixcache_test
+
+import (
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/messaging"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+)
+
+type fakePub struct{ published []any }
+
+func (f *fakePub) Publish(subject string, v any) error {
+ f.published = append(f.published, v)
+ return nil
+}
+
+// Sync must satisfy the Provider seam so SmartRouter can hold a single
+// prefixcache.Provider that broadcasts via NATS.
+var _ prefixcache.Provider = (*prefixcache.Sync)(nil)
+
+var _ = Describe("Sync", func() {
+ It("delegates Evict to the wrapped index", func() {
+ cfg := prefixcache.DefaultConfig()
+ cfg.TTL = time.Minute
+ idx := prefixcache.NewIndex(cfg)
+ s := prefixcache.NewSync(idx, &fakePub{})
+ s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
+ // Before TTL: still hot.
+ Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0).HasHot).To(BeTrue())
+ // After TTL via Sync.Evict: entry is swept.
+ s.Evict(t0.Add(2 * time.Minute))
+ Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0.Add(2*time.Minute)).HasHot).To(BeFalse())
+ })
+
+ It("publishes an observe event with the replica when Observe is new", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ pub := &fakePub{}
+ s := prefixcache.NewSync(idx, pub)
+ s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // first time -> publish
+ Expect(pub.published).To(HaveLen(1))
+ ev := pub.published[0].(messaging.PrefixCacheObserveEvent)
+ Expect(ev.NodeID).To(Equal("A"))
+ Expect(ev.Replica).To(Equal(1))
+ s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // same -> no publish
+ Expect(pub.published).To(HaveLen(1))
+ })
+
+ It("broadcasts an invalidate even for a model with no local tree, without interning one", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ pub := &fakePub{}
+ s := prefixcache.NewSync(idx, pub)
+ // A peer frontend may hold a stale entry for this model even though THIS
+ // frontend never cached it, so the invalidate MUST be broadcast for
+ // cross-frontend coherence. The local drop must still not intern a tree.
+ s.Invalidate("never-cached", rk("A", 0))
+ Expect(pub.published).To(HaveLen(1))
+ ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
+ Expect(ev.NodeID).To(Equal("A"))
+ Expect(ev.Replica).To(Equal(0))
+ Expect(idx.TreeCountForTest()).To(Equal(0))
+ })
+
+ It("broadcasts an invalidate for a cached replica too", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ pub := &fakePub{}
+ s := prefixcache.NewSync(idx, pub)
+ s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) // creates the tree (also publishes observe)
+ pub.published = nil
+ s.Invalidate("m", rk("A", 0))
+ Expect(pub.published).To(HaveLen(1))
+ Expect(pub.published[0]).To(BeAssignableToTypeOf(messaging.PrefixCacheInvalidateEvent{}))
+ })
+
+ It("broadcasts a node-wide invalidate with a negative replica", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ pub := &fakePub{}
+ s := prefixcache.NewSync(idx, pub)
+ s.InvalidateNode("m", "A")
+ Expect(pub.published).To(HaveLen(1))
+ ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent)
+ Expect(ev.NodeID).To(Equal("A"))
+ Expect(ev.Replica).To(BeNumerically("<", 0))
+ })
+
+ It("applies a peer observe event into the local index with the replica", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ s := prefixcache.NewSync(idx, &fakePub{})
+ s.ApplyObserve(messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 2}, t0)
+ d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 2)}, t0)
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(rk("A", 2)))
+ })
+
+ It("applies a peer single-replica invalidate", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ s := prefixcache.NewSync(idx, &fakePub{})
+ s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
+ s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
+ s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0})
+ cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
+ Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
+ Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeTrue())
+ })
+
+ It("applies a peer node-wide invalidate when replica is negative", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ s := prefixcache.NewSync(idx, &fakePub{})
+ s.Observe("m", []uint64{1, 2}, rk("A", 0), t0)
+ s.Observe("m", []uint64{3, 4}, rk("A", 1), t0)
+ s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1})
+ cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)}
+ Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse())
+ Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeFalse())
+ })
+})
diff --git a/core/services/nodes/reconciler.go b/core/services/nodes/reconciler.go
index 9606c3f52..bf3cfd18f 100644
--- a/core/services/nodes/reconciler.go
+++ b/core/services/nodes/reconciler.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/mudler/LocalAI/core/services/advisorylock"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
grpcclient "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/xlog"
"github.com/nats-io/nats.go"
@@ -56,6 +57,13 @@ type ReplicaReconciler struct {
// probeStaleAfter: only probe node_models rows older than this so we
// don't hammer every worker every tick for models we just heard from.
probeStaleAfter time.Duration
+ // pressure is the shared forced-disturb counter written by the router. When
+ // a model's count within the Pressure's rolling window reaches pressureThreshold the
+ // reconciler treats its cache-warm replica as saturated and scales up,
+ // subject to the same MaxReplicas/capacity/UnsatisfiableUntil machinery as
+ // the other scale-up paths. nil disables this signal (a true no-op).
+ pressure *prefixcache.Pressure
+ pressureThreshold int
}
// ModelScheduler abstracts the scheduling logic needed by the reconciler.
@@ -83,6 +91,12 @@ type ReplicaReconcilerOptions struct {
Interval time.Duration // default 30s
ScaleDownDelay time.Duration // default 5m
ProbeStaleAfter time.Duration // default 2m
+ // Pressure is the shared forced-disturb counter written by the router. nil
+ // disables the cache-saturation autoscale signal (a true no-op).
+ Pressure *prefixcache.Pressure
+ // PressureThreshold is the forced-disturb count within PressureWindow that
+ // triggers a scale-up. Default prefixcache.DefaultConfig().PressureScaleThreshold (1).
+ PressureThreshold int
}
// NewReplicaReconciler creates a new ReplicaReconciler.
@@ -103,16 +117,22 @@ func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
if prober == nil {
prober = grpcModelProber{token: opts.RegistrationToken}
}
+ pressureThreshold := opts.PressureThreshold
+ if pressureThreshold == 0 {
+ pressureThreshold = prefixcache.DefaultConfig().PressureScaleThreshold
+ }
return &ReplicaReconciler{
- registry: opts.Registry,
- scheduler: opts.Scheduler,
- unloader: opts.Unloader,
- adapter: opts.Adapter,
- prober: prober,
- db: opts.DB,
- interval: interval,
- scaleDownDelay: scaleDownDelay,
- probeStaleAfter: probeStaleAfter,
+ registry: opts.Registry,
+ scheduler: opts.Scheduler,
+ unloader: opts.Unloader,
+ adapter: opts.Adapter,
+ prober: prober,
+ db: opts.DB,
+ interval: interval,
+ scaleDownDelay: scaleDownDelay,
+ probeStaleAfter: probeStaleAfter,
+ pressure: opts.Pressure,
+ pressureThreshold: pressureThreshold,
}
}
@@ -409,13 +429,25 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
}
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
"current", current, "min", cfg.MinReplicas, "adding", needed)
- rc.scaleUp(ctx, cfg, needed)
- // Successful (or partial) scale-up clears the hysteresis so a future
- // dip starts fresh.
- _ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
+ if rc.scaleUp(ctx, cfg, needed) {
+ // A real (or partial) scale-up clears the hysteresis so a future
+ // dip starts fresh. If scaleUp added nothing (scheduler errored or
+ // no node could be loaded) we leave the hysteresis intact so the
+ // next tick retries from where it left off rather than resetting
+ // the unsatisfiable counter on a failed attempt.
+ _ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
+ }
return
}
+ // scaledUp tracks whether a scale-up already fired in this tick. The two
+ // scale-up paths below (busy-burst and pressure) share the single `current`
+ // value read once above; scaleUp does not re-check it. So at most one of
+ // them may fire per tick, otherwise a model that is both busy AND over the
+ // pressure threshold would scale +2 and could overshoot MaxReplicas by one.
+ // Scale-down is also skipped in a tick that scaled up.
+ scaledUp := false
+
// 2. Auto-scale up if all replicas are busy
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
if rc.allReplicasBusy(ctx, cfg.ModelName) {
@@ -432,17 +464,63 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
}
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
"current", current)
- rc.scaleUp(ctx, cfg, 1)
+ // Only mark the tick as having scaled up if a replica was actually
+ // added. On a failed scaleUp, leave scaledUp false so the pressure
+ // path below and the scale-down logic still apply as they would
+ // have if the busy-burst path had not run.
+ scaledUp = rc.scaleUp(ctx, cfg, 1)
}
}
- // 3. Scale down idle replicas above minimum
- floor := cfg.MinReplicas
- if floor < 1 {
- floor = 1
+ // 2b. Auto-scale up on prefix-cache forced-disturb pressure. A forced-disturb
+ // is recorded by the router when a request had a usable hot prefix match
+ // but the load guard forced it off the warm node: the cache-warm replica
+ // is saturated. We reuse the same MaxReplicas + capacity guards as the
+ // busy-burst path, and the same UnsatisfiableUntil cooldown gates this
+ // block at the top of reconcileModel, so a no-capacity model will not
+ // spin. Pressure never overrides MaxReplicas or force-evicts.
+ //
+ // Skipped when the busy-burst path already scaled up this tick: at most
+ // one scaleUp(+1) per tick (see scaledUp above).
+ if !scaledUp && rc.pressure != nil && current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
+ if pressureCount := rc.pressure.Count(cfg.ModelName, time.Now()); pressureCount >= rc.pressureThreshold {
+ candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg)
+ if selectorMatched {
+ capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs)
+ if capErr == nil && capacity > 0 {
+ xlog.Info("Reconciler: prefix-cache forced-disturb pressure, scaling up",
+ "model", cfg.ModelName, "current", current,
+ "pressure", pressureCount,
+ "threshold", rc.pressureThreshold)
+ if rc.scaleUp(ctx, cfg, 1) {
+ scaledUp = true
+ // Consume the signal only on a real scale-up:
+ // Pressure.Count is non-draining (it prunes only by
+ // age), so a single burst stays in-window for the whole
+ // window and would re-fire scaleUp on every tick. Reset
+ // clears the model's events so a fresh scale-up needs
+ // fresh forced-disturbs to accumulate. If scaleUp added
+ // nothing (scheduler errored or no node could be loaded)
+ // we preserve the signal so the next tick retries off
+ // the same accumulated pressure instead of having to
+ // re-accumulate a full window from scratch.
+ rc.pressure.Reset(cfg.ModelName)
+ }
+ }
+ // No capacity: transient demand, not a misconfig - let the next
+ // tick retry naturally (mirrors the busy-burst path's choice not
+ // to enter cooldown for burst load).
+ }
+ }
}
- if int(current) > floor {
- rc.scaleDownIdle(ctx, cfg, int(current), floor)
+
+ // 3. Scale down idle replicas above minimum. Skipped in a tick that already
+ // scaled up so we never scale up and down in the same pass.
+ if !scaledUp {
+ floor := max(cfg.MinReplicas, 1)
+ if int(current) > floor {
+ rc.scaleDownIdle(ctx, cfg, int(current), floor)
+ }
}
}
@@ -470,10 +548,17 @@ func (rc *ReplicaReconciler) markCapacityProblem(ctx context.Context, modelName,
// scaleUp schedules additional replicas of the model. Callers in
// reconcileModel are expected to have already capped `count` against
// ClusterCapacityForModel so this function never tries to overshoot.
-func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
+//
+// Returns true if at least one replica was actually scheduled. Callers use
+// this to gate signal-consuming side effects (Pressure.Reset,
+// ClearUnsatisfiable) on a real scale-up: a failed/no-op scaleUp must not
+// discard the accumulated forced-disturb pressure or clear the unsatisfiable
+// hysteresis, otherwise the signal has to re-accumulate from scratch and the
+// next tick can't simply retry.
+func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) bool {
if rc.scheduler == nil {
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
- return
+ return false
}
// Resolve selector → candidate node IDs (nil when no selector → "any
@@ -481,18 +566,21 @@ func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingCon
// reconcileModel, but defensively short-circuit here too.
candidateNodeIDs, ok := rc.candidateNodeIDsForSelector(ctx, cfg)
if !ok {
- return
+ return false
}
+ scheduled := 0
for i := 0; i < count; i++ {
node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
if err != nil {
xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
"attempt", i+1, "error", err)
- return // stop trying on first failure
+ break // stop trying on first failure
}
+ scheduled++
xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
}
+ return scheduled > 0
}
// scaleDownIdle removes idle replicas above the floor.
diff --git a/core/services/nodes/reconciler_test.go b/core/services/nodes/reconciler_test.go
index 0697374a8..ce8800caf 100644
--- a/core/services/nodes/reconciler_test.go
+++ b/core/services/nodes/reconciler_test.go
@@ -2,12 +2,14 @@ package nodes
import (
"context"
+ "errors"
"runtime"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/core/services/testutil"
"gorm.io/gorm"
)
@@ -245,6 +247,225 @@ var _ = Describe("ReplicaReconciler", func() {
})
})
+ Describe("Forced-disturb pressure autoscale (Phase 6)", func() {
+ It("scales up when pressure exceeds threshold, replicas= threshold every tick and
+ // drives the model toward MaxReplicas off a single burst.
+ node := registerNode("consume-node", "10.0.0.64:50051")
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "consume-model", 0, "loaded", "addr1", 0)).To(Succeed())
+ setSchedulingConfig("consume-model", 1, 4, "")
+
+ pressure := prefixcache.NewPressure(time.Minute)
+ now := time.Now()
+ pressure.Record("consume-model", now)
+ pressure.Record("consume-model", now)
+ pressure.Record("consume-model", now)
+
+ scheduler := &fakeScheduler{scheduleNode: node}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ Pressure: pressure,
+ })
+
+ // First tick: pressure above threshold → one scale-up.
+ reconciler.reconcile(context.Background())
+ Expect(scheduler.scheduleCalls).To(HaveLen(1),
+ "first tick must scale up once on the burst")
+
+ // Second tick: the burst's events are still inside the window, but
+ // the first scale-up Reset them, so no further scale-up occurs.
+ reconciler.reconcile(context.Background())
+ Expect(scheduler.scheduleCalls).To(HaveLen(1),
+ "a single burst must not re-trigger scale-up on the next in-window tick")
+ })
+
+ It("does not consume the pressure signal when scaleUp fails", func() {
+ // Pressure above threshold and capacity exists, but the scheduler
+ // errors so no replica is actually added. The forced-disturb signal
+ // must be preserved (NOT Reset) so the next tick retries the
+ // scale-up off the same accumulated pressure, instead of having to
+ // re-accumulate a full window of forced-disturbs from scratch.
+ node := registerNode("fail-node", "10.0.0.66:50051")
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "fail-model", 0, "loaded", "addr1", 0)).To(Succeed())
+ setSchedulingConfig("fail-model", 1, 4, "")
+
+ pressure := prefixcache.NewPressure(time.Minute)
+ now := time.Now()
+ pressure.Record("fail-model", now)
+ pressure.Record("fail-model", now)
+ pressure.Record("fail-model", now)
+ Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1))
+
+ // Scheduler errors: scaleUp attempts but adds nothing.
+ scheduler := &fakeScheduler{scheduleErr: errors.New("schedule boom")}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ Pressure: pressure,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(1),
+ "scaleUp must have attempted exactly one schedule call")
+ Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1),
+ "a failed scaleUp must NOT consume (Reset) the pressure signal — next tick should retry")
+ })
+
+ It("consumes the pressure signal only when scaleUp succeeds", func() {
+ // Mirror of the failure case: when the scheduler succeeds and a
+ // replica is actually added, the forced-disturb signal IS consumed
+ // (Reset to 0) so a single burst scales up only once.
+ node := registerNode("ok-node", "10.0.0.67:50051")
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "ok-model", 0, "loaded", "addr1", 0)).To(Succeed())
+ setSchedulingConfig("ok-model", 1, 4, "")
+
+ pressure := prefixcache.NewPressure(time.Minute)
+ now := time.Now()
+ pressure.Record("ok-model", now)
+ pressure.Record("ok-model", now)
+ pressure.Record("ok-model", now)
+
+ scheduler := &fakeScheduler{scheduleNode: node}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ Pressure: pressure,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(1),
+ "successful scaleUp must have scheduled one replica")
+ Expect(pressure.Count("ok-model", time.Now())).To(Equal(0),
+ "a successful scaleUp must consume (Reset) the pressure signal to 0")
+ })
+
+ It("performs at most one scale-up per tick when both busy and over pressure", func() {
+ // The single loaded replica is busy (all-replicas-busy fires) AND
+ // pressure is above threshold. Both scale-up paths are eligible in
+ // the same tick. The invariant is at-most-one scaleUp(+1) per tick,
+ // so exactly one schedule call must happen, not two.
+ node := registerNode("dual-node", "10.0.0.65:50051")
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "dual-model", 0, "loaded", "addr1", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), node.ID, "dual-model", 0)).To(Succeed())
+ setSchedulingConfig("dual-model", 1, 4, "")
+
+ pressure := prefixcache.NewPressure(time.Minute)
+ pressure.Record("dual-model", time.Now())
+ pressure.Record("dual-model", time.Now())
+
+ scheduler := &fakeScheduler{scheduleNode: node}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ Pressure: pressure,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(1),
+ "busy + pressure in one tick must still scale up by exactly one, not two")
+ })
+
+ It("does not spin when pressure is high but no capacity exists", func() {
+ // Single node, cap 1, already loaded → capacity 0. Pressure is high
+ // but there is nowhere to place a replica: must not call scheduler.
+ registerCappedNodeFn := func(name, address string, cap int) *BackendNode {
+ node := &BackendNode{
+ Name: name,
+ NodeType: NodeTypeBackend,
+ Address: address,
+ MaxReplicasPerModel: cap,
+ }
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ return node
+ }
+ node := registerCappedNodeFn("pcap-node", "10.0.0.63:50051", 1)
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "pcap-model", 0, "loaded", "addr1", 0)).To(Succeed())
+ // MaxReplicas high enough that replicas unchanged behavior.
+type RoutePreference struct {
+ PreferredNodeID string
+ PreferredReplica int
+}
+
// FindAndLockNodeWithModel atomically finds the best loaded replica of the
// given model and increments its in-flight counter within a single
// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction
@@ -758,7 +893,7 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
// NodeSelector so a cached replica on a now-excluded node isn't picked over a
// matching replica elsewhere — the selector-mismatch fall-through path used to
// trigger an eviction-busy loop when both sides had the model loaded.
-func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error) {
+func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
var nm NodeModel
var node BackendNode
@@ -781,17 +916,33 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
// stale row in the same window, and other helpers that mirror this
// JOIN need the same invariant. Belt-and-braces: status filter here
// AND the status-checked node fetch below.
- q := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
+ base := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
modelName, "loaded", StatusHealthy)
if len(candidateNodeIDs) > 0 {
- q = q.Where("node_models.node_id IN ?", candidateNodeIDs)
+ base = base.Where("node_models.node_id IN ?", candidateNodeIDs)
}
- if err := q.
- Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC").
- First(&nm).Error; err != nil {
- return err
+
+ picked := false
+ if pref != nil && pref.PreferredNodeID != "" {
+ // Lock the EXACT (node_id, replica_index) row the caller chose. The
+ // caller (prefix-cache router) has already applied the load guard
+ // per replica, so here we only require that exact replica still be
+ // loaded+healthy. Fall through to the default ORDER BY when that
+ // specific replica is not found/loaded.
+ q := base.Session(&gorm.Session{}).
+ Where("node_models.node_id = ? AND node_models.replica_index = ?", pref.PreferredNodeID, pref.PreferredReplica)
+ if err := q.First(&nm).Error; err == nil {
+ picked = true
+ }
+ }
+ if !picked {
+ if err := base.
+ Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC").
+ First(&nm).Error; err != nil {
+ return err
+ }
}
if err := tx.Model(&nm).Updates(map[string]any{
@@ -815,6 +966,47 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
return &node, &nm, nil
}
+// LoadedReplicaStats returns one ReplicaCandidate per loaded+healthy replica of
+// modelName, carrying its current in-flight count. It is a read used by the
+// prefix-cache router to apply the load guard when choosing a preferred node.
+// When candidateNodeIDs is non-empty, only replicas on those nodes are
+// returned; pass nil to consider any healthy node. The result is never nil;
+// an empty slice means no loaded replica exists.
+func (r *NodeRegistry) LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error) {
+ type row struct {
+ NodeID string
+ Address string
+ ReplicaIndex int
+ InFlight int
+ LastUsed time.Time
+ AvailableVRAM uint64
+ }
+ q := r.db.WithContext(ctx).Model(&NodeModel{}).
+ Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
+ Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
+ modelName, "loaded", StatusHealthy)
+ if len(candidateNodeIDs) > 0 {
+ q = q.Where("node_models.node_id IN ?", candidateNodeIDs)
+ }
+
+ // Narrow to only the columns the sole consumer (router buildPreference)
+ // reads: NodeID and InFlight. The other ReplicaCandidate fields stay at
+ // their zero value, which the consumer does not read. This avoids the
+ // JOIN-side available_vram fetch and the extra column transfer.
+ var rows []row
+ err := q.Select("node_models.node_id AS node_id, node_models.in_flight AS in_flight").
+ Scan(&rows).Error
+ if err != nil {
+ return nil, fmt.Errorf("loading replica stats for %s: %w", modelName, err)
+ }
+
+ out := make([]ReplicaCandidate, 0, len(rows))
+ for _, rw := range rows {
+ out = append(out, ReplicaCandidate(rw))
+ }
+ return out, nil
+}
+
// TouchNodeModel updates the last_used timestamp for LRU tracking on a single
// replica row.
func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) {
@@ -1198,8 +1390,12 @@ func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSche
}
return r.db.WithContext(ctx).
Clauses(clause.OnConflict{
- Columns: []clause.Column{{Name: "model_name"}},
- DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
+ Columns: []clause.Column{{Name: "model_name"}},
+ DoUpdates: clause.AssignmentColumns([]string{
+ "node_selector", "min_replicas", "max_replicas",
+ "route_policy", "balance_abs_threshold", "balance_rel_threshold", "min_prefix_match",
+ "updated_at",
+ }),
}).
Create(config).Error
}
diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go
index a07398909..85ae598c9 100644
--- a/core/services/nodes/registry_test.go
+++ b/core/services/nodes/registry_test.go
@@ -245,7 +245,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed())
- foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil)
+ foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(node.ID))
Expect(foundNM.ModelName).To(Equal("my-model"))
@@ -257,7 +257,7 @@ var _ = Describe("NodeRegistry", func() {
})
It("returns error when model is not loaded anywhere", func() {
- _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil)
+ _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil, nil)
Expect(err).To(HaveOccurred())
})
@@ -274,7 +274,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
- foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil)
+ foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.Name).To(Equal("lock-light"))
})
@@ -299,7 +299,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
- foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID})
+ foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(included.ID))
Expect(foundNM.NodeID).To(Equal(included.ID))
@@ -326,7 +326,7 @@ var _ = Describe("NodeRegistry", func() {
// (FindAndLockNodeWithModel atomically increments to lock the row.)
picks := make([]string, 0, 9)
for i := 0; i < 9; i++ {
- n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil)
+ n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
picks = append(picks, n.Name)
Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed())
@@ -355,7 +355,7 @@ var _ = Describe("NodeRegistry", func() {
// query must return an error so Route() falls through to schedule
// a fresh load on a matching node instead of reusing the excluded
// replica.
- _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID})
+ _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}, nil)
Expect(err).To(HaveOccurred())
})
@@ -422,7 +422,7 @@ var _ = Describe("NodeRegistry", func() {
goPick := PickBestReplica(candidates)
Expect(goPick).ToNot(BeNil())
- sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil)
+ sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
@@ -433,6 +433,124 @@ var _ = Describe("NodeRegistry", func() {
})
})
+ Describe("FindAndLockNodeWithModel preference", func() {
+ var nodeA, nodeB *BackendNode
+
+ BeforeEach(func() {
+ nodeA = makeNode("pref-a", "10.0.0.70:50051", 8_000_000_000)
+ nodeB = makeNode("pref-b", "10.0.0.71:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), nodeA, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), nodeB, true)).To(Succeed())
+ // Both loaded+healthy for model "pref-model", in_flight 0.
+ Expect(registry.SetNodeModel(context.Background(), nodeA.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), nodeB.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
+ })
+
+ It("locks the preferred node when eligible", func() {
+ node, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: nodeB.ID})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(node.ID).To(Equal(nodeB.ID))
+ Expect(nm.NodeID).To(Equal(nodeB.ID))
+
+ // in_flight is incremented atomically via gorm.Expr, so verify the
+ // persisted value through a re-fetch (the returned struct mirrors
+ // the pre-increment read, like the default-pick path).
+ persisted, err := registry.GetNodeModel(context.Background(), nodeB.ID, "pref-model", 0)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(persisted.InFlight).To(Equal(1))
+ })
+
+ It("falls back to default order when preferred not loaded", func() {
+ node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: "ZZZ"})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(node.ID).To(BeElementOf(nodeA.ID, nodeB.ID))
+ })
+
+ It("nil preference behaves like before", func() {
+ node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(node).ToNot(BeNil())
+ })
+
+ It("locks the EXACT preferred replica when the node hosts two replicas", func() {
+ // A single node hosts replica 0 and replica 1 of a model, both
+ // loaded+healthy. The preference must lock the SPECIFIC replica
+ // requested, not the least-loaded replica on the node.
+ node := makeNode("pref-multi", "10.0.0.72:50051", 16_000_000_000)
+ node.MaxReplicasPerModel = 2
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "addr0", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "addr1", 0)).To(Succeed())
+
+ // pref={node, 1} must lock replica 1 specifically.
+ gotNode, nm1, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
+ &RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 1})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(gotNode.ID).To(Equal(node.ID))
+ Expect(nm1.ReplicaIndex).To(Equal(1))
+
+ // pref={node, 0} must lock replica 0 specifically.
+ _, nm0, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
+ &RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 0})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(nm0.ReplicaIndex).To(Equal(0))
+ })
+ })
+
+ Describe("LoadedReplicaStats", func() {
+ var n1, n2, n3 *BackendNode
+
+ BeforeEach(func() {
+ n1 = makeNode("stats-1", "10.0.0.80:50051", 8_000_000_000)
+ n2 = makeNode("stats-2", "10.0.0.81:50051", 8_000_000_000)
+ n3 = makeNode("stats-3", "10.0.0.82:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n3, true)).To(Succeed())
+ // n1 loaded+busy, n2 loaded+idle, n3 has a different model only.
+ Expect(registry.SetNodeModel(context.Background(), n1.ID, "stats-model", 0, "loaded", "10.0.0.80:6000", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), n2.ID, "stats-model", 0, "loaded", "10.0.0.81:6000", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), n3.ID, "other-model", 0, "loaded", "", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
+ })
+
+ It("returns loaded healthy replicas with in-flight counts", func() {
+ stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(stats).To(HaveLen(2))
+ byNode := map[string]ReplicaCandidate{}
+ for _, s := range stats {
+ byNode[s.NodeID] = s
+ }
+ Expect(byNode).To(HaveKey(n1.ID))
+ Expect(byNode).To(HaveKey(n2.ID))
+ Expect(byNode[n1.ID].InFlight).To(Equal(2))
+ Expect(byNode[n2.ID].InFlight).To(Equal(0))
+ })
+
+ It("filters to the candidate node set when provided", func() {
+ stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", []string{n2.ID})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(stats).To(HaveLen(1))
+ Expect(stats[0].NodeID).To(Equal(n2.ID))
+ })
+
+ It("excludes unhealthy nodes", func() {
+ Expect(registry.MarkUnhealthy(context.Background(), n1.ID)).To(Succeed())
+ stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(stats).To(HaveLen(1))
+ Expect(stats[0].NodeID).To(Equal(n2.ID))
+ })
+
+ It("returns empty for a model with no loaded replicas", func() {
+ stats, err := registry.LoadedReplicaStats(context.Background(), "no-such-model", nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(stats).To(BeEmpty())
+ })
+ })
+
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {
It("transitions healthy -> unhealthy -> healthy", func() {
node := makeNode("roundtrip-node", "10.0.0.60:50051", 8_000_000_000)
@@ -632,6 +750,30 @@ var _ = Describe("NodeRegistry", func() {
Expect(fetched.MaxReplicas).To(Equal(5))
})
+ It("persists and updates route policy and thresholds", func() {
+ err := registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
+ ModelName: "prefix-cache-model", RoutePolicy: "prefix_cache",
+ BalanceAbsThreshold: 3, BalanceRelThreshold: 2.0, MinPrefixMatch: 0.4,
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ got, err := registry.GetModelScheduling(context.Background(), "prefix-cache-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(got.RoutePolicy).To(Equal("prefix_cache"))
+ Expect(got.BalanceAbsThreshold).To(Equal(3))
+ Expect(got.BalanceRelThreshold).To(BeNumerically("==", 2.0))
+ Expect(got.MinPrefixMatch).To(BeNumerically("==", 0.4))
+
+ // Update must not be dropped on conflict.
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
+ ModelName: "prefix-cache-model", RoutePolicy: "round_robin",
+ })).ToNot(HaveOccurred())
+
+ got, err = registry.GetModelScheduling(context.Background(), "prefix-cache-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(got.RoutePolicy).To(Equal("round_robin"))
+ })
+
It("lists all configs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
@@ -903,6 +1045,187 @@ var _ = Describe("NodeRegistry", func() {
})
})
+ Describe("SetReplicaRemovedHook", func() {
+ type removed struct {
+ model, node string
+ replica int
+ }
+
+ It("fires once with the specific replica after RemoveNodeModel", func() {
+ node := makeNode("hook-remove-one", "10.0.0.230:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-model", 1, "loaded", "a", 0)).To(Succeed())
+
+ var fired []removed
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
+ })
+
+ // RemoveNodeModel(replica 1) must fire with the SPECIFIC replica index.
+ Expect(registry.RemoveNodeModel(context.Background(), node.ID, "hook-model", 1)).To(Succeed())
+ Expect(fired).To(HaveLen(1))
+ Expect(fired[0]).To(Equal(removed{model: "hook-model", node: node.ID, replica: 1}))
+ })
+
+ It("fires once with replica<0 after RemoveAllNodeModelReplicas", func() {
+ node := makeNode("hook-remove-all", "10.0.0.231:50051", 16_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 0, "loaded", "a", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 1, "loaded", "b", 0)).To(Succeed())
+
+ var fired []removed
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
+ })
+
+ // One call covers all replicas of that model on the node: a negative
+ // replica index signals "all replicas", and the consumer's
+ // InvalidateNode drops every entry for the (model, node) pair.
+ Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "hook-all-model")).To(Succeed())
+ Expect(fired).To(HaveLen(1))
+ Expect(fired[0].model).To(Equal("hook-all-model"))
+ Expect(fired[0].node).To(Equal(node.ID))
+ Expect(fired[0].replica).To(BeNumerically("<", 0))
+ })
+
+ It("does not panic when no hook is set", func() {
+ node := makeNode("hook-unset", "10.0.0.232:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "no-hook-model", 0, "loaded", "a", 0)).To(Succeed())
+
+ Expect(func() {
+ Expect(registry.RemoveNodeModel(context.Background(), node.ID, "no-hook-model", 0)).To(Succeed())
+ Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "no-hook-model")).To(Succeed())
+ }).ToNot(Panic())
+ })
+
+ // firedModelSet collects the distinct model names the hook saw for the
+ // given node. The bulk node-scoped deletes below remove every replica of
+ // every model on the node in one statement, so the chokepoint must fire
+ // the hook once per distinct model name (the consumer's Invalidate
+ // drops all entries for that (model, node) pair).
+ seedTwoModels := func(node *BackendNode) {
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 0, "loaded", "a0", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 1, "loaded", "a1", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", 0, "loaded", "b0", 0)).To(Succeed())
+ }
+
+ It("fires once per distinct model after MarkOffline", func() {
+ node := makeNode("hook-offline", "10.0.0.240:50051", 8_000_000_000)
+ seedTwoModels(node)
+
+ fired := map[removed]int{}
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ // Bulk node-scoped deletes signal "all replicas" with replica<0.
+ Expect(replicaIndex).To(BeNumerically("<", 0))
+ fired[removed{model: modelName, node: nodeID}]++
+ })
+
+ Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
+ Expect(fired).To(HaveLen(2))
+ Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
+ Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
+ })
+
+ It("fires once per distinct model after MarkDraining", func() {
+ node := makeNode("hook-draining", "10.0.0.241:50051", 8_000_000_000)
+ seedTwoModels(node)
+
+ fired := map[removed]int{}
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ // Bulk node-scoped deletes signal "all replicas" with replica<0.
+ Expect(replicaIndex).To(BeNumerically("<", 0))
+ fired[removed{model: modelName, node: nodeID}]++
+ })
+
+ Expect(registry.MarkDraining(context.Background(), node.ID)).To(Succeed())
+ Expect(fired).To(HaveLen(2))
+ Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
+ Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
+ })
+
+ It("fires once per distinct model after Deregister", func() {
+ node := makeNode("hook-deregister", "10.0.0.242:50051", 8_000_000_000)
+ seedTwoModels(node)
+
+ fired := map[removed]int{}
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ // Bulk node-scoped deletes signal "all replicas" with replica<0.
+ Expect(replicaIndex).To(BeNumerically("<", 0))
+ fired[removed{model: modelName, node: nodeID}]++
+ })
+
+ Expect(registry.Deregister(context.Background(), node.ID)).To(Succeed())
+ Expect(fired).To(HaveLen(2))
+ Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
+ Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
+ })
+
+ It("fires once per distinct model when re-registration clears stale rows", func() {
+ node := makeNode("hook-reregister", "10.0.0.243:50051", 8_000_000_000)
+ seedTwoModels(node)
+
+ fired := map[removed]int{}
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ // Bulk node-scoped deletes signal "all replicas" with replica<0.
+ Expect(replicaIndex).To(BeNumerically("<", 0))
+ fired[removed{model: modelName, node: nodeID}]++
+ })
+
+ // Re-register the same node (same name): the re-register path
+ // clears the stale model rows, which must fire the hook.
+ again := makeNode("hook-reregister", "10.0.0.243:50052", 8_000_000_000)
+ Expect(registry.Register(context.Background(), again, true)).To(Succeed())
+ Expect(fired).To(HaveLen(2))
+ Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
+ Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
+ })
+
+ // Atomicity: the bulk node-scoped delete in MarkOffline/MarkDraining/
+ // re-register now captures the model names and deletes the rows inside a
+ // single transaction. A true SetNodeModel-between-capture-and-delete race
+ // can't be forced deterministically here, but we can assert the
+ // post-condition the transaction guarantees: the set of fired hooks
+ // equals exactly the set of node_models rows the operation removed, with
+ // nothing left behind. If the capture and delete ever saw inconsistent
+ // snapshots, either a surviving row (delete missed it) or a missing hook
+ // (capture missed it) would break one of these assertions.
+ It("MarkOffline fires hooks for exactly the rows it deletes (consistent snapshot)", func() {
+ node := makeNode("hook-atomic-offline", "10.0.0.244:50051", 8_000_000_000)
+ seedTwoModels(node)
+
+ // Capture what the transaction should remove, straight from the DB,
+ // before running the operation.
+ before, err := registry.GetNodeModels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ expectedModels := map[string]struct{}{}
+ for _, nm := range before {
+ expectedModels[nm.ModelName] = struct{}{}
+ }
+ Expect(expectedModels).To(HaveLen(2), "seed should create two distinct models")
+
+ fired := map[string]struct{}{}
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
+ Expect(nodeID).To(Equal(node.ID))
+ Expect(replicaIndex).To(BeNumerically("<", 0))
+ fired[modelName] = struct{}{}
+ })
+
+ Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
+
+ // Hooks fired for exactly the distinct models that existed.
+ Expect(fired).To(Equal(expectedModels),
+ "hooks must fire for exactly the set of models the transaction deleted")
+
+ // And the delete actually emptied the node_models rows for the node:
+ // no row survives that did not get a hook.
+ after, err := registry.GetNodeModels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(after).To(BeEmpty(), "no node_models row should survive the bulk delete")
+ })
+ })
+
Describe("ApplyAutoLabels", func() {
It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() {
node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000)
diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go
index a9b01e80c..ec783b283 100644
--- a/core/services/nodes/router.go
+++ b/core/services/nodes/router.go
@@ -12,6 +12,8 @@ import (
"time"
"github.com/mudler/LocalAI/core/services/advisorylock"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/vram"
@@ -43,6 +45,22 @@ type SmartRouterOptions struct {
// anti-affinity is disabled at the scheduler layer; the per-node
// watchdog still enforces the rule on arrival.
ConflictResolver ConcurrencyConflictResolver
+ // PrefixProvider, when set, enables prefix-cache-aware routing: requests
+ // carrying a prompt prefix chain (distributedhdr.PrefixChain) are biased
+ // toward the node that already holds the longest matching prefix, subject
+ // to the load guard in prefixcache.Select. nil disables it entirely and
+ // routing is byte-for-byte the round-robin floor. At runtime this is the
+ // *prefixcache.Sync so Observe/Invalidate broadcast to peers.
+ PrefixProvider prefixcache.Provider
+ // PrefixConfig holds the global policy + thresholds. Per-model overrides on
+ // ModelSchedulingConfig refine it per request. Unused when PrefixProvider
+ // is nil.
+ PrefixConfig prefixcache.Config
+ // Pressure, when set, records a forced-disturb each time a request had a
+ // usable hot prefix match but the load guard forced it off the warm node.
+ // The reconciler reads the same instance to autoscale a saturated cache-warm
+ // replica. nil disables recording (the disabled path stays a no-op).
+ Pressure *prefixcache.Pressure
}
// SmartRouter routes inference requests to the best available backend node.
@@ -56,6 +74,14 @@ type SmartRouter struct {
db *gorm.DB // for advisory locks during routing
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
conflictResolver ConcurrencyConflictResolver
+ // prefixProvider is the prefix-cache routing seam (nil disables it; see
+ // SmartRouterOptions.PrefixProvider). prefixConfig holds the global policy
+ // and thresholds.
+ prefixProvider prefixcache.Provider
+ prefixConfig prefixcache.Config
+ // pressure records forced-disturb events (hot match forced off the warm
+ // node by the load guard). nil disables recording. See SmartRouterOptions.
+ pressure *prefixcache.Pressure
// installFlight coalesces concurrent identical NATS install requests
// (same nodeID + backend + modelID + replica) so 6 simultaneous chat
// completions for one not-yet-loaded model produce ONE round-trip, not
@@ -91,6 +117,9 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
stagingTracker: NewStagingTracker(),
conflictResolver: opts.ConflictResolver,
probeCache: newProbeCache(probeCacheTTL),
+ prefixProvider: opts.PrefixProvider,
+ prefixConfig: opts.PrefixConfig,
+ pressure: opts.Pressure,
}
}
@@ -230,18 +259,31 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
trackingKey = modelName
}
+ // Fetch the model's scheduling config once: it is immutable for the life of
+ // this request, and resolveSelectorCandidates, buildPreference, and
+ // nodeMatchesScheduling all read it. Fetching once gives a consistent
+ // snapshot and avoids three DB round-trips for one row. nil sched means
+ // "no scheduling constraints", same as before.
+ sched, _ := r.registry.GetModelScheduling(ctx, trackingKey)
+
// Resolve the model's NodeSelector once so cached-replica lookup and the
// new-load scheduler agree on the candidate set. Without this, a cached
// replica on a node the selector now excludes was picked over a matching
// replica elsewhere, and the fall-through then tried to load on the
// matching node where the model was already at capacity (eviction-busy).
- candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey)
+ candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey, sched)
if err != nil {
return nil, err
}
+ // Compute the prefix-cache preference once for this request. pref biases
+ // FindAndLockNodeWithModel toward the warm-cache node; observeChain is
+ // non-nil only when this model uses prefix_cache, gating the Observe calls
+ // below. Both are nil (no-op) when prefix-cache routing is disabled.
+ pref, observeChain := r.buildPreference(ctx, trackingKey, candidateNodeIDs, sched)
+
// Step 1: Find and atomically lock a node with this model loaded
- node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs)
+ node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref)
if err == nil && node != nil {
modelAddr := node.Address
if nm.Address != "" {
@@ -258,7 +300,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
"node", node.Name, "model", modelName, "replica", replicaIdx)
} else {
// Verify node still matches scheduling constraints
- if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
+ if !r.nodeMatchesScheduling(ctx, node, sched) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
@@ -269,6 +311,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
// onFirstComplete callback releases the reservation after the first inference
// call finishes, so in-flight returns to 0 when idle.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
+ r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx})
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
@@ -288,7 +331,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
// Step 2: Model not loaded — schedule loading with distributed lock to prevent duplicates
loadModel := func() (*RouteResult, error) {
// Re-check after acquiring lock — another request may have loaded it
- node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs)
+ node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref)
if err == nil && node != nil {
modelAddr := node.Address
if nm.Address != "" {
@@ -305,7 +348,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
"node", node.Name, "model", modelName, "replica", replicaIdx)
} else {
// Verify node still matches scheduling constraints
- if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
+ if !r.nodeMatchesScheduling(ctx, node, sched) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
@@ -314,6 +357,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
// Model loaded while we waited — FindAndLockNodeWithModel already incremented
// in-flight as a reservation. Release it after the first inference completes.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
+ r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx})
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
@@ -337,6 +381,10 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
return nil, err
}
+ // Cold load landed on result.Node replica result.ReplicaIndex: record the
+ // assignment so subsequent requests with the same prefix prefer it.
+ r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: result.Node.ID, Replica: result.ReplicaIndex})
+
replicaIdx := result.ReplicaIndex
tracked := NewInFlightTrackingClient(result.Client, r.registry, result.Node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
@@ -389,13 +437,117 @@ func extractNodeIDs(nodes []BackendNode) []string {
return ids
}
+// buildPreference computes the per-request route preference from the prefix
+// chain on ctx and the model's resolved policy. The returned observeChain is
+// non-nil only when the resolved policy is prefix_cache, signalling Route to
+// record the assignment after a successful pick; for round-robin models it is
+// nil so the tree is never polluted. The *RoutePreference is non-nil only when
+// a load-eligible preferred node was chosen.
+//
+// When prefix-cache routing is disabled (nil provider), no chain is present,
+// or the policy resolves to round-robin, both returns are nil and routing is
+// the unchanged round-robin floor.
+func (r *SmartRouter) buildPreference(ctx context.Context, modelID string, candidateNodeIDs []string, sched *ModelSchedulingConfig) (*RoutePreference, []uint64) {
+ if r.prefixProvider == nil {
+ return nil, nil
+ }
+ chain := distributedhdr.PrefixChain(ctx)
+ if len(chain) == 0 {
+ return nil, nil
+ }
+
+ // Resolve per-model policy + thresholds over the global config.
+ policy := r.prefixConfig.GlobalPolicy
+ cfg := r.prefixConfig
+ if sched != nil {
+ policy = prefixcache.ParsePolicy(sched.RoutePolicy).Resolve(r.prefixConfig.GlobalPolicy)
+ if sched.BalanceAbsThreshold > 0 {
+ cfg.BalanceAbsThreshold = sched.BalanceAbsThreshold
+ }
+ if sched.BalanceRelThreshold > 0 {
+ cfg.BalanceRelThreshold = sched.BalanceRelThreshold
+ }
+ if sched.MinPrefixMatch > 0 {
+ cfg.MinPrefixMatch = sched.MinPrefixMatch
+ }
+ }
+ if policy != prefixcache.RoutePolicyPrefixCache {
+ return nil, nil
+ }
+
+ // Load the candidate replicas PER REPLICA. Affinity is tracked per replica
+ // (each replica is a separate process with its own KV cache), so two
+ // replicas of the same model on the same node are two distinct candidates.
+ // FindAndLockNodeWithModel then locks the EXACT (node, replica) the policy
+ // chose.
+ stats, err := r.registry.LoadedReplicaStats(ctx, modelID, candidateNodeIDs)
+ if err != nil {
+ xlog.Debug("prefixcache: loading replica stats failed, skipping preference", "model", modelID, "error", err)
+ return nil, chain
+ }
+ if len(stats) == 0 {
+ return nil, chain
+ }
+ cands := make([]prefixcache.Candidate, 0, len(stats))
+ keys := make([]prefixcache.ReplicaKey, 0, len(stats))
+ for _, s := range stats {
+ key := prefixcache.ReplicaKey{NodeID: s.NodeID, Replica: s.ReplicaIndex}
+ cands = append(cands, prefixcache.Candidate{Key: key, InFlight: s.InFlight})
+ keys = append(keys, key)
+ }
+
+ d := r.prefixProvider.Decide(modelID, chain, keys, time.Now())
+ chosen, ok := prefixcache.Select(cands, d, cfg)
+
+ // Observability for the prefix-cache routing decision. One line per request
+ // at Debug: enable with DEBUG=true on the frontend to assess cache-aware
+ // routing. hotMatchHonored=true means we routed to the cache-warm replica;
+ // false with HasHot means the load guard forced a cold pick.
+ xlog.Debug("prefix-cache routing decision",
+ "model", modelID,
+ "chainDepth", len(chain),
+ "candidates", len(cands),
+ "hotNode", d.Hot.NodeID,
+ "hotReplica", d.Hot.Replica,
+ "hasHot", d.HasHot,
+ "matchRatio", d.MatchRatio,
+ "minMatch", cfg.MinPrefixMatch,
+ "chosen", fmt.Sprintf("%s/%d", chosen.NodeID, chosen.Replica),
+ "hotMatchHonored", d.HasHot && chosen == d.Hot)
+
+ // Forced-disturb: a usable hot prefix match existed but the load guard
+ // forced us off the warm replica (Select picked a different replica). This
+ // is the scale-worthy signal - the cache-warm replica is saturated. It
+ // deliberately does not fire for all-unique workloads (no hot match),
+ // avoiding false-positive scale-ups. nil pressure is a no-op.
+ if r.pressure != nil && d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && chosen != d.Hot {
+ r.pressure.Record(modelID, time.Now())
+ }
+
+ if !ok {
+ return nil, chain
+ }
+ return &RoutePreference{PreferredNodeID: chosen.NodeID, PreferredReplica: chosen.Replica}, chain
+}
+
+// observePrefix records that the replica `key` served the request whose prompt
+// prefix is chain. It is a no-op when prefix-cache routing is disabled or the
+// chain is empty (round-robin models pass a nil chain so the tree is never
+// polluted).
+func (r *SmartRouter) observePrefix(modelID string, chain []uint64, key prefixcache.ReplicaKey) {
+ if r.prefixProvider == nil || len(chain) == 0 {
+ return
+ }
+ r.prefixProvider.Observe(modelID, chain, key, time.Now())
+ xlog.Debug("prefix-cache observed assignment", "model", modelID, "node", key.NodeID, "replica", key.Replica, "chainDepth", len(chain))
+}
+
// resolveSelectorCandidates returns the node IDs that match the model's
// NodeSelector. Returns nil when no selector is configured ("any healthy node"
// — registry helpers treat nil as no filter). Returns an error when a
// non-empty selector matches zero healthy nodes, since there is nothing to
// route or schedule on.
-func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string) ([]string, error) {
- sched, _ := r.registry.GetModelScheduling(ctx, modelID)
+func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string, sched *ModelSchedulingConfig) ([]string, error) {
if sched == nil || sched.NodeSelector == "" {
return nil, nil
}
@@ -469,9 +621,8 @@ func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID str
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
// Returns true if no constraints exist or the node matches all selector labels.
-func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
- sched, err := r.registry.GetModelScheduling(ctx, modelName)
- if err != nil || sched == nil || sched.NodeSelector == "" {
+func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, sched *ModelSchedulingConfig) bool {
+ if sched == nil || sched.NodeSelector == "" {
return true // no constraints
}
@@ -518,7 +669,8 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
// Check for scheduling constraints (node selector). If a selector is set,
// we restrict the candidate pool to matching nodes; otherwise nil means
// "any healthy node".
- candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID)
+ sched, _ := r.registry.GetModelScheduling(ctx, modelID)
+ candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID, sched)
if err != nil {
return nil, "", 0, err
}
diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go
index 26237773e..8c6be3269 100644
--- a/core/services/nodes/router_test.go
+++ b/core/services/nodes/router_test.go
@@ -12,7 +12,9 @@ import (
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/messaging"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/core/services/testutil"
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
@@ -111,18 +113,34 @@ type fakeModelRouter struct {
findNodesWithModelByName map[string][]BackendNode
findNodesWithModelErr error
+ // LoadedReplicaStats returns (keyed by model name)
+ loadedReplicaStatsByName map[string][]ReplicaCandidate
+ loadedReplicaStatsErr error
+
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
removeCalls []string
setCalls []string
touchCalls []string
+
+ // Preferences passed to FindAndLockNodeWithModel, in call order. nil
+ // entries are recorded too, so tests can assert "preference was nil".
+ findAndLockPrefs []*RoutePreference
}
-func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string) (*BackendNode, *NodeModel, error) {
+func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
+ f.findAndLockPrefs = append(f.findAndLockPrefs, pref)
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
}
+func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) {
+ if f.loadedReplicaStatsErr != nil {
+ return nil, f.loadedReplicaStatsErr
+ }
+ return f.loadedReplicaStatsByName[modelName], nil
+}
+
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
return nil
@@ -1055,3 +1073,355 @@ var _ = Describe("SmartRouter", func() {
})
})
})
+
+// ---------------------------------------------------------------------------
+// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests
+// ---------------------------------------------------------------------------
+
+type observeRecord struct {
+ model string
+ chain []uint64
+ key prefixcache.ReplicaKey
+}
+
+type invalidateRecord struct {
+ model string
+ key prefixcache.ReplicaKey
+}
+
+// fakePrefixProvider records all interactions and returns a configurable
+// decision.
+type fakePrefixProvider struct {
+ decideCalls int
+ observed []observeRecord
+ invalidated []invalidateRecord
+ invalidatedNode []string
+ decision prefixcache.PrefixDecision
+}
+
+func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision {
+ f.decideCalls++
+ return f.decision
+}
+
+func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool {
+ f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key})
+ return true
+}
+
+func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) {
+ f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key})
+}
+
+func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) {
+ f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID)
+}
+
+func (f *fakePrefixProvider) Evict(_ time.Time) {}
+
+var _ = Describe("SmartRouter prefix-cache routing", func() {
+ var (
+ backend *stubBackend
+ factory *stubClientFactory
+ unloader *fakeUnloader
+ )
+
+ BeforeEach(func() {
+ backend = &stubBackend{healthResult: true}
+ factory = &stubClientFactory{client: backend}
+ unloader = &fakeUnloader{
+ installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"},
+ }
+ })
+
+ // loadedReg builds a fake registry with one loaded healthy replica for
+ // "m" on node "X", plus matching replica stats so buildPreference can run.
+ loadedReg := func() *fakeModelRouter {
+ node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
+ nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"}
+ return &fakeModelRouter{
+ findAndLockNode: node,
+ findAndLockNM: nm,
+ getModelScheduling: &ModelSchedulingConfig{
+ RoutePolicy: "prefix_cache",
+ },
+ loadedReplicaStatsByName: map[string][]ReplicaCandidate{
+ "m": {{NodeID: "X", InFlight: 0}},
+ },
+ }
+ }
+
+ Context("nil provider (round-robin floor)", func() {
+ It("passes a nil preference and never decides or observes", func() {
+ reg := loadedReg()
+ router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory})
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(reg.findAndLockPrefs).ToNot(BeEmpty())
+ for _, p := range reg.findAndLockPrefs {
+ Expect(p).To(BeNil())
+ }
+ })
+ })
+
+ Context("with a provider", func() {
+ It("passes the decided node as the preference and observes the pick", func() {
+ reg := loadedReg()
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(prov.decideCalls).To(BeNumerically(">=", 1))
+ Expect(reg.findAndLockPrefs[0]).ToNot(BeNil())
+ Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X"))
+ Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0))
+ Expect(prov.observed).To(HaveLen(1))
+ Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
+ Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3}))
+ })
+
+ It("routes a recurring prefix back to the previously observed node", func() {
+ // Real Index as the provider: first request observes X, second
+ // request with the same chain must yield PreferredNodeID == X.
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ reg := loadedReg()
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: idx,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ // First request landed on X (cold placement on the only candidate)
+ // and observed the prefix there.
+ dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now())
+ Expect(dFirst.HasHot).To(BeTrue())
+ Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
+
+ // Second request, same chain: X is now the warm-cache hot match, so
+ // the preference must point at it.
+ _, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1]
+ Expect(last).ToNot(BeNil())
+ Expect(last.PreferredNodeID).To(Equal("X"))
+ Expect(last.PreferredReplica).To(Equal(0))
+ })
+
+ It("prefers the exact hot replica when two replicas share a node", func() {
+ // Two replicas of "m" live on the SAME node X: replica 0 and replica
+ // 1. A hot prefix observed on (X,0) must produce a preference that
+ // locks replica 0 specifically, NOT the sibling replica 1 on the same
+ // node. This is the replica-granular regression this change fixes.
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
+ nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"}
+ reg := &fakeModelRouter{
+ findAndLockNode: node,
+ findAndLockNM: nm,
+ getModelScheduling: &ModelSchedulingConfig{
+ RoutePolicy: "prefix_cache",
+ },
+ loadedReplicaStatsByName: map[string][]ReplicaCandidate{
+ "m": {
+ {NodeID: "X", ReplicaIndex: 0, InFlight: 0},
+ {NodeID: "X", ReplicaIndex: 1, InFlight: 0},
+ },
+ },
+ }
+ // Seed the index so (X,0) is the warm replica for this chain.
+ idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now())
+
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: idx,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ pref := reg.findAndLockPrefs[0]
+ Expect(pref).ToNot(BeNil())
+ Expect(pref.PreferredNodeID).To(Equal("X"))
+ Expect(pref.PreferredReplica).To(Equal(0),
+ "the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen")
+ })
+
+ It("does not decide or observe when no prefix chain is present", func() {
+ reg := loadedReg()
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ _, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(prov.decideCalls).To(Equal(0))
+ Expect(prov.observed).To(BeEmpty())
+ Expect(reg.findAndLockPrefs[0]).To(BeNil())
+ })
+
+ It("does not observe for round-robin models even with a chain", func() {
+ reg := loadedReg()
+ reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"}
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(prov.decideCalls).To(Equal(0))
+ Expect(prov.observed).To(BeEmpty())
+ Expect(reg.findAndLockPrefs[0]).To(BeNil())
+ })
+ })
+
+ Context("forced-disturb pressure", func() {
+ // disturbReg builds a registry with two candidate replicas for "m":
+ // the hot node X is saturated (high in_flight) and Y is free. Select
+ // will therefore reject the hot node and pick Y, which is the
+ // forced-disturb signal. findAndLockNode returns Y so Route succeeds.
+ disturbReg := func() *fakeModelRouter {
+ nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"}
+ nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"}
+ return &fakeModelRouter{
+ findAndLockNode: nodeY,
+ findAndLockNM: nm,
+ getModelScheduling: &ModelSchedulingConfig{
+ RoutePolicy: "prefix_cache",
+ },
+ loadedReplicaStatsByName: map[string][]ReplicaCandidate{
+ "m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}},
+ },
+ }
+ }
+
+ It("records pressure when a strong hot match was forced off the warm node", func() {
+ reg := disturbReg()
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
+ Hot: prefixcache.ReplicaKey{NodeID: "X"},
+ HasHot: true,
+ MatchRatio: 1.0,
+ ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}},
+ }}
+ pressure := prefixcache.NewPressure(time.Minute)
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ Pressure: pressure,
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0),
+ "hot match existed but the load guard forced us off X: must record pressure")
+ })
+
+ It("does not record pressure when the hot node is itself eligible", func() {
+ reg := loadedReg() // single node X, in_flight 0 → X stays eligible
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
+ Hot: prefixcache.ReplicaKey{NodeID: "X"},
+ HasHot: true,
+ MatchRatio: 1.0,
+ }}
+ pressure := prefixcache.NewPressure(time.Minute)
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ Pressure: pressure,
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(pressure.Count("m", time.Now())).To(Equal(0),
+ "chosen == hot node, no disturb")
+ })
+
+ It("does not record pressure for an all-unique workload with no hot match", func() {
+ reg := loadedReg()
+ prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
+ HasHot: false, // no prefix match at all
+ MatchRatio: 0,
+ ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}},
+ }}
+ pressure := prefixcache.NewPressure(time.Minute)
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: prov,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ Pressure: pressure,
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+
+ Expect(pressure.Count("m", time.Now())).To(Equal(0),
+ "no hot match means no cache to disturb: must not false-positive")
+ })
+ })
+
+ Context("removal chokepoint on unload", func() {
+ It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() {
+ idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
+ reg := loadedReg()
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ PrefixProvider: idx,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ })
+
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6})
+ // Warm the cache: X now holds the prefix.
+ _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
+
+ // UnloadModel must route the eviction through the registry removal
+ // chokepoint (RemoveAllNodeModelReplicas). The registry's
+ // SetReplicaRemovedHook is what invalidates the prefix index in
+ // production; the router no longer invalidates directly. Here the
+ // fake registry records the removal but fires no hook, so we assert
+ // the chokepoint is exercised rather than the downstream
+ // invalidation (covered by the registry hook integration tests).
+ Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed())
+ Expect(reg.removeCalls).To(ContainElement("X:m"),
+ "UnloadModel must remove the replica via the registry removal chokepoint")
+ })
+ })
+})
diff --git a/docs/content/features/distributed-mode.md b/docs/content/features/distributed-mode.md
index f9b09079a..c9b48d835 100644
--- a/docs/content/features/distributed-mode.md
+++ b/docs/content/features/distributed-mode.md
@@ -558,3 +558,17 @@ All fields are optional and composable:
- Ensure the port range is not blocked by firewalls or used by other services
- Verify the backend gallery configuration is correct
- The worker needs network access to download backends from the gallery
+
+## Roadmap: Routing and Caching Enhancements
+
+The scheduling algorithm above is load-based (least in-flight, then least-recently-used). Work is underway to make routing **prefix-cache-aware**: bias each request toward the replica that already holds the relevant KV/prefix cache (multi-turn conversations and shared system prompts), so backends reuse cache instead of recomputing it. The first step is a router-side radix tree of prompt-prefix hashes mapped to nodes, with longest-prefix match, a load guard that preserves round-robin behavior under imbalance, and NATS sync across frontends. It is purely a routing-layer hint (no backend changes) and never routes worse than today's round-robin.
+
+Further enhancements, surfaced from a survey of SGLang, vLLM production-stack, Ray Serve, llm-d, AIBrix, and NVIDIA Dynamo, are tracked under the routing roadmap epic ([#10063](https://github.com/mudler/LocalAI/issues/10063)):
+
+- **Reported/precise KV-event mode** ([#10064](https://github.com/mudler/LocalAI/issues/10064)): subscribe to actual backend KV-cache events for exact residency instead of inferring it from routing history.
+- **Multi-tier cache-overlap scoring** ([#10065](https://github.com/mudler/LocalAI/issues/10065)): credit GPU/CPU/disk cache tiers separately.
+- **Pluggable scorer/filter/picker pipeline** ([#10066](https://github.com/mudler/LocalAI/issues/10066)): composable multi-signal routing (cache, queue depth, KV utilization, latency).
+- **Load-shaping** ([#10067](https://github.com/mudler/LocalAI/issues/10067)): anti-herding (softmax/temperature) and dispatch-time freshness.
+- **Prefill/decode disaggregation routing** ([#10068](https://github.com/mudler/LocalAI/issues/10068)): route prefill and decode to separate pools with KV transfer.
+- **Per-user fairness (VTC)** ([#10069](https://github.com/mudler/LocalAI/issues/10069)): balance per-user token usage against pod load.
+- **Minor tuning + MCP parity** ([#10070](https://github.com/mudler/LocalAI/issues/10070)): per-model TTL override, probabilistic LRU updates, and MCP scheduling-config tool parity.
diff --git a/pkg/distributedhdr/prefixhash.go b/pkg/distributedhdr/prefixhash.go
new file mode 100644
index 000000000..a4124098f
--- /dev/null
+++ b/pkg/distributedhdr/prefixhash.go
@@ -0,0 +1,37 @@
+package distributedhdr
+
+import "context"
+
+type prefixChainKey struct{}
+
+// WithPrefixChain attaches a prompt prefix-hash chain to ctx so the distributed
+// router can make a prefix-cache-aware decision. Set at inference entry where
+// the rendered prompt is known; read in SmartRouter.Route.
+func WithPrefixChain(ctx context.Context, chain []uint64) context.Context {
+ return context.WithValue(ctx, prefixChainKey{}, chain)
+}
+
+// PrefixChain returns the chain attached by WithPrefixChain, or nil.
+func PrefixChain(ctx context.Context) []uint64 {
+ if v, ok := ctx.Value(prefixChainKey{}).([]uint64); ok {
+ return v
+ }
+ return nil
+}
+
+// PrefixChainHook, when set at startup (distributed mode only), builds a prefix
+// hash chain from a model id and rendered prompt. Left nil in single-process
+// mode so there is zero overhead. See core/application/distributed.go.
+var PrefixChainHook func(model, prompt string) []uint64
+
+// MaybeWithPrefixChain attaches a prefix chain to ctx iff the hook is set and
+// returns a non-empty chain. Otherwise returns ctx unchanged.
+func MaybeWithPrefixChain(ctx context.Context, model, prompt string) context.Context {
+ if PrefixChainHook == nil {
+ return ctx
+ }
+ if chain := PrefixChainHook(model, prompt); len(chain) > 0 {
+ return WithPrefixChain(ctx, chain)
+ }
+ return ctx
+}
diff --git a/pkg/distributedhdr/prefixhash_test.go b/pkg/distributedhdr/prefixhash_test.go
new file mode 100644
index 000000000..5b85db23c
--- /dev/null
+++ b/pkg/distributedhdr/prefixhash_test.go
@@ -0,0 +1,37 @@
+package distributedhdr_test
+
+import (
+ "context"
+
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("prefix chain ctx", func() {
+ It("round-trips the chain through ctx", func() {
+ ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
+ Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{1, 2, 3}))
+ })
+ It("returns nil when absent", func() {
+ Expect(distributedhdr.PrefixChain(context.Background())).To(BeNil())
+ })
+
+ It("uses the hook to build the chain when set", func() {
+ distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return []uint64{42} }
+ defer func() { distributedhdr.PrefixChainHook = nil }()
+ ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
+ Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{42}))
+ })
+ It("is a no-op when the hook is nil", func() {
+ distributedhdr.PrefixChainHook = nil
+ ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
+ Expect(distributedhdr.PrefixChain(ctx)).To(BeNil())
+ })
+ It("is a no-op when the hook returns an empty chain", func() {
+ distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return nil }
+ defer func() { distributedhdr.PrefixChainHook = nil }()
+ ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi")
+ Expect(distributedhdr.PrefixChain(ctx)).To(BeNil())
+ })
+})
diff --git a/pkg/radixtree/radixtree.go b/pkg/radixtree/radixtree.go
new file mode 100644
index 000000000..9f6602abc
--- /dev/null
+++ b/pkg/radixtree/radixtree.go
@@ -0,0 +1,248 @@
+// Package radixtree implements a generic prefix tree over sequences of uint64
+// key-elements, mapping the longest stored prefix of a query sequence to a
+// value. Entries carry a TTL and the tree tracks a recency-weighted score per
+// value. The clock is injected (callers pass `now`) so behavior is fully
+// deterministic and testable. It has no external dependencies.
+package radixtree
+
+import (
+ "math"
+ "sync"
+ "time"
+)
+
+// Options configures a Tree.
+type Options struct {
+ // TTL is the idle lifetime of an entry. An entry whose lastSeen is older
+ // than TTL (relative to the `now` passed in) is treated as absent and is
+ // swept by Evict. Refreshed on every Insert that traverses it. The boundary
+ // is strict greater-than: an entry whose age is exactly equal to TTL is
+ // still live; it expires only once age exceeds TTL.
+ TTL time.Duration
+ // HalfLife controls recency weighting in Weight(). An entry contributes
+ // 0.5^(age/HalfLife). Zero means "no decay" (every live entry counts 1).
+ HalfLife time.Duration
+ // MaxEntries bounds the number of value-bearing nodes. Zero means
+ // unbounded. When exceeded, Insert evicts the least-recently-seen entry.
+ MaxEntries int
+}
+
+// Tree is a prefix tree. V is the stored value type (for prefix-cache routing,
+// a node identifier). Safe for concurrent use.
+type Tree[V comparable] struct {
+ mu sync.RWMutex
+ opts Options
+ root *node[V]
+ size int
+}
+
+type node[V comparable] struct {
+ children map[uint64]*node[V]
+ value V
+ hasValue bool
+ lastSeen time.Time
+}
+
+// New creates an empty Tree.
+func New[V comparable](opts Options) *Tree[V] {
+ return &Tree[V]{opts: opts, root: &node[V]{children: map[uint64]*node[V]{}}}
+}
+
+// LongestMatch returns the value at the deepest stored, non-expired prefix of
+// key, the matched depth (number of key elements consumed), and ok.
+func (t *Tree[V]) LongestMatch(key []uint64, now time.Time) (V, int, bool) {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ var best V
+ bestDepth, found := 0, false
+ cur := t.root
+ for i, k := range key {
+ next, ok := cur.children[k]
+ if !ok {
+ break
+ }
+ cur = next
+ if cur.hasValue && !t.expired(cur, now) {
+ best, bestDepth, found = cur.value, i+1, true
+ }
+ }
+ return best, bestDepth, found
+}
+
+// expired reports whether n's lastSeen is older than the configured TTL. The
+// comparison is strict greater-than: an entry whose age equals TTL exactly is
+// still considered live. With TTL == 0 (unbounded) nothing ever expires.
+func (t *Tree[V]) expired(n *node[V], now time.Time) bool {
+ return t.opts.TTL > 0 && now.Sub(n.lastSeen) > t.opts.TTL
+}
+
+// Insert records value at EVERY node along the key chain, not just the leaf,
+// so each prefix-block node remembers the value (node id) that served that
+// prefix. This is what makes LongestMatch find a shared prefix even when the
+// query tail diverges (SGLang/vLLM-style prefix matching). Re-inserting a
+// different value over a shared prefix node overwrites it: the last writer
+// owns the shared prefix node (a recency heuristic, and the correct one - the
+// most recent chain that traversed that block is the one most likely warm).
+// lastSeen is refreshed on every traversed node so active prefixes stay live.
+// Inserting an empty key is a no-op: the root never holds a value.
+func (t *Tree[V]) Insert(key []uint64, value V, now time.Time) {
+ if len(key) == 0 {
+ return
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ cur := t.root
+ for _, k := range key {
+ next, ok := cur.children[k]
+ if !ok {
+ next = &node[V]{children: map[uint64]*node[V]{}}
+ cur.children[k] = next
+ }
+ cur = next
+ if !cur.hasValue {
+ t.size++
+ }
+ cur.value, cur.hasValue, cur.lastSeen = value, true, now
+ }
+ if t.opts.MaxEntries > 0 && t.size > t.opts.MaxEntries {
+ t.evictOldestLocked(now)
+ }
+}
+
+// evictOldestLocked drops the single least-recently-seen value-bearing node and
+// prunes any empty branches the removal leaves behind. Called with t.mu held.
+func (t *Tree[V]) evictOldestLocked(now time.Time) {
+ var victim *node[V]
+ var walk func(n *node[V])
+ walk = func(n *node[V]) {
+ if n.hasValue && (victim == nil || n.lastSeen.Before(victim.lastSeen)) {
+ victim = n
+ }
+ for _, c := range n.children {
+ walk(c)
+ }
+ }
+ walk(t.root)
+ if victim != nil {
+ // Clear the victim's value and reclaim it plus any ancestors that are
+ // now both value-less and childless.
+ t.pruneWalk(t.root, func(n *node[V]) bool { return n == victim })
+ }
+}
+
+// pruneWalk clears the value of every node for which shouldClear returns true,
+// then removes the now empty (value-less and childless) branches that result.
+// It keeps t.size accurate by decrementing once per cleared node. Returns true
+// if n itself should be removed from its parent. Called with t.mu held.
+func (t *Tree[V]) pruneWalk(n *node[V], shouldClear func(*node[V]) bool) bool {
+ for k, c := range n.children {
+ if t.pruneWalk(c, shouldClear) {
+ delete(n.children, k)
+ }
+ }
+ if n.hasValue && shouldClear(n) {
+ n.hasValue = false
+ var zero V
+ n.value = zero
+ t.size--
+ }
+ return n != t.root && !n.hasValue && len(n.children) == 0
+}
+
+// Len returns the number of live (value-bearing) entries, including not-yet-
+// swept expired ones. Use after Evict for the post-sweep count.
+func (t *Tree[V]) Len() int {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.size
+}
+
+// Evict removes expired value-bearing nodes and prunes resulting empty
+// branches. O(n) in tree size; call periodically from a background sweeper.
+func (t *Tree[V]) Evict(now time.Time) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.pruneWalk(t.root, func(n *node[V]) bool { return t.expired(n, now) })
+}
+
+// contribution returns the recency-weighted score a single live, non-expired
+// node adds to its value's weight: 1.0 when HalfLife<=0 (a plain count), else
+// 0.5^(age/HalfLife). It does not check hasValue or expiry; callers must filter
+// those first. Shared by Weight and WeightsFor so the metric stays identical.
+func (t *Tree[V]) contribution(n *node[V], now time.Time) float64 {
+ if t.opts.HalfLife <= 0 {
+ return 1
+ }
+ age := now.Sub(n.lastSeen).Seconds()
+ return math.Pow(0.5, age/t.opts.HalfLife.Seconds())
+}
+
+// Weight returns the recency-weighted count of live entries anchored to value:
+// sum over non-expired entries of 0.5^(age/HalfLife). With HalfLife==0 every
+// live entry contributes 1.0 (a plain count). This is the "valuable warm cache"
+// proxy used for cold placement and autoscale.
+func (t *Tree[V]) Weight(value V, now time.Time) float64 {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ var sum float64
+ var walk func(n *node[V])
+ walk = func(n *node[V]) {
+ if n.hasValue && n.value == value && !t.expired(n, now) {
+ sum += t.contribution(n, now)
+ }
+ for _, c := range n.children {
+ walk(c)
+ }
+ }
+ walk(t.root)
+ return sum
+}
+
+// WeightsFor returns the recency-weighted weight (same metric as Weight) for
+// each value in values, computed in a single tree traversal. Values not present
+// in the tree map to 0. This is O(N + len(values)) versus calling Weight once
+// per value (O(len(values) * N)). Concurrency-safe (read lock).
+func (t *Tree[V]) WeightsFor(values []V, now time.Time) map[V]float64 {
+ want := make(map[V]struct{}, len(values))
+ result := make(map[V]float64, len(values))
+ for _, v := range values {
+ want[v] = struct{}{}
+ result[v] = 0
+ }
+ if len(want) == 0 {
+ return result
+ }
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ var walk func(n *node[V])
+ walk = func(n *node[V]) {
+ if n.hasValue && !t.expired(n, now) {
+ if _, ok := want[n.value]; ok {
+ result[n.value] += t.contribution(n, now)
+ }
+ }
+ for _, c := range n.children {
+ walk(c)
+ }
+ }
+ walk(t.root)
+ return result
+}
+
+// Remove drops every entry whose value equals value, then prunes empty
+// branches. Used when a replica is unloaded or its node goes offline so the
+// tree never points at a node that no longer holds the model. It is the
+// equality special case of RemoveFunc.
+func (t *Tree[V]) Remove(value V) {
+ t.RemoveFunc(func(v V) bool { return v == value })
+}
+
+// RemoveFunc drops every entry whose value satisfies pred, then prunes empty
+// branches. Generalizes Remove (Remove(v) == RemoveFunc(func(x V) bool { return
+// x == v })). Used to drop, in one walk, every entry that belongs to a class of
+// values (for example all replicas of a single node).
+func (t *Tree[V]) RemoveFunc(pred func(V) bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.pruneWalk(t.root, func(n *node[V]) bool { return pred(n.value) })
+}
diff --git a/pkg/radixtree/radixtree_suite_test.go b/pkg/radixtree/radixtree_suite_test.go
new file mode 100644
index 000000000..1dc9ab0c5
--- /dev/null
+++ b/pkg/radixtree/radixtree_suite_test.go
@@ -0,0 +1,13 @@
+package radixtree_test
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestRadixTree(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "RadixTree Suite")
+}
diff --git a/pkg/radixtree/radixtree_test.go b/pkg/radixtree/radixtree_test.go
new file mode 100644
index 000000000..507ad9bfb
--- /dev/null
+++ b/pkg/radixtree/radixtree_test.go
@@ -0,0 +1,354 @@
+package radixtree_test
+
+import (
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/pkg/radixtree"
+)
+
+var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC)
+
+var _ = Describe("Tree construction", func() {
+ It("returns an empty tree that matches nothing", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ _, depth, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0)
+ Expect(ok).To(BeFalse())
+ Expect(depth).To(Equal(0))
+ })
+})
+
+var _ = Describe("Insert and LongestMatch", func() {
+ It("returns the deepest matching prefix value", func() {
+ // Non-overlapping chains keep the longest-prefix intent clean: every
+ // node on the value's own chain records that value, and no other Insert
+ // overwrites a shared prefix node. A query that runs off the end of a
+ // chain stops matching at the deepest stored element it reached.
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0)
+ tr.Insert([]uint64{7, 8}, "nodeA", t0)
+
+ v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("nodeB"))
+ Expect(depth).To(Equal(4))
+
+ v, depth, ok = tr.LongestMatch([]uint64{7, 8, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("nodeA"))
+ Expect(depth).To(Equal(2))
+ })
+
+ It("lets the last writer own a shared prefix node", func() {
+ // When two chains share a leading block, value-at-every-node means the
+ // later Insert overwrites the shared prefix node. Inserting nodeA on
+ // [1,2] then nodeB on [1,2,3,4] makes nodeB own [1] and [1,2], so a
+ // query that diverges within the shared block resolves to nodeB. This
+ // is the intended recency heuristic: the most recent chain through that
+ // block is the one most likely still warm.
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2}, "nodeA", t0)
+ tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0)
+
+ v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("nodeB"))
+ Expect(depth).To(Equal(4))
+
+ // The shared prefix [1,2] is now owned by nodeB (last writer wins).
+ v, depth, ok = tr.LongestMatch([]uint64{1, 2, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("nodeB"))
+ Expect(depth).To(Equal(2))
+ })
+
+ It("returns ok=false when no prefix is stored", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{7, 8}, "nodeA", t0)
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
+ Expect(ok).To(BeFalse())
+ })
+
+ It("matches a shared prefix when the query tail diverges", func() {
+ // SGLang/vLLM-style prefix matching: a single Insert of a full chain
+ // must let any query that shares a leading block match at the depth of
+ // the deepest shared element, even though the tails differ. This is the
+ // core use case (shared system prompt / multi-turn extension / volatile
+ // tail), not exact-repeat.
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2, 3, 4, 5}, "nodeA", t0)
+ v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 9, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(depth).To(Equal(3)) // shared prefix [1,2,3]
+ Expect(v).To(Equal("nodeA"))
+ })
+})
+
+var _ = Describe("TTL expiry", func() {
+ It("does not match an entry past its TTL", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ tr.Insert([]uint64{1, 2}, "nodeA", t0)
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Minute))
+ Expect(ok).To(BeFalse())
+ })
+
+ It("refreshes lastSeen on re-insert so a live path survives", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ tr.Insert([]uint64{1, 2}, "nodeA", t0)
+ tr.Insert([]uint64{1, 2}, "nodeA", t0.Add(50*time.Second))
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(90*time.Second))
+ Expect(ok).To(BeTrue())
+ })
+
+ It("Evict reclaims expired nodes", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ // Value-at-every-node: Insert of a 2-element chain records nodeA at both
+ // {1} and {1,2}, so Len is 2 (one valued node per distinct prefix).
+ tr.Insert([]uint64{1, 2}, "nodeA", t0)
+ Expect(tr.Len()).To(Equal(2))
+ tr.Evict(t0.Add(2 * time.Minute))
+ Expect(tr.Len()).To(Equal(0))
+ })
+})
+
+var _ = Describe("Weight", func() {
+ It("counts live entries for a value with no decay", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0
+ tr.Insert([]uint64{1}, "A", t0)
+ tr.Insert([]uint64{1, 2}, "A", t0)
+ tr.Insert([]uint64{9}, "B", t0)
+ Expect(tr.Weight("A", t0)).To(BeNumerically("==", 2))
+ Expect(tr.Weight("B", t0)).To(BeNumerically("==", 1))
+ Expect(tr.Weight("C", t0)).To(BeNumerically("==", 0))
+ })
+
+ It("decays older entries by half-life", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute})
+ tr.Insert([]uint64{1}, "A", t0)
+ // one half-life later, the entry weighs 0.5
+ Expect(tr.Weight("A", t0.Add(time.Minute))).To(BeNumerically("~", 0.5, 0.001))
+ })
+
+ It("ignores expired entries", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ tr.Insert([]uint64{1}, "A", t0)
+ Expect(tr.Weight("A", t0.Add(2*time.Minute))).To(BeNumerically("==", 0))
+ })
+})
+
+var _ = Describe("WeightsFor", func() {
+ It("matches per-value Weight with no decay", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0
+ tr.Insert([]uint64{1}, "A", t0)
+ tr.Insert([]uint64{1, 2}, "A", t0)
+ tr.Insert([]uint64{9}, "B", t0)
+
+ got := tr.WeightsFor([]string{"A", "B", "C"}, t0)
+ Expect(got).To(HaveLen(3))
+ Expect(got["A"]).To(BeNumerically("==", 2))
+ Expect(got["B"]).To(BeNumerically("==", 1))
+ Expect(got["C"]).To(BeNumerically("==", 0))
+ })
+
+ It("matches per-value Weight under decay", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute})
+ tr.Insert([]uint64{1}, "A", t0)
+ tr.Insert([]uint64{1, 2}, "A", t0.Add(30*time.Second))
+ tr.Insert([]uint64{9}, "B", t0)
+
+ now := t0.Add(time.Minute)
+ got := tr.WeightsFor([]string{"A", "B", "C"}, now)
+ Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12))
+ Expect(got["B"]).To(BeNumerically("~", tr.Weight("B", now), 1e-12))
+ Expect(got["C"]).To(BeNumerically("==", 0))
+ })
+
+ It("respects TTL expiry and matches Weight at a non-zero age under decay", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute, HalfLife: 30 * time.Second})
+ tr.Insert([]uint64{1}, "A", t0) // will be expired at now
+ tr.Insert([]uint64{2}, "A", t0.Add(90*time.Second)) // live, aged 30s at now
+ tr.Insert([]uint64{9}, "B", t0) // expired at now
+
+ now := t0.Add(2 * time.Minute)
+ got := tr.WeightsFor([]string{"A", "B"}, now)
+ Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12))
+ Expect(got["A"]).To(BeNumerically("~", 0.5, 0.001)) // single live entry aged one half-life
+ Expect(got["B"]).To(BeNumerically("==", 0))
+ })
+
+ It("returns an empty map for an empty values slice", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1}, "A", t0)
+ Expect(tr.WeightsFor(nil, t0)).To(BeEmpty())
+ Expect(tr.WeightsFor([]string{}, t0)).To(BeEmpty())
+ })
+
+ It("maps a value not present in the tree to 0", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1}, "A", t0)
+ got := tr.WeightsFor([]string{"Z"}, t0)
+ Expect(got).To(HaveLen(1))
+ Expect(got["Z"]).To(BeNumerically("==", 0))
+ })
+})
+
+var _ = Describe("Remove", func() {
+ It("drops every entry anchored to a value and prunes", func() {
+ // Non-overlapping chains so Remove("A") and the survival of B are both
+ // meaningful: with value-at-every-node, overlapping chains would let the
+ // later writer own the shared prefix nodes, so A could own nothing and
+ // the test would be vacuous.
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2}, "A", t0)
+ tr.Insert([]uint64{7, 8, 9}, "B", t0)
+ tr.Remove("A")
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
+ Expect(ok).To(BeFalse()) // A gone; its branch is fully reclaimed
+ v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("B")) // B survives
+ Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0))
+ })
+})
+
+var _ = Describe("RemoveFunc", func() {
+ It("drops every entry matching the predicate, prunes, and keeps the rest", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2}, "drop-a", t0)
+ tr.Insert([]uint64{3, 4}, "drop-b", t0)
+ tr.Insert([]uint64{7, 8, 9}, "keep", t0)
+ // Drop everything whose value starts with "drop".
+ tr.RemoveFunc(func(v string) bool { return len(v) >= 4 && v[:4] == "drop" })
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
+ Expect(ok).To(BeFalse())
+ _, _, ok = tr.LongestMatch([]uint64{3, 4}, t0)
+ Expect(ok).To(BeFalse())
+ v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("keep"))
+ Expect(tr.Len()).To(Equal(3)) // only the 3-node "keep" chain remains
+ })
+
+ It("makes Remove a special case of RemoveFunc", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2}, "A", t0)
+ tr.Insert([]uint64{7, 8, 9}, "B", t0)
+ tr.RemoveFunc(func(v string) bool { return v == "A" })
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0)
+ Expect(ok).To(BeFalse())
+ v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0)
+ Expect(ok).To(BeTrue())
+ Expect(v).To(Equal("B"))
+ })
+})
+
+var _ = Describe("TTL boundary", func() {
+ It("treats age exactly equal to TTL as still live, and one tick past as expired", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Minute})
+ tr.Insert([]uint64{1, 2}, "A", t0)
+
+ // age == TTL: strict greater-than means this is still live.
+ _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute))
+ Expect(ok).To(BeTrue())
+
+ // one nanosecond past TTL: expired.
+ _, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute+time.Nanosecond))
+ Expect(ok).To(BeFalse())
+ })
+})
+
+var _ = Describe("MaxEntries eviction", func() {
+ It("drops the least-recently-seen entry when the cap is exceeded", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
+ tr.Insert([]uint64{1}, "A", t0)
+ tr.Insert([]uint64{2}, "B", t0.Add(time.Second))
+ tr.Insert([]uint64{3}, "C", t0.Add(2*time.Second))
+
+ Expect(tr.Len()).To(Equal(2))
+
+ // A was the least-recently-seen, so it is the one dropped.
+ _, _, ok := tr.LongestMatch([]uint64{1}, t0.Add(2*time.Second))
+ Expect(ok).To(BeFalse())
+
+ // B and C survive.
+ _, _, ok = tr.LongestMatch([]uint64{2}, t0.Add(2*time.Second))
+ Expect(ok).To(BeTrue())
+ _, _, ok = tr.LongestMatch([]uint64{3}, t0.Add(2*time.Second))
+ Expect(ok).To(BeTrue())
+ })
+
+ It("prunes value-less ancestors left behind by an eviction", func() {
+ // Value-at-every-node: Inserting the deep chain B = [1,2,3] records B at
+ // {1}, {1,2}, and {1,2,3} (three valued nodes). With the cap at 2, the
+ // least-recently-seen valued nodes are evicted one per subsequent Insert.
+ // The two fresh single-element keys (C, D) are newer, so eviction keeps
+ // peeling B's nodes off until B's entire branch is reclaimed - none of
+ // its internal nodes may linger and inflate Len past the cap.
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
+ tr.Insert([]uint64{1, 2, 3}, "B", t0)
+ tr.Insert([]uint64{5}, "C", t0.Add(time.Second))
+ tr.Insert([]uint64{6}, "D", t0.Add(2*time.Second))
+
+ Expect(tr.Len()).To(Equal(2))
+ // B (oldest) evicted; its deep branch reclaimed.
+ _, _, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0.Add(2*time.Second))
+ Expect(ok).To(BeFalse())
+ _, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Second))
+ Expect(ok).To(BeFalse())
+ Expect(tr.Weight("B", t0.Add(2*time.Second))).To(BeNumerically("==", 0))
+ })
+
+ It("reclaims structure so the tree never grows past the cap under churn", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2})
+ tr.Insert([]uint64{1}, "A", t0)
+ tr.Insert([]uint64{2}, "B", t0.Add(time.Second))
+ Expect(tr.Len()).To(Equal(2))
+
+ for i := range 10 {
+ tr.Insert([]uint64{uint64(100 + i)}, "X", t0.Add(time.Duration(i+2)*time.Second))
+ Expect(tr.Len()).To(Equal(2))
+ }
+ })
+})
+
+var _ = Describe("Empty key", func() {
+ It("LongestMatch on an empty key returns ok=false", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ tr.Insert([]uint64{1, 2}, "A", t0)
+ _, depth, ok := tr.LongestMatch([]uint64{}, t0)
+ Expect(ok).To(BeFalse())
+ Expect(depth).To(Equal(0))
+ })
+
+ It("Insert with an empty key is a no-op that creates no root value", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ Expect(func() { tr.Insert([]uint64{}, "A", t0) }).NotTo(Panic())
+ Expect(tr.Len()).To(Equal(0))
+ _, _, ok := tr.LongestMatch([]uint64{}, t0)
+ Expect(ok).To(BeFalse())
+ Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0))
+ })
+})
+
+var _ = Describe("Concurrent access", func() {
+ It("is race-free under parallel insert/match/weight", func() {
+ tr := radixtree.New[string](radixtree.Options{TTL: time.Hour})
+ done := make(chan struct{})
+ for g := range 8 {
+ go func(g int) {
+ defer GinkgoRecover()
+ for i := range 1000 {
+ tr.Insert([]uint64{uint64(g), uint64(i % 10)}, "n", t0)
+ tr.LongestMatch([]uint64{uint64(g), 1}, t0)
+ tr.Weight("n", t0)
+ }
+ done <- struct{}{}
+ }(g)
+ }
+ for range 8 {
+ <-done
+ }
+ })
+})
diff --git a/tests/e2e/distributed/model_routing_test.go b/tests/e2e/distributed/model_routing_test.go
index d8e3127f0..cb64e4622 100644
--- a/tests/e2e/distributed/model_routing_test.go
+++ b/tests/e2e/distributed/model_routing_test.go
@@ -63,7 +63,7 @@ var _ = Describe("Model Routing", Label("Distributed"), func() {
Expect(models[0].InFlight).To(Equal(2))
// FindAndLockNodeWithModel should return this node and atomically increment in-flight
- foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil)
+ foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(node.ID))
Expect(foundModel.ModelName).To(Equal("llama3"))
diff --git a/tests/e2e/distributed/prefix_cache_routing_test.go b/tests/e2e/distributed/prefix_cache_routing_test.go
new file mode 100644
index 000000000..ad0cb709d
--- /dev/null
+++ b/tests/e2e/distributed/prefix_cache_routing_test.go
@@ -0,0 +1,234 @@
+package distributed_test
+
+import (
+ "context"
+ "time"
+
+ "github.com/mudler/LocalAI/core/services/nodes"
+ "github.com/mudler/LocalAI/core/services/nodes/prefixcache"
+ "github.com/mudler/LocalAI/pkg/distributedhdr"
+ grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
+ pb "github.com/mudler/LocalAI/pkg/grpc/proto"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ ggrpc "google.golang.org/grpc"
+
+ pgdriver "gorm.io/driver/postgres"
+ gormDB "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+)
+
+// prefixStubBackend implements grpc.Backend with a canned-success HealthCheck
+// and LoadModel so SmartRouter.probeHealth passes and any cold load returns
+// success — no real inference happens. Mirrors the stubBackend pattern used by
+// the SmartRouter unit tests in core/services/nodes/router_test.go, reproduced
+// here because that fake lives in the internal (unexported) nodes package.
+type prefixStubBackend struct {
+ grpcPkg.Backend // embed so unused methods satisfy the interface; they panic only if called
+
+ healthResult bool
+}
+
+func (f *prefixStubBackend) HealthCheck(_ context.Context) (bool, error) {
+ return f.healthResult, nil
+}
+
+func (f *prefixStubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
+ return &pb.Result{Success: true}, nil
+}
+
+func (f *prefixStubBackend) IsBusy() bool { return false }
+
+// prefixStubClientFactory hands the same fake backend to every NewClient call,
+// so the SmartRouter never opens a real gRPC connection during routing.
+type prefixStubClientFactory struct {
+ client *prefixStubBackend
+}
+
+func (f *prefixStubClientFactory) NewClient(_ string, _ bool) grpcPkg.Backend {
+ return f.client
+}
+
+var _ = Describe("Prefix-cache aware routing", Label("Distributed"), func() {
+ const model = "model"
+
+ var (
+ infra *TestInfra
+ db *gormDB.DB
+ registry *nodes.NodeRegistry
+ router *nodes.SmartRouter
+ idx *prefixcache.Index
+
+ nodeXID string
+ nodeYID string
+
+ chainA = []uint64{1, 2, 3, 4, 5} // conversation A
+ chainShared = []uint64{1, 2, 3, 9, 9} // shares leading prefix [1,2,3] with A
+ chainUnrelated = []uint64{7, 8, 9} // no shared prefix with A
+ )
+
+ // routeAndSettle drives one request through the router for the given prefix
+ // chain and immediately settles the in-flight reservation the way a real
+ // inference completion would (Release closes the client; the DecrementInFlight
+ // emulates the OnFirstComplete callback that fires after the first inference).
+ // Settling keeps both nodes balanced at in_flight=0 so the prefix-cache load
+ // guard never falsely forces a request off its warm node between steps.
+ routeAndSettle := func(chain []uint64) string {
+ GinkgoHelper()
+ ctx := distributedhdr.WithPrefixChain(context.Background(), chain)
+ result, err := router.Route(ctx, model, model, "llama-cpp",
+ &pb.ModelOptions{ModelFile: model}, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(result).ToNot(BeNil())
+ Expect(result.Node).ToNot(BeNil())
+ nodeID := result.Node.ID
+ result.Release()
+ Expect(registry.DecrementInFlight(context.Background(), nodeID, model, 0)).To(Succeed())
+ return nodeID
+ }
+
+ BeforeEach(func() {
+ infra = SetupInfra("localai_prefix_cache_routing_test")
+
+ var err error
+ db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
+ Logger: logger.Default.LogMode(logger.Silent),
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ registry, err = nodes.NewNodeRegistry(db)
+ Expect(err).ToNot(HaveOccurred())
+
+ // The prefix-cache index is the real radix-tree provider. Keep a handle so
+ // the specs can assert Decide() directly in addition to observing Route().
+ idx = prefixcache.NewIndex(prefixcache.DefaultConfig())
+
+ // Wire the registry chokepoint hook ourselves. In production distributed.go
+ // wires this; a bare SmartRouter test must register it so removal-path
+ // invalidation is exercised end to end. A negative replica index means
+ // "all replicas of the node" (InvalidateNode); otherwise drop the exact
+ // replica.
+ registry.SetReplicaRemovedHook(func(modelName, nodeID string, replica int) {
+ if replica < 0 {
+ idx.InvalidateNode(modelName, nodeID)
+ } else {
+ idx.Invalidate(modelName, prefixcache.ReplicaKey{NodeID: nodeID, Replica: replica})
+ }
+ })
+
+ // Register TWO healthy nodes and mark the model loaded on both (replica 0).
+ nodeX := &nodes.BackendNode{Name: "node-x", Address: "127.0.0.1:50051"}
+ nodeY := &nodes.BackendNode{Name: "node-y", Address: "127.0.0.1:50052"}
+ Expect(registry.Register(context.Background(), nodeX, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), nodeY, true)).To(Succeed())
+ nodeXID = nodeX.ID
+ nodeYID = nodeY.ID
+ Expect(registry.SetNodeModel(context.Background(), nodeXID, model, 0, "loaded", "", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), nodeYID, model, 0, "loaded", "", 0)).To(Succeed())
+
+ factory := &prefixStubClientFactory{client: &prefixStubBackend{healthResult: true}}
+ router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
+ ClientFactory: factory,
+ PrefixProvider: idx,
+ PrefixConfig: prefixcache.DefaultConfig(),
+ DB: db,
+ })
+ })
+
+ It("locks affinity, honors shared prefixes, isolates unrelated chains, and re-homes on failover", func() {
+ now := time.Now()
+ // Both nodes host replica 0 of the model.
+ keys := []prefixcache.ReplicaKey{{NodeID: nodeXID, Replica: 0}, {NodeID: nodeYID, Replica: 0}}
+
+ // --- Step 1: cold miss + observe -------------------------------------
+ // chainA's prefix has never been seen, so there is no hot match yet; the
+ // request cold-places on some loaded node X and the assignment is recorded.
+ Expect(idx.Decide(model, chainA, keys, now).HasHot).To(BeFalse(),
+ "step 1: chainA must be a cold miss (no prior affinity)")
+ placedNode := routeAndSettle(chainA)
+ Expect(placedNode).To(Or(Equal(nodeXID), Equal(nodeYID)))
+ // From here on, "X" is whichever node served chainA first.
+ nodeX := placedNode
+ var nodeY string
+ if nodeX == nodeXID {
+ nodeY = nodeYID
+ } else {
+ nodeY = nodeXID
+ }
+ hotX := prefixcache.ReplicaKey{NodeID: nodeX, Replica: 0}
+ Expect(idx.Decide(model, chainA, keys, time.Now()).Hot).To(Equal(hotX),
+ "step 1: chainA must now be recorded against the replica that served it")
+
+ // --- Step 2: hot-match affinity --------------------------------------
+ // The SAME chain routes back to X.
+ Expect(routeAndSettle(chainA)).To(Equal(nodeX),
+ "step 2: a repeat of chainA must return to its warm node X")
+
+ // --- Step 3: shared-prefix match (the regression we fixed) -----------
+ // A DIFFERENT chain that shares the leading prefix [1,2,3] with X's chain
+ // but diverges at the tail still matches the shared head and routes to X.
+ // Before the radix-tree fix this fell through to a cold placement.
+ Expect(idx.Decide(model, chainShared, keys, time.Now()).Hot).To(Equal(hotX),
+ "step 3: chainShared must hot-match X on the shared prefix")
+ Expect(routeAndSettle(chainShared)).To(Equal(nodeX),
+ "step 3: chainShared must route to X via the shared-prefix match")
+
+ // --- Step 4: negative control ----------------------------------------
+ // A completely unrelated chain shares no prefix with X's chain, so it must
+ // NOT hot-match X's affinity. (Cold placement may still pick X or Y by
+ // load/cacheWeight, but it must not be a false hot match.) Asserting the
+ // provider decision directly is the robust check.
+ Expect(idx.Decide(model, chainUnrelated, keys, time.Now()).HasHot).To(BeFalse(),
+ "step 4: chainUnrelated must be a cold miss, not a false hot match on X")
+
+ // --- Step 5: failover + invalidation ---------------------------------
+ // Remove node X's replica of the model. This fires the registry chokepoint
+ // hook, which invalidates the prefix-cache entry for X. A request for X's
+ // chain must then fail over to the surviving node Y, and the prefix entry
+ // must no longer pin to X (it re-homes to Y on the next observe).
+ Expect(registry.RemoveAllNodeModelReplicas(context.Background(), nodeX, model)).To(Succeed())
+
+ yKeys := []prefixcache.ReplicaKey{{NodeID: nodeY, Replica: 0}}
+ // The chokepoint hook dropped X from the index immediately.
+ Expect(idx.Decide(model, chainA, yKeys, time.Now()).Hot).ToNot(Equal(hotX),
+ "step 5: after X's replica is removed, chainA must no longer pin to X")
+
+ // Route(chainA): only Y still hosts the model, so it fails over to Y.
+ Expect(routeAndSettle(chainA)).To(Equal(nodeY),
+ "step 5: chainA must fail over to the surviving node Y")
+
+ // And the entry has re-homed: chainA now hot-matches Y, never X.
+ reHomed := idx.Decide(model, chainA, yKeys, time.Now())
+ hotY := prefixcache.ReplicaKey{NodeID: nodeY, Replica: 0}
+ Expect(reHomed.Hot).ToNot(Equal(hotX),
+ "step 5: chainA must not re-home to the removed node X")
+ Expect(reHomed.Hot).To(Equal(hotY),
+ "step 5: chainA must re-home to the surviving node Y")
+ })
+
+ It("tracks affinity per replica when ONE node hosts TWO replicas of the model", func() {
+ // This is the bug the replica-granular change fixes: two replicas of the
+ // same model on the SAME node are distinct KV caches. A prefix observed
+ // on replica (node,0) must NOT be reported as hot on the sibling replica
+ // (node,1) of the same node.
+ const multiNodeModel = "multi-replica-model"
+ multiNode := &nodes.BackendNode{Name: "node-multi", Address: "127.0.0.1:50060", MaxReplicasPerModel: 2}
+ Expect(registry.Register(context.Background(), multiNode, true)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 0, "loaded", "addr0", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 1, "loaded", "addr1", 0)).To(Succeed())
+
+ chain := []uint64{42, 43, 44}
+ key0 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 0}
+ key1 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 1}
+
+ // Observe the chain on replica 0 only.
+ idx.Observe(multiNodeModel, chain, key0, time.Now())
+
+ d := idx.Decide(multiNodeModel, chain, []prefixcache.ReplicaKey{key0, key1}, time.Now())
+ Expect(d.HasHot).To(BeTrue())
+ Expect(d.Hot).To(Equal(key0),
+ "the prefix was served by replica 0; the SAME-node sibling replica 1 must NOT be chosen")
+ })
+})