diff --git a/core/application/distributed.go b/core/application/distributed.go index a82ca1931..64e8dc12e 100644 --- a/core/application/distributed.go +++ b/core/application/distributed.go @@ -16,7 +16,9 @@ import ( "github.com/mudler/LocalAI/core/services/jobs" "github.com/mudler/LocalAI/core/services/messaging" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/distributedhdr" "github.com/mudler/LocalAI/pkg/sanitize" "github.com/mudler/xlog" "gorm.io/gorm" @@ -240,6 +242,84 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade cfg.Distributed.BackendUpgradeTimeoutOrDefault(), ) + // Prefix-cache-aware routing. Enabled by default; an operator can opt out + // with --distributed-prefix-cache=false, which leaves prefixProvider and + // pressure nil so the SmartRouter and reconciler behave exactly as the + // round-robin floor (true no-op). When enabled we build the local index, + // wrap it in a NATS-backed Sync (publishes our observations, applies peers' + // via the subscriptions below), install the extraction hook used by + // core/backend/llm.go, and run a background eviction ticker on the app ctx. + var prefixProvider prefixcache.Provider + var pressure *prefixcache.Pressure + var prefixCfg prefixcache.Config + if !cfg.Distributed.PrefixCacheDisabled { + prefixCfg = prefixcache.DefaultConfig() + if cfg.Distributed.PrefixCacheTTL > 0 { + prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL + } + if err := prefixCfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err) + } + idx := prefixcache.NewIndex(prefixCfg) + prefixSync := prefixcache.NewSync(idx, natsClient) + pressure = prefixcache.NewPressure(prefixCfg.PressureWindow) + prefixProvider = prefixSync + + // Invalidate the prefix-cache index whenever a replica row is removed. + // SetReplicaRemovedHook fires from the single chokepoint all removal paths + // funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this + // one hook covers every path: reconciler scale-down, probe reaper, + // health-monitor reap, RemoteUnloaderAdapter, and the router. Registering + // it only inside this enabled block keeps the disabled path a true no-op + // (the registry stays hook-less). + registry.SetReplicaRemovedHook(func(model, node string, replica int) { + if replica < 0 { + prefixSync.InvalidateNode(model, node) + } else { + prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica}) + } + }) + + distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { + return prefixcache.ExtractChain(model, prompt, prefixCfg) + } + + // Apply peers' observations/invalidations to the same Sync. ApplyObserve + // and ApplyInvalidate update only the local index and do not re-publish, + // so there is no broadcast loop. + if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) { + prefixSync.ApplyObserve(ev, time.Now()) + }); err != nil { + return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err) + } + if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) { + prefixSync.ApplyInvalidate(ev) + }); err != nil { + return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err) + } + + // Background eviction: sweep idle entries on the app context. Stopped + // when the app context is cancelled (mirrors the reconciler loop which + // also runs on options.Context). TTL/2 keeps stale entries from + // outliving their idle window by more than half a TTL. + evictInterval := prefixCfg.TTL / 2 + go func() { + ticker := time.NewTicker(evictInterval) + defer ticker.Stop() + for { + select { + case <-cfg.Context.Done(): + return + case <-ticker.C: + prefixSync.Evict(time.Now()) + } + } + }() + xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval) + } else { + xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing") + } + // All dependencies ready — build SmartRouter with all options at once var conflictResolver nodes.ConcurrencyConflictResolver if configLoader != nil { @@ -252,6 +332,9 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade AuthToken: routerAuthToken, DB: authDB, ConflictResolver: conflictResolver, + PrefixProvider: prefixProvider, + PrefixConfig: prefixCfg, + Pressure: pressure, }) // Create ReplicaReconciler for auto-scaling model replicas. Adapter + @@ -268,6 +351,8 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade Interval: 30 * time.Second, ScaleDownDelay: 5 * time.Minute, ProbeStaleAfter: 2 * time.Minute, + Pressure: pressure, + PressureThreshold: prefixCfg.PressureScaleThreshold, }) // Create ModelRouterAdapter to wire into ModelLoader diff --git a/core/backend/llm.go b/core/backend/llm.go index 053e984e8..4f6b4d216 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -19,6 +19,7 @@ import ( "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/distributedhdr" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" @@ -94,6 +95,22 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima } } + // Make the rendered prompt's prefix chain available to the distributed router + // for prefix-cache-aware node selection. No-op in single-process mode. The + // model id MUST match the id ModelOptions feeds to model.WithModelID, so both + // use the shared config.ModelConfig.ModelID() helper (Name with a fallback to + // Model) or the chain salt and the tracking key would diverge. + // + // s is empty for UseTokenizerTemplate models (the backend tokenizes the + // structured messages itself), so fall back to a prefix-stable serialization + // of the messages - otherwise prefix routing would silently degrade to + // round-robin for the bulk of modern chat models. + chainSource := s + if chainSource == "" { + chainSource = messagesPrefixSource(messages) + } + ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource) + opts := ModelOptions(*c, o, model.WithContext(ctx)) inferenceModel, err := loader.Load(opts...) if err != nil { diff --git a/core/backend/options.go b/core/backend/options.go index 0215bf37a..c891b6d67 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -34,16 +34,11 @@ func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, back } func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { - name := c.Name - if name == "" { - name = c.Model - } - defOpts := []model.Option{ model.WithBackendString(c.Backend), model.WithModel(c.Model), model.WithContext(so.Context), - model.WithModelID(name), + model.WithModelID(c.ModelID()), } threads := 1 diff --git a/core/backend/prefix_source.go b/core/backend/prefix_source.go new file mode 100644 index 000000000..2623033e3 --- /dev/null +++ b/core/backend/prefix_source.go @@ -0,0 +1,36 @@ +package backend + +import ( + "strings" + + "github.com/mudler/LocalAI/core/schema" +) + +// messagesPrefixSource builds a deterministic, prefix-stable serialization of a +// chat conversation for prefix-cache-aware routing. It is the fallback used when +// the frontend did not render a prompt string: models with +// config.TemplateConfig.UseTokenizerTemplate tokenize the structured messages +// backend-side, so the frontend's rendered prompt is empty and a chain built +// from it would always be empty - silently degrading prefix routing to +// round-robin for the bulk of modern chat models. +// +// Messages are emitted head-first in turn order (role line + content line per +// message), so two conversations sharing a leading system prompt and early turns +// share a leading byte prefix. That is exactly what ExtractChain hashes into a +// shared chain prefix, landing both requests on the same cache-warm replica. +func messagesPrefixSource(messages schema.Messages) string { + var b strings.Builder + for _, m := range messages { + b.WriteString(m.Role) + b.WriteByte('\n') + content := m.StringContent + if content == "" { + if s, ok := m.Content.(string); ok { + content = s + } + } + b.WriteString(content) + b.WriteByte('\n') + } + return b.String() +} diff --git a/core/backend/prefix_source_internal_test.go b/core/backend/prefix_source_internal_test.go new file mode 100644 index 000000000..0395b35e4 --- /dev/null +++ b/core/backend/prefix_source_internal_test.go @@ -0,0 +1,53 @@ +package backend + +import ( + "strings" + + "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("messagesPrefixSource", func() { + mk := func(role, content string) schema.Message { + return schema.Message{Role: role, StringContent: content} + } + + It("serializes messages head-first in turn order", func() { + got := messagesPrefixSource(schema.Messages{ + mk("system", "You are helpful."), + mk("user", "Hi"), + }) + Expect(got).To(Equal("system\nYou are helpful.\nuser\nHi\n")) + }) + + It("is deterministic across calls for the same conversation", func() { + conv := schema.Messages{mk("system", "S"), mk("user", "U")} + Expect(messagesPrefixSource(conv)).To(Equal(messagesPrefixSource(conv))) + }) + + It("shares a leading byte prefix when the system prompt is shared", func() { + shared := "system\nShared system prompt.\nuser\n" + a := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question A")}) + b := messagesPrefixSource(schema.Messages{mk("system", "Shared system prompt."), mk("user", "Question B")}) + Expect(strings.HasPrefix(a, shared)).To(BeTrue()) + Expect(strings.HasPrefix(b, shared)).To(BeTrue()) + }) + + It("does NOT share a prefix when the system prompt differs", func() { + a := messagesPrefixSource(schema.Messages{mk("system", "Prompt A"), mk("user", "Q")}) + b := messagesPrefixSource(schema.Messages{mk("system", "Prompt B"), mk("user", "Q")}) + Expect(strings.HasPrefix(a, "system\nPrompt A")).To(BeTrue()) + Expect(strings.HasPrefix(b, "system\nPrompt B")).To(BeTrue()) + }) + + It("returns empty for no messages", func() { + Expect(messagesPrefixSource(nil)).To(Equal("")) + }) + + It("falls back to Content when StringContent is empty", func() { + got := messagesPrefixSource(schema.Messages{{Role: "user", Content: "plain"}}) + Expect(got).To(Equal("user\nplain\n")) + }) +}) diff --git a/core/cli/run.go b/core/cli/run.go index 09a58971b..8bbc2b20c 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -145,19 +145,21 @@ type RunCMD struct { DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"` // Distributed / Horizontal Scaling - Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"` - InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"` - NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"` - StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"` - StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"` - StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"` - StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"` - StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"` - RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"` - AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"` - BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"` - BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"` - ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"` + Distributed bool `env:"LOCALAI_DISTRIBUTED" default:"false" help:"Enable distributed mode (requires PostgreSQL + NATS)" group:"distributed"` + InstanceID string `env:"LOCALAI_INSTANCE_ID" help:"Unique instance ID for distributed mode (auto-generated UUID if empty)" group:"distributed"` + NatsURL string `env:"LOCALAI_NATS_URL" help:"NATS server URL (e.g., nats://localhost:4222)" group:"distributed"` + StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3-compatible storage endpoint URL (e.g., http://minio:9000)" group:"distributed"` + StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" default:"localai" help:"S3 bucket name for object storage" group:"distributed"` + StorageRegion string `env:"LOCALAI_STORAGE_REGION" default:"us-east-1" help:"S3 region" group:"distributed"` + StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key ID" group:"distributed"` + StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret access key" group:"distributed"` + RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token that backend nodes must provide to register (empty = no auth required)" group:"distributed"` + AutoApproveNodes bool `env:"LOCALAI_AUTO_APPROVE_NODES" default:"false" help:"Auto-approve new worker nodes (skip admin approval)" group:"distributed"` + DistributedPrefixCache bool `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE" default:"true" help:"Enable prefix-cache-aware routing in distributed mode (default true). When false, routing falls back to round-robin." group:"distributed"` + DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"` + BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"` + BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"` + ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"` Version bool @@ -284,6 +286,16 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { if r.AutoApproveNodes { opts = append(opts, config.EnableAutoApproveNodes) } + if !r.DistributedPrefixCache { + opts = append(opts, config.DisablePrefixCache) + } + if r.DistributedPrefixCacheTTL != "" { + d, err := time.ParseDuration(r.DistributedPrefixCacheTTL) + if err != nil { + return fmt.Errorf("invalid LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL %q: %w", r.DistributedPrefixCacheTTL, err) + } + opts = append(opts, config.WithPrefixCacheTTL(d)) + } if r.ExposeNodeHeader { opts = append(opts, config.WithExposeNodeHeader(true)) } diff --git a/core/config/distributed_config.go b/core/config/distributed_config.go index 0ff783321..e0e0454d9 100644 --- a/core/config/distributed_config.go +++ b/core/config/distributed_config.go @@ -49,6 +49,17 @@ type DistributedConfig struct { AgentWorkerConcurrency int `yaml:"agent_worker_concurrency" json:"agent_worker_concurrency" env:"LOCALAI_AGENT_WORKER_CONCURRENCY"` JobWorkerConcurrency int `yaml:"job_worker_concurrency" json:"job_worker_concurrency" env:"LOCALAI_JOB_WORKER_CONCURRENCY"` + + // PrefixCacheDisabled turns off prefix-cache-aware routing, falling back to + // round-robin (the floor). Prefix-cache routing is ON by default in + // distributed mode; this flag exists so operators can opt out. The CLI + // surfaces a default-true --distributed-prefix-cache enable flag and sets + // this when the operator passes --distributed-prefix-cache=false. + PrefixCacheDisabled bool + // PrefixCacheTTL is the idle-timeout for prefix-cache index entries and + // drives the background eviction cadence (eviction runs every TTL/2). Zero + // means use the prefixcache package default (5m). + PrefixCacheTTL time.Duration } // Validate checks that the distributed configuration is internally consistent. @@ -158,6 +169,20 @@ var EnableAutoApproveNodes = func(o *ApplicationConfig) { o.Distributed.AutoApproveNodes = true } +// DisablePrefixCache turns off prefix-cache-aware routing (falls back to +// round-robin). Prefix-cache routing is enabled by default in distributed mode. +var DisablePrefixCache = func(o *ApplicationConfig) { + o.Distributed.PrefixCacheDisabled = true +} + +// WithPrefixCacheTTL sets the prefix-cache index idle-timeout (and the +// background eviction cadence, which runs every TTL/2). +func WithPrefixCacheTTL(d time.Duration) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.PrefixCacheTTL = d + } +} + // Flag names for distributed timeout / interval configuration. These are // the kebab-case identifiers kong derives from the matching RunCMD struct // fields; they appear in Validate error messages and any other operator- diff --git a/core/config/model_config.go b/core/config/model_config.go index d57544c6f..a1b000798 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -694,6 +694,18 @@ func (c *ModelConfig) IsModelURL() bool { return uri.LooksLikeURL() } +// ModelID returns the identifier used to reference this model across the +// system: the configured Name, falling back to Model when Name is empty. +// This is the single source of truth for the id fed to model.WithModelID and +// the prefix-cache chain salt; both MUST agree with the router's tracking key +// or the prefix-cache salt diverges silently. +func (c ModelConfig) ModelID() string { + if c.Name != "" { + return c.Name + } + return c.Model +} + // ModelFileName returns the filename of the model // If the model is a URL, it will return the MD5 of the URL which is the filename func (c *ModelConfig) ModelFileName() string { diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go index a93912e1b..5abfb8eaf 100644 --- a/core/config/model_config_test.go +++ b/core/config/model_config_test.go @@ -10,6 +10,23 @@ import ( ) var _ = Describe("Test cases for config related functions", func() { + Context("ModelID", func() { + It("returns Name when set", func() { + c := ModelConfig{Name: "my-name"} + c.Model = "my-model" + Expect(c.ModelID()).To(Equal("my-name")) + }) + It("falls back to Model when Name is empty", func() { + c := ModelConfig{} + c.Model = "my-model" + Expect(c.ModelID()).To(Equal("my-model")) + }) + It("returns empty string when both are empty", func() { + c := ModelConfig{} + Expect(c.ModelID()).To(Equal("")) + }) + }) + Context("Test Read configuration functions", func() { It("Test Validate", func() { tmp, err := os.CreateTemp("", "config.yaml") diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go index 83fd2170c..90dbe70b0 100644 --- a/core/http/endpoints/localai/nodes.go +++ b/core/http/endpoints/localai/nodes.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -25,6 +26,7 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/LocalAI/pkg/httpclient" ) @@ -911,14 +913,56 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { } // SetSchedulingRequest is the request body for creating/updating a scheduling config. +// +// The four prefix-cache fields are POINTERS so an omitted field is +// distinguishable from an explicit zero. On update, an omitted prefix-cache +// field preserves the model's previously-configured value instead of resetting +// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector, +// MinReplicas and MaxReplicas keep their full-replace PUT semantics. type SetSchedulingRequest struct { - ModelName string `json:"model_name"` - NodeSelector map[string]string `json:"node_selector,omitempty"` - MinReplicas int `json:"min_replicas"` - MaxReplicas int `json:"max_replicas"` + ModelName string `json:"model_name"` + NodeSelector map[string]string `json:"node_selector,omitempty"` + MinReplicas int `json:"min_replicas"` + MaxReplicas int `json:"max_replicas"` + RoutePolicy *string `json:"route_policy,omitempty"` + BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"` + BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"` + MinPrefixMatch *float64 `json:"min_prefix_match,omitempty"` +} + +// validateSchedulingRequest enforces the invariants of a scheduling config. +// The prefix-cache bounds are delegated to prefixcache.ValidateThresholds (the +// single source of truth), and are checked against the RESOLVED values passed +// in (provided-or-preserved), so validation only rejects bad values the caller +// actually supplied. It returns nil when valid, or an error with a user-facing +// message describing the first violation. +func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, absThr int, relThr, minMatch float64) error { + if req.ModelName == "" { + return errors.New("model_name is required") + } + if req.MinReplicas < 0 { + return errors.New("min_replicas must be >= 0") + } + if req.MaxReplicas < 0 { + return errors.New("max_replicas must be >= 0") + } + if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas { + return errors.New("min_replicas must be <= max_replicas") + } + if err := prefixcache.ValidateThresholds(routePolicy, absThr, relThr, minMatch); err != nil { + return err + } + return nil } // SetSchedulingEndpoint creates or updates a model scheduling config. +// +// The registry upsert full-replaces all columns, so a request that omits the +// prefix-cache fields would otherwise wipe a model's previously-configured +// routing settings. To avoid that footgun the four prefix-cache fields are +// merged PATCH-style: a non-nil request pointer wins; a nil one preserves the +// existing config's value (or the zero default when no config exists yet). The +// non-prefix fields keep their full-replace PUT behavior. func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { return func(c echo.Context) error { ctx := c.Request().Context() @@ -926,17 +970,45 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body")) } - if req.ModelName == "" { - return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required")) + + // Fetch the existing config (may be nil) so omitted prefix-cache fields + // can fall back to the stored value rather than resetting to zero. + var existing *nodes.ModelSchedulingConfig + if req.ModelName != "" { + var err error + existing, err = registry.GetModelScheduling(ctx, req.ModelName) + if err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to load existing scheduling config")) + } } - if req.MinReplicas < 0 { - return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0")) + + // Resolve each prefix-cache field: provided pointer wins, otherwise keep + // the existing value (zero/default when there is no existing config). + routePolicy := "" + absThr := 0 + relThr := 0.0 + minMatch := 0.0 + if existing != nil { + routePolicy = existing.RoutePolicy + absThr = existing.BalanceAbsThreshold + relThr = existing.BalanceRelThreshold + minMatch = existing.MinPrefixMatch } - if req.MaxReplicas < 0 { - return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0")) + if req.RoutePolicy != nil { + routePolicy = *req.RoutePolicy } - if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas { - return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas")) + if req.BalanceAbsThreshold != nil { + absThr = *req.BalanceAbsThreshold + } + if req.BalanceRelThreshold != nil { + relThr = *req.BalanceRelThreshold + } + if req.MinPrefixMatch != nil { + minMatch = *req.MinPrefixMatch + } + + if err := validateSchedulingRequest(req, routePolicy, absThr, relThr, minMatch); err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, err.Error())) } // Serialize node selector to JSON @@ -950,10 +1022,14 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { } config := &nodes.ModelSchedulingConfig{ - ModelName: req.ModelName, - NodeSelector: selectorJSON, - MinReplicas: req.MinReplicas, - MaxReplicas: req.MaxReplicas, + ModelName: req.ModelName, + NodeSelector: selectorJSON, + MinReplicas: req.MinReplicas, + MaxReplicas: req.MaxReplicas, + RoutePolicy: routePolicy, + BalanceAbsThreshold: absThr, + BalanceRelThreshold: relThr, + MinPrefixMatch: minMatch, } if err := registry.SetModelScheduling(ctx, config); err != nil { return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config")) diff --git a/core/http/endpoints/localai/nodes_scheduling_validation_test.go b/core/http/endpoints/localai/nodes_scheduling_validation_test.go new file mode 100644 index 000000000..800a14d7a --- /dev/null +++ b/core/http/endpoints/localai/nodes_scheduling_validation_test.go @@ -0,0 +1,66 @@ +package localai + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("validateSchedulingRequest", func() { + base := func() SetSchedulingRequest { + return SetSchedulingRequest{ModelName: "m"} + } + + It("accepts an empty route policy (inherit) with valid thresholds", func() { + Expect(validateSchedulingRequest(base(), "", 3, 0, 0.4)).To(Succeed()) + }) + + It("accepts the prefix_cache policy", func() { + Expect(validateSchedulingRequest(base(), "prefix_cache", 0, 0, 0)).To(Succeed()) + }) + + It("accepts the round_robin policy", func() { + Expect(validateSchedulingRequest(base(), "round_robin", 0, 0, 0)).To(Succeed()) + }) + + It("accepts balance_rel_threshold >= 1", func() { + Expect(validateSchedulingRequest(base(), "", 0, 1.5, 0)).To(Succeed()) + }) + + It("rejects a missing model_name", func() { + req := base() + req.ModelName = "" + err := validateSchedulingRequest(req, "", 0, 0, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("model_name is required")) + }) + + It("rejects an unknown route_policy (no silent default)", func() { + err := validateSchedulingRequest(base(), "bogus", 0, 0, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("route_policy")) + }) + + It("rejects min_prefix_match above 1", func() { + err := validateSchedulingRequest(base(), "", 0, 0, 2) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("min_prefix_match")) + }) + + It("rejects a negative min_prefix_match", func() { + err := validateSchedulingRequest(base(), "", 0, 0, -0.1) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("min_prefix_match")) + }) + + It("rejects a negative balance_abs_threshold", func() { + err := validateSchedulingRequest(base(), "", -1, 0, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("balance_abs_threshold")) + }) + + It("rejects balance_rel_threshold between 0 and 1 exclusive", func() { + err := validateSchedulingRequest(base(), "", 0, 0.5, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("balance_rel_threshold")) + }) +}) diff --git a/core/http/endpoints/localai/nodes_test.go b/core/http/endpoints/localai/nodes_test.go index f9eb8071b..fdb29987d 100644 --- a/core/http/endpoints/localai/nodes_test.go +++ b/core/http/endpoints/localai/nodes_test.go @@ -230,6 +230,114 @@ var _ = Describe("Node HTTP handlers", func() { }) }) + Describe("SetSchedulingEndpoint", func() { + postScheduling := func(body string) *httptest.ResponseRecorder { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := SetSchedulingEndpoint(registry) + Expect(handler(c)).To(Succeed()) + return rec + } + + It("persists prefix-cache fields and round-trips them via GET", func() { + ctx := context.Background() + rec := postScheduling(`{"model_name":"pc-model","route_policy":"prefix_cache","balance_abs_threshold":3,"min_prefix_match":0.4}`) + Expect(rec.Code).To(Equal(http.StatusOK)) + + cfg, err := registry.GetModelScheduling(ctx, "pc-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg).ToNot(BeNil()) + Expect(cfg.RoutePolicy).To(Equal("prefix_cache")) + Expect(cfg.BalanceAbsThreshold).To(Equal(3)) + Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9)) + + e := echo.New() + getReq := httptest.NewRequest(http.MethodGet, "/", nil) + getRec := httptest.NewRecorder() + gc := e.NewContext(getReq, getRec) + gc.SetParamNames("model") + gc.SetParamValues("pc-model") + Expect(GetSchedulingEndpoint(registry)(gc)).To(Succeed()) + Expect(getRec.Code).To(Equal(http.StatusOK)) + + var got nodes.ModelSchedulingConfig + Expect(json.Unmarshal(getRec.Body.Bytes(), &got)).To(Succeed()) + Expect(got.RoutePolicy).To(Equal("prefix_cache")) + Expect(got.BalanceAbsThreshold).To(Equal(3)) + Expect(got.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9)) + }) + + It("returns 400 for an out-of-range min_prefix_match", func() { + rec := postScheduling(`{"model_name":"bad-mpm","min_prefix_match":2}`) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("min_prefix_match")) + }) + + It("returns 400 for an unknown route_policy", func() { + rec := postScheduling(`{"model_name":"bad-policy","route_policy":"bogus"}`) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("route_policy")) + }) + + It("returns 400 for a balance_rel_threshold between 0 and 1", func() { + rec := postScheduling(`{"model_name":"bad-rel","balance_rel_threshold":0.5}`) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue()) + Expect(errObj["message"]).To(ContainSubstring("balance_rel_threshold")) + }) + + // Regression for the partial-update footgun: a min/max-only POST used to + // full-replace every column and silently reset the prefix-cache settings + // to empty/zero. The pointer-merge must preserve omitted prefix fields. + It("preserves prefix-cache settings across a min_replicas-only update", func() { + ctx := context.Background() + + rec := postScheduling(`{"model_name":"merge-model","route_policy":"prefix_cache","min_prefix_match":0.4}`) + Expect(rec.Code).To(Equal(http.StatusOK)) + + // Update only min_replicas - omits all prefix-cache fields. + rec = postScheduling(`{"model_name":"merge-model","min_replicas":2}`) + Expect(rec.Code).To(Equal(http.StatusOK)) + + cfg, err := registry.GetModelScheduling(ctx, "merge-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg).ToNot(BeNil()) + Expect(cfg.MinReplicas).To(Equal(2), "the provided non-prefix field must update") + Expect(cfg.RoutePolicy).To(Equal("prefix_cache"), "omitted route_policy must be preserved") + Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must be preserved") + }) + + It("updates a prefix-cache field when it is explicitly provided", func() { + ctx := context.Background() + + rec := postScheduling(`{"model_name":"update-model","route_policy":"prefix_cache","min_prefix_match":0.4}`) + Expect(rec.Code).To(Equal(http.StatusOK)) + + rec = postScheduling(`{"model_name":"update-model","route_policy":"round_robin"}`) + Expect(rec.Code).To(Equal(http.StatusOK)) + + cfg, err := registry.GetModelScheduling(ctx, "update-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg).ToNot(BeNil()) + Expect(cfg.RoutePolicy).To(Equal("round_robin"), "explicitly provided route_policy must update") + Expect(cfg.MinPrefixMatch).To(BeNumerically("~", 0.4, 1e-9), "omitted min_prefix_match must still be preserved") + }) + }) + Describe("ListNodesEndpoint", func() { It("returns an empty list when no nodes are registered", func() { e := echo.New() diff --git a/core/http/react-ui/src/pages/Nodes.jsx b/core/http/react-ui/src/pages/Nodes.jsx index 4acfb5ebd..f2eb9d955 100644 --- a/core/http/react-ui/src/pages/Nodes.jsx +++ b/core/http/react-ui/src/pages/Nodes.jsx @@ -493,6 +493,13 @@ function SchedulingForm({ onSave, onCancel }) { const [selector, setSelector] = useState({}) const [minReplicas, setMinReplicas] = useState(1) const [maxReplicas, setMaxReplicas] = useState(0) + // Prefix-cache routing controls. Empty routePolicy means "inherit the + // cluster default"; the three thresholds at 0 likewise inherit, so they + // stay out of the POST body's effective override only when explicitly set. + const [routePolicy, setRoutePolicy] = useState('') + const [balanceAbsThreshold, setBalanceAbsThreshold] = useState(0) + const [balanceRelThreshold, setBalanceRelThreshold] = useState(0) + const [minPrefixMatch, setMinPrefixMatch] = useState(0) const hasSelector = Object.keys(selector).length > 0 @@ -508,6 +515,10 @@ function SchedulingForm({ onSave, onCancel }) { node_selector: hasSelector ? selector : undefined, min_replicas: mode === 'placement' ? 0 : minReplicas, max_replicas: mode === 'placement' ? 0 : maxReplicas, + route_policy: routePolicy, + balance_abs_threshold: balanceAbsThreshold, + balance_rel_threshold: balanceRelThreshold, + min_prefix_match: minPrefixMatch, }) } @@ -593,6 +604,76 @@ function SchedulingForm({ onSave, onCancel }) { /> )} + + {/* Per-model routing policy. Left empty/zero these inherit the + cluster-wide defaults; set them to override how requests for this + model are spread across replicas. */} +
+ + + + Prefix Cache routes shared-prefix requests to the same replica to reuse its KV cache, falling back to round-robin when replicas are imbalanced. + +
+ + {routePolicy === 'prefix_cache' && ( +
+
+ + setMinPrefixMatch(parseFloat(e.target.value) || 0)} + /> + + Fraction of the prompt (0..1) that must match a cached prefix before affinity kicks in. 0 inherits the default. + +
+
+ + setBalanceAbsThreshold(parseInt(e.target.value) || 0)} + /> + + Max absolute in-flight gap allowed before falling back to round-robin. 0 inherits the default. + +
+
+ + setBalanceRelThreshold(parseFloat(e.target.value) || 0)} + /> + + Max relative in-flight ratio (>= 1) allowed before falling back to round-robin. 0 inherits the default. + +
+
+ )} {/* Hairline divider above the actions, matching the project's form pattern. */} @@ -1475,6 +1556,8 @@ export default function Nodes() { Node Selector Min Replicas Max Replicas + Routing + Thresholds Status Actions @@ -1519,6 +1602,18 @@ export default function Nodes() { {isAutoScaling ? (cfg.max_replicas || 'no limit') : '-'} + + {cfg.route_policy || 'default'} + + + {cfg.route_policy === 'prefix_cache' ? ( + <> +
match: {cfg.min_prefix_match ? cfg.min_prefix_match : 'inherit'}
+
abs: {cfg.balance_abs_threshold ? cfg.balance_abs_threshold : 'inherit'}
+
rel: {cfg.balance_rel_threshold ? cfg.balance_rel_threshold : 'inherit'}
+ + ) : '-'} + {isUnsatisfiable ? ( = 0 it targets the single replica (Model, NodeID, Replica). When +// Replica < 0 it targets ALL replicas of (Model, NodeID), for example when a +// whole node goes offline. +type PrefixCacheInvalidateEvent struct { + Model string `json:"model"` + NodeID string `json:"node_id"` + Replica int `json:"replica"` +} diff --git a/core/services/messaging/subjects_prefixcache_test.go b/core/services/messaging/subjects_prefixcache_test.go new file mode 100644 index 000000000..2b8eb5771 --- /dev/null +++ b/core/services/messaging/subjects_prefixcache_test.go @@ -0,0 +1,27 @@ +package messaging_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/messaging" +) + +var _ = Describe("PrefixCache subjects", func() { + It("exposes stable subject constants", func() { + Expect(messaging.SubjectPrefixCacheObserve).To(Equal("prefixcache.observe")) + Expect(messaging.SubjectPrefixCacheInvalidate).To(Equal("prefixcache.invalidate")) + }) + + It("carries a replica index on the observe event", func() { + ev := messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 3} + Expect(ev.Replica).To(Equal(3)) + }) + + It("uses a negative replica on the invalidate event to mean all replicas of a node", func() { + all := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1} + Expect(all.Replica).To(BeNumerically("<", 0)) + one := messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0} + Expect(one.Replica).To(Equal(0)) + }) +}) diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go index c1dcd2af9..4e82d56cf 100644 --- a/core/services/nodes/interfaces.go +++ b/core/services/nodes/interfaces.go @@ -9,7 +9,7 @@ import ( // ModelRouter is used by SmartRouter for routing decisions and model lifecycle. type ModelRouter interface { - FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error) + FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error @@ -37,6 +37,7 @@ type ModelRouter interface { FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error) + LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error) } // ConcurrencyConflictResolver returns the names of configured models that diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go index 190dac731..1670808c7 100644 --- a/core/services/nodes/model_router_test.go +++ b/core/services/nodes/model_router_test.go @@ -27,7 +27,7 @@ func newFakeModelRouterForSmartRouter() *fakeModelRouterForSmartRouter { } } -func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string) (*BackendNode, *NodeModel, error) { +func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Context, _ string, _ []string, _ *RoutePreference) (*BackendNode, *NodeModel, error) { f.mu.Lock() defer f.mu.Unlock() return f.node, f.nodeModel, f.findErr @@ -121,6 +121,9 @@ func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ strin func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) { return nil, nil } +func (f *fakeModelRouterForSmartRouter) LoadedReplicaStats(_ context.Context, _ string, _ []string) ([]ReplicaCandidate, error) { + return nil, nil +} // Compile-time check var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil) diff --git a/core/services/nodes/prefixcache/config.go b/core/services/nodes/prefixcache/config.go new file mode 100644 index 000000000..dc897f0b0 --- /dev/null +++ b/core/services/nodes/prefixcache/config.go @@ -0,0 +1,95 @@ +package prefixcache + +import ( + "fmt" + "time" +) + +// Config holds prefix-cache-aware routing settings. Per-model overrides +// (policy, abs/rel thresholds, min-match) live on ModelSchedulingConfig; TTL +// and window/depth are global-only. +type Config struct { + GlobalPolicy RoutePolicy + MinPrefixMatch float64 // ratio matched/total, [0,1] + BalanceAbsThreshold int // absolute in-flight slack + BalanceRelThreshold float64 // relative load ratio, >= 1 + TTL time.Duration // idle-timeout for entries + HalfLife time.Duration // recency decay for cacheWeight + WindowBytes int // chunk window size + MaxDepth int // max trailing blocks hashed + // PressureWindow is the rolling window over which forced-disturb events are + // counted for the autoscale signal (see Pressure). Default 1 minute. + PressureWindow time.Duration + // PressureScaleThreshold is the minimum forced-disturb count within + // PressureWindow that makes the reconciler treat the cache-warm replica as + // saturated and scale up (subject to MaxReplicas and capacity). Default 1, + // i.e. any sustained forced-disturb. + PressureScaleThreshold int +} + +func DefaultConfig() Config { + return Config{ + GlobalPolicy: RoutePolicyPrefixCache, + MinPrefixMatch: 0.3, + BalanceAbsThreshold: 2, + BalanceRelThreshold: 1.5, + TTL: 5 * time.Minute, + HalfLife: 2 * time.Minute, + WindowBytes: 256, + MaxDepth: 64, + PressureWindow: time.Minute, + PressureScaleThreshold: 1, + } +} + +// validateThresholdBounds enforces the numeric bounds shared between the +// per-model override validator (ValidateThresholds) and Config.Validate: +// minMatch in [0,1]; absThr >= 0; relThr == 0 (inherit) or >= 1. It is the +// single source of truth for those bounds so the endpoint and the global +// config cannot drift apart. +func validateThresholdBounds(absThr int, relThr, minMatch float64) error { + if minMatch < 0 || minMatch > 1 { + return fmt.Errorf("prefixcache: min_prefix_match must be in [0,1], got %v", minMatch) + } + if absThr < 0 { + return fmt.Errorf("prefixcache: balance_abs_threshold must be >= 0, got %d", absThr) + } + if relThr != 0 && relThr < 1 { + return fmt.Errorf("prefixcache: balance_rel_threshold must be 0 (inherit) or >= 1, got %v", relThr) + } + return nil +} + +// ValidateThresholds checks per-model override bounds. routePolicy must be one +// of "", "round_robin", "prefix_cache" (explicit allow-list - NOT ParsePolicy, +// which maps unknown to Default and would accept typos). minMatch in [0,1]; +// absThr >= 0; relThr == 0 (inherit) or >= 1. +func ValidateThresholds(routePolicy string, absThr int, relThr, minMatch float64) error { + switch routePolicy { + case "", "round_robin", "prefix_cache": + default: + return fmt.Errorf(`prefixcache: route_policy must be one of "", "round_robin", "prefix_cache", got %q`, routePolicy) + } + return validateThresholdBounds(absThr, relThr, minMatch) +} + +func (c Config) Validate() error { + // Config.BalanceRelThreshold has no "inherit" sentinel - it is a concrete + // global value that must be >= 1 - so pass 0 for relThr to the shared + // numeric check and assert the >= 1 floor here separately. + if err := validateThresholdBounds(c.BalanceAbsThreshold, 0, c.MinPrefixMatch); err != nil { + return err + } + if c.BalanceRelThreshold < 1 { + return fmt.Errorf("prefixcache: balance_rel_threshold must be >= 1, got %v", c.BalanceRelThreshold) + } + if c.WindowBytes <= 0 || c.MaxDepth <= 0 { + return fmt.Errorf("prefixcache: window_bytes and max_depth must be > 0") + } + // TTL must be positive: it is the entry idle-lifetime and the eviction + // ticker runs at TTL/2, so time.NewTicker would panic on TTL <= 0. + if c.TTL <= 0 { + return fmt.Errorf("prefixcache: ttl must be > 0, got %v", c.TTL) + } + return nil +} diff --git a/core/services/nodes/prefixcache/config_test.go b/core/services/nodes/prefixcache/config_test.go new file mode 100644 index 000000000..07ed0466f --- /dev/null +++ b/core/services/nodes/prefixcache/config_test.go @@ -0,0 +1,73 @@ +package prefixcache_test + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +var _ = Describe("Config", func() { + It("supplies defaults", func() { + c := prefixcache.DefaultConfig() + Expect(c.GlobalPolicy).To(Equal(prefixcache.RoutePolicyPrefixCache)) // default ON + Expect(c.MinPrefixMatch).To(BeNumerically("==", 0.3)) + Expect(c.BalanceAbsThreshold).To(Equal(2)) + Expect(c.BalanceRelThreshold).To(BeNumerically("==", 1.5)) + Expect(c.TTL).To(Equal(5 * time.Minute)) + Expect(c.WindowBytes).To(Equal(256)) + Expect(c.MaxDepth).To(Equal(64)) + }) + + It("rejects invalid values", func() { + c := prefixcache.DefaultConfig() + c.MinPrefixMatch = 1.5 + Expect(c.Validate()).To(HaveOccurred()) + c = prefixcache.DefaultConfig() + c.BalanceAbsThreshold = -1 + Expect(c.Validate()).To(HaveOccurred()) + c = prefixcache.DefaultConfig() + c.TTL = 0 + Expect(c.Validate()).To(HaveOccurred()) // TTL/2 ticker would panic + }) +}) + +var _ = Describe("ValidateThresholds", func() { + It("accepts valid values across all route policies", func() { + Expect(prefixcache.ValidateThresholds("", 3, 0, 0.4)).To(Succeed()) + Expect(prefixcache.ValidateThresholds("round_robin", 0, 1.5, 0)).To(Succeed()) + Expect(prefixcache.ValidateThresholds("prefix_cache", 2, 2.0, 1.0)).To(Succeed()) + }) + + It("rejects an unknown route_policy (explicit allow-list, no silent default)", func() { + err := prefixcache.ValidateThresholds("bogus", 0, 0, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("route_policy")) + }) + + It("rejects min_prefix_match above 1", func() { + err := prefixcache.ValidateThresholds("", 0, 0, 1.5) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("min_prefix_match")) + }) + + It("rejects a negative min_prefix_match", func() { + err := prefixcache.ValidateThresholds("", 0, 0, -0.1) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("min_prefix_match")) + }) + + It("rejects a negative balance_abs_threshold", func() { + err := prefixcache.ValidateThresholds("", -1, 0, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("balance_abs_threshold")) + }) + + It("rejects balance_rel_threshold between 0 and 1 exclusive", func() { + err := prefixcache.ValidateThresholds("", 0, 0.5, 0) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("balance_rel_threshold")) + }) +}) diff --git a/core/services/nodes/prefixcache/export_test.go b/core/services/nodes/prefixcache/export_test.go new file mode 100644 index 000000000..dbfe62166 --- /dev/null +++ b/core/services/nodes/prefixcache/export_test.go @@ -0,0 +1,18 @@ +package prefixcache + +// LenForTest exposes the internal per-model slice length so black-box tests can +// assert that Record bounds its backing slice. Test-only. +func (p *Pressure) LenForTest(model string) int { + p.mu.Lock() + defer p.mu.Unlock() + return len(p.events[model]) +} + +// TreeCountForTest exposes the number of per-model radix trees the Index +// currently retains, so black-box tests can assert that Invalidate does not +// intern empty trees for models that never used the prefix cache. Test-only. +func (ix *Index) TreeCountForTest() int { + ix.mu.RLock() + defer ix.mu.RUnlock() + return len(ix.trees) +} diff --git a/core/services/nodes/prefixcache/extractor.go b/core/services/nodes/prefixcache/extractor.go new file mode 100644 index 000000000..65531bafb --- /dev/null +++ b/core/services/nodes/prefixcache/extractor.go @@ -0,0 +1,57 @@ +package prefixcache + +import ( + "encoding/binary" + + "github.com/cespare/xxhash/v2" +) + +// ExtractChain renders prompt into a cumulative chain of prefix hashes: +// h[0]=H(salt,block0), h[i]=H(h[i-1],block_i). Blocks are fixed +// cfg.WindowBytes-byte windows over the prompt bytes, chunked from absolute +// offset 0 with fixed boundaries [0,W), [W,2W), ... and the chain is capped to +// the FIRST cfg.MaxDepth blocks (the head). +// +// Head-first chunking is what makes this a true prefix-chain. The reusable +// KV/prefix cache is always at the HEAD of the prompt: the system prompt and +// early turns are stable, new content is appended at the end, and the KV cache +// is valid up to the first differing token scanning from the start. Because the +// boundaries are anchored at offset 0 (never length-dependent), a prompt P and +// any extension P+suffix share their entire leading overlap, so turn N and turn +// N+1 match for longest-prefix routing. Prefixes deeper than +// MaxDepth*WindowBytes bytes are treated as equal (two prompts agreeing on the +// first MaxDepth head blocks yield identical chains): an accepted routing-hint +// limitation, since the cap bounds the chain length for very long prompts. +// +// xxhash is used (not hash/maphash) because the hash MUST be identical across +// frontend processes: peers exchange these hashes over NATS, and maphash uses a +// per-process random seed that would make peers disagree. +func ExtractChain(model, prompt string, cfg Config) []uint64 { + if prompt == "" { + return nil + } + data := []byte(prompt) + nBlocks := (len(data) + cfg.WindowBytes - 1) / cfg.WindowBytes + depth := min(nBlocks, cfg.MaxDepth) + salt := xxhash.Sum64String(model) + // One Digest reused across blocks: Reset() restores the seed-0 initial + // state, so Reset()+Write produces the byte-identical value to a fresh + // New()+Write. xxhash seed 0 is stateless, so output is unchanged while we + // avoid allocating a Digest per block. The output determinism across + // processes (peers exchange these hashes over NATS) is preserved. + h := xxhash.New() + chain := make([]uint64, 0, depth) + prev := salt + var pb [8]byte + for i := range depth { + off := i * cfg.WindowBytes + end := min(off+cfg.WindowBytes, len(data)) + h.Reset() + binary.LittleEndian.PutUint64(pb[:], prev) + _, _ = h.Write(pb[:]) + _, _ = h.Write(data[off:end]) + prev = h.Sum64() + chain = append(chain, prev) + } + return chain +} diff --git a/core/services/nodes/prefixcache/extractor_test.go b/core/services/nodes/prefixcache/extractor_test.go new file mode 100644 index 000000000..53ee95d3f --- /dev/null +++ b/core/services/nodes/prefixcache/extractor_test.go @@ -0,0 +1,75 @@ +package prefixcache_test + +import ( + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +var _ = Describe("Extractor", func() { + cfg := prefixcache.DefaultConfig() + + It("produces a deterministic chain for the same prompt and model", func() { + a := prefixcache.ExtractChain("modelX", "hello world", cfg) + b := prefixcache.ExtractChain("modelX", "hello world", cfg) + Expect(a).To(Equal(b)) + Expect(len(a)).To(BeNumerically(">", 0)) + }) + + It("shares the head but diverges on a volatile tail", func() { + base := strings.Repeat("system rules ", 100) // > one window + x := prefixcache.ExtractChain("m", base+"Current time 12:00:00", cfg) + y := prefixcache.ExtractChain("m", base+"Current time 12:00:01", cfg) + // leading hashes (the stable head) are identical + Expect(x[0]).To(Equal(y[0])) + // the final (tail) hash differs + Expect(x[len(x)-1]).NotTo(Equal(y[len(y)-1])) + }) + + It("salts by model so identical text yields different chains per model", func() { + Expect(prefixcache.ExtractChain("m1", "abc", cfg)[0]). + NotTo(Equal(prefixcache.ExtractChain("m2", "abc", cfg)[0])) + }) + + It("caps depth", func() { + small := cfg + small.WindowBytes = 1 + small.MaxDepth = 4 + chain := prefixcache.ExtractChain("m", "abcdefghij", small) + Expect(len(chain)).To(Equal(4)) + }) + + It("returns nil for empty prompt", func() { + Expect(prefixcache.ExtractChain("m", "", cfg)).To(BeNil()) + }) + + It("stays stable across turns once the prompt grows past the depth cap", func() { + small := cfg + small.WindowBytes = 4 + small.MaxDepth = 3 // 12-byte head budget + + // base is longer than MaxDepth*WindowBytes so the chain is capped to + // the first 3 head blocks. + base := "system-rules-stable-prefix-that-exceeds-the-budget" + Expect(len(base)).To(BeNumerically(">", small.WindowBytes*small.MaxDepth)) + + turnN := prefixcache.ExtractChain("m", base, small) + turnN1 := prefixcache.ExtractChain("m", base+"more text appended", small) + // Both capped to the same first MaxDepth head blocks -> identical chains. + Expect(turnN).To(HaveLen(small.MaxDepth)) + Expect(turnN1).To(HaveLen(small.MaxDepth)) + Expect(turnN1).To(Equal(turnN)) + + // A prompt diverging WITHIN the budget shares the leading hashes up to + // the divergence block and differs after. "system-r" matches base for + // the first two 4-byte blocks ("syst","em-r"), then block 2 differs. + divergent := prefixcache.ExtractChain("m", "system-rDIFFERENT-tail", small) + Expect(divergent).To(HaveLen(small.MaxDepth)) + Expect(divergent[0]).To(Equal(turnN[0])) + Expect(divergent[1]).To(Equal(turnN[1])) + Expect(divergent[2]).NotTo(Equal(turnN[2])) + }) +}) diff --git a/core/services/nodes/prefixcache/index.go b/core/services/nodes/prefixcache/index.go new file mode 100644 index 000000000..fe629ff56 --- /dev/null +++ b/core/services/nodes/prefixcache/index.go @@ -0,0 +1,129 @@ +package prefixcache + +import ( + "sort" + "sync" + "time" + + "github.com/mudler/LocalAI/pkg/radixtree" +) + +// Index is the guessed (routing-history) Provider backed by per-model radix +// trees keyed by ReplicaKey. Affinity is per replica, so the same prefix served +// by two replicas of one node resolves back to the exact replica that served it. +// Safe for concurrent use. +type Index struct { + cfg Config + mu sync.RWMutex + trees map[string]*radixtree.Tree[ReplicaKey] +} + +func NewIndex(cfg Config) *Index { + return &Index{cfg: cfg, trees: map[string]*radixtree.Tree[ReplicaKey]{}} +} + +// existingTree returns the tree for model without creating one. The bool +// reports whether a tree already existed. +func (ix *Index) existingTree(model string) (*radixtree.Tree[ReplicaKey], bool) { + ix.mu.RLock() + defer ix.mu.RUnlock() + t, ok := ix.trees[model] + return t, ok +} + +func (ix *Index) tree(model string) *radixtree.Tree[ReplicaKey] { + ix.mu.RLock() + t, ok := ix.trees[model] + ix.mu.RUnlock() + if ok { + return t + } + ix.mu.Lock() + defer ix.mu.Unlock() + if t, ok = ix.trees[model]; ok { + return t + } + t = radixtree.New[ReplicaKey](radixtree.Options{TTL: ix.cfg.TTL, HalfLife: ix.cfg.HalfLife}) + ix.trees[model] = t + return t +} + +func (ix *Index) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision { + t := ix.tree(model) + var d PrefixDecision + // WeightsFor computes every candidate weight in a single tree walk and + // returns a map pre-populated with an entry (weight 0 by default) for every + // requested candidate. Candidacy is therefore exactly "is a key in weights", + // so we derive the hot-match membership check from it rather than building a + // second set. + weights := t.WeightsFor(candidates, now) + if len(chain) > 0 { + if key, depth, ok := t.LongestMatch(chain, now); ok { + // LongestMatch searches the whole tree, so the deepest match can be + // a replica that is offline / unloaded / not in the candidate set. + // Treating that as a hot match produces a false forced-disturb signal + // upstream (the warm replica was absent, not load-saturated). Only honor + // the match when the matched replica is an actual candidate; otherwise + // fall back to cold placement. + if _, ok := weights[key]; ok { + d.Hot = key + d.HasHot = true + d.MatchRatio = float64(depth) / float64(len(chain)) + } + } + } + // Cold order: candidates ascending by cacheWeight, tie-break by NodeID then + // Replica. The sort comparator reads precomputed weights instead of triggering + // an O(tree size) Weight call per comparison. With at most one candidate the + // input order is already the cold order, so skip the sort. + order := make([]ReplicaKey, len(candidates)) + copy(order, candidates) + if len(order) > 1 { + sort.Slice(order, func(i, j int) bool { + if weights[order[i]] != weights[order[j]] { + return weights[order[i]] < weights[order[j]] + } + return order[i].less(order[j]) + }) + } + d.ColdOrder = order + return d +} + +func (ix *Index) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool { + if len(chain) == 0 || key.NodeID == "" { + return false + } + t := ix.tree(model) + // New/extended iff the current deepest match for this exact chain is not + // already this replica at full depth. + cur, depth, ok := t.LongestMatch(chain, now) + t.Insert(chain, key, now) + return !ok || depth < len(chain) || cur != key +} + +// Invalidate drops all entries for ONE replica. It never interns an empty tree +// (a registry chokepoint fires Invalidate for every replica removal of every +// model, including round-robin models that never used the prefix cache, so +// lazily creating a tree here would grow the trees map unboundedly). +func (ix *Index) Invalidate(model string, key ReplicaKey) { + if t, ok := ix.existingTree(model); ok { + t.RemoveFunc(func(k ReplicaKey) bool { return k == key }) + } +} + +// InvalidateNode drops entries for ALL replicas of nodeID. Like Invalidate it +// does not intern an empty tree. +func (ix *Index) InvalidateNode(model, nodeID string) { + if t, ok := ix.existingTree(model); ok { + t.RemoveFunc(func(k ReplicaKey) bool { return k.NodeID == nodeID }) + } +} + +func (ix *Index) Evict(now time.Time) { + ix.mu.RLock() + defer ix.mu.RUnlock() + for _, t := range ix.trees { + t.Evict(now) + } +} diff --git a/core/services/nodes/prefixcache/index_test.go b/core/services/nodes/prefixcache/index_test.go new file mode 100644 index 000000000..4719551d1 --- /dev/null +++ b/core/services/nodes/prefixcache/index_test.go @@ -0,0 +1,169 @@ +package prefixcache_test + +import ( + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + +var _ = Describe("Index provider", func() { + cfg := prefixcache.DefaultConfig() + + It("returns no hot match before anything is observed", func() { + idx := prefixcache.NewIndex(cfg) + d := idx.Decide("m", []uint64{1, 2, 3}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0) + Expect(d.HasHot).To(BeFalse()) + // cold order present (all weights zero -> deterministic by node id) + Expect(d.ColdOrder).To(ConsistOf(rk("A", 0), rk("B", 0))) + }) + + It("returns the observed replica as hot match with the right ratio", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(rk("A", 0))) + Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001)) + }) + + It("orders cold candidates by ascending cacheWeight", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1}, rk("A", 0), t0) + idx.Observe("m", []uint64{2}, rk("A", 0), t0) // A weight 2 + idx.Observe("m", []uint64{3}, rk("B", 0), t0) // B weight 1 + d := idx.Decide("m", []uint64{9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0) + Expect(d.HasHot).To(BeFalse()) + Expect(d.ColdOrder).To(Equal([]prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)})) // B lower weight first + }) + + It("drops the hot match when the matched replica is not in the candidate set", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + // A holds the longest match, but A is not a candidate (offline / + // unloaded). The matched replica must be ignored so cold placement runs + // and no false forced-disturb fires upstream. + d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("B", 0), rk("C", 0)}, t0) + Expect(d.HasHot).To(BeFalse()) + Expect(d.MatchRatio).To(Equal(0.0)) + Expect(d.ColdOrder).To(ConsistOf(rk("B", 0), rk("C", 0))) + }) + + It("returns a hot match for a query that only shares a prefix with an observed chain", func() { + // The real-world case: a replica served chain [1,2,3,4]; a new request + // shares the leading block [1,2,3] but diverges at the tail ([1,2,3,9]). + // With prefix matching (value recorded at every node) Decide must still + // route to the warm replica, matching at the depth of the shared prefix. + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + d := idx.Decide("m", []uint64{1, 2, 3, 9}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(rk("A", 0))) + Expect(d.MatchRatio).To(BeNumerically("~", 3.0/4.0, 0.001)) // shared [1,2,3] of len-4 query + }) + + It("keeps the hot match when the matched replica is a candidate", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + d := idx.Decide("m", []uint64{1, 2, 3, 4, 5}, []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0)}, t0) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(rk("A", 0))) + Expect(d.MatchRatio).To(BeNumerically("~", 4.0/5.0, 0.001)) + }) + + It("tracks affinity per replica, not per node", func() { + // Two replicas on the SAME node, each serving a different chain that share + // a leading block. The hot match for a query extending chain1 must be the + // EXACT replica that served chain1, not the other replica on the same node. + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) // replica 0 owns [1,2,3,4] + idx.Observe("m", []uint64{1, 2, 5, 6}, rk("A", 1), t0) // replica 1 owns [1,2,5,6] + cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)} + d := idx.Decide("m", []uint64{1, 2, 3, 4, 7}, cands, t0) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(rk("A", 0))) // distinct replicas on one node have distinct affinity + d2 := idx.Decide("m", []uint64{1, 2, 5, 6, 7}, cands, t0) + Expect(d2.HasHot).To(BeTrue()) + Expect(d2.Hot).To(Equal(rk("A", 1))) + }) + + It("Invalidate drops one replica while InvalidateNode drops all replicas of a node", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + idx.Observe("m", []uint64{5, 6, 7, 8}, rk("A", 1), t0) + cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)} + + // Invalidate replica 0 only: replica 1 survives. + idx.Invalidate("m", rk("A", 0)) + Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse()) + d1 := idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0) + Expect(d1.HasHot).To(BeTrue()) + Expect(d1.Hot).To(Equal(rk("A", 1))) + + // Re-observe both, then InvalidateNode drops BOTH replicas. + idx.Observe("m", []uint64{1, 2, 3, 4}, rk("A", 0), t0) + idx.InvalidateNode("m", "A") + Expect(idx.Decide("m", []uint64{1, 2, 3, 4}, cands, t0).HasHot).To(BeFalse()) + Expect(idx.Decide("m", []uint64{5, 6, 7, 8}, cands, t0).HasHot).To(BeFalse()) + }) + + It("forgets a replica on Invalidate", func() { + idx := prefixcache.NewIndex(cfg) + idx.Observe("m", []uint64{1, 2}, rk("A", 0), t0) + idx.Invalidate("m", rk("A", 0)) + d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0) + Expect(d.HasHot).To(BeFalse()) + }) + + It("does not intern an empty tree when invalidating a model that has none", func() { + idx := prefixcache.NewIndex(cfg) + Expect(idx.TreeCountForTest()).To(Equal(0)) + // Round-robin model that never used the prefix cache: invalidating a + // replica removal must be a no-op and must not retain a tree. + idx.Invalidate("never-cached", rk("A", 0)) + idx.Invalidate("never-cached", rk("B", 0)) + idx.InvalidateNode("other", "C") + Expect(idx.TreeCountForTest()).To(Equal(0)) + // And a Decide afterwards still works without a hot match. + d := idx.Decide("never-cached", []uint64{1}, []prefixcache.ReplicaKey{rk("A", 0)}, t0) + Expect(d.HasHot).To(BeFalse()) + }) + + It("is safe for concurrent Decide/Observe/Invalidate (run with -race)", func() { + idx := prefixcache.NewIndex(cfg) + models := []string{"m1", "m2"} + nodes := []string{"A", "B", "C"} + cands := []prefixcache.ReplicaKey{rk("A", 0), rk("B", 0), rk("C", 0)} + var wg sync.WaitGroup + for g := range 8 { + wg.Add(1) + go func(g int) { + defer GinkgoRecover() + defer wg.Done() + model := models[g%len(models)] + node := nodes[g%len(nodes)] + now := t0 + for i := range 200 { + chain := []uint64{uint64(g), uint64(i % 7), uint64(i)} + switch i % 4 { + case 0: + idx.Observe(model, chain, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2}, now) + case 1: + idx.Decide(model, chain, cands, now) + case 2: + idx.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: i % 2}) + case 3: + idx.InvalidateNode(model, node) + } + now = now.Add(time.Millisecond) + } + }(g) + } + wg.Wait() + }) +}) diff --git a/core/services/nodes/prefixcache/policy.go b/core/services/nodes/prefixcache/policy.go new file mode 100644 index 000000000..3a87f25b7 --- /dev/null +++ b/core/services/nodes/prefixcache/policy.go @@ -0,0 +1,47 @@ +// Package prefixcache implements prefix-cache-aware routing for distributed +// mode: it turns a request prompt into a chain of prefix hashes, tracks which +// node served which prefix in an in-memory radix tree, and provides a +// load-guarded preferred-node decision. See docs/content/features/distributed-mode.md. +package prefixcache + +// RoutePolicy selects the routing strategy for a model. The zero value is +// RoutePolicyDefault, meaning "inherit the cluster-wide default". +type RoutePolicy int + +const ( + RoutePolicyDefault RoutePolicy = iota // inherit global default + RoutePolicyRoundRobin // today's behavior (the floor) + RoutePolicyPrefixCache // cache-aware routing +) + +// ParsePolicy maps a config string to a RoutePolicy. Unknown or empty strings +// map to RoutePolicyDefault. +func ParsePolicy(s string) RoutePolicy { + switch s { + case "round_robin": + return RoutePolicyRoundRobin + case "prefix_cache": + return RoutePolicyPrefixCache + default: + return RoutePolicyDefault + } +} + +func (p RoutePolicy) String() string { + switch p { + case RoutePolicyRoundRobin: + return "round_robin" + case RoutePolicyPrefixCache: + return "prefix_cache" + default: + return "default" + } +} + +// Resolve returns p unless it is Default, in which case it returns global. +func (p RoutePolicy) Resolve(global RoutePolicy) RoutePolicy { + if p == RoutePolicyDefault { + return global + } + return p +} diff --git a/core/services/nodes/prefixcache/policy_test.go b/core/services/nodes/prefixcache/policy_test.go new file mode 100644 index 000000000..5529de8a6 --- /dev/null +++ b/core/services/nodes/prefixcache/policy_test.go @@ -0,0 +1,29 @@ +package prefixcache_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +var _ = Describe("RoutePolicy", func() { + It("parses known values and defaults unknown to Default (zero)", func() { + Expect(prefixcache.ParsePolicy("round_robin")).To(Equal(prefixcache.RoutePolicyRoundRobin)) + Expect(prefixcache.ParsePolicy("prefix_cache")).To(Equal(prefixcache.RoutePolicyPrefixCache)) + Expect(prefixcache.ParsePolicy("")).To(Equal(prefixcache.RoutePolicyDefault)) + Expect(prefixcache.ParsePolicy("bogus")).To(Equal(prefixcache.RoutePolicyDefault)) + }) + + It("stringifies", func() { + Expect(prefixcache.RoutePolicyPrefixCache.String()).To(Equal("prefix_cache")) + Expect(prefixcache.RoutePolicyRoundRobin.String()).To(Equal("round_robin")) + }) + + It("resolves per-model against a global default", func() { + Expect(prefixcache.RoutePolicyDefault.Resolve(prefixcache.RoutePolicyPrefixCache)). + To(Equal(prefixcache.RoutePolicyPrefixCache)) + Expect(prefixcache.RoutePolicyRoundRobin.Resolve(prefixcache.RoutePolicyPrefixCache)). + To(Equal(prefixcache.RoutePolicyRoundRobin)) + }) +}) diff --git a/core/services/nodes/prefixcache/prefixcache_suite_test.go b/core/services/nodes/prefixcache/prefixcache_suite_test.go new file mode 100644 index 000000000..6c2c86073 --- /dev/null +++ b/core/services/nodes/prefixcache/prefixcache_suite_test.go @@ -0,0 +1,13 @@ +package prefixcache_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestPrefixCache(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "PrefixCache Suite") +} diff --git a/core/services/nodes/prefixcache/pressure.go b/core/services/nodes/prefixcache/pressure.go new file mode 100644 index 000000000..29e3e0f15 --- /dev/null +++ b/core/services/nodes/prefixcache/pressure.go @@ -0,0 +1,82 @@ +package prefixcache + +import ( + "sync" + "time" +) + +// Pressure is a concurrency-safe rolling per-model counter of forced-disturb +// events. A forced-disturb is recorded by the router when a usable hot prefix +// match existed but the load guard forced the request off the warm node (see +// SmartRouter.buildPreference). The reconciler reads Count to decide whether +// the cache-warm replica is saturated enough to warrant a scale-up. +// +// Entries older than the window are dropped on both Record and Count, so the +// slice never grows unbounded - even for a model that takes records but is +// never Counted (e.g. one with zero loaded replicas the reconciler skips). An +// idle model's history also decays to zero on the next read. +type Pressure struct { + mu sync.Mutex + window time.Duration + events map[string][]time.Time +} + +// NewPressure creates a Pressure counter that remembers events for the given +// rolling window. +func NewPressure(window time.Duration) *Pressure { + return &Pressure{ + window: window, + events: make(map[string][]time.Time), + } +} + +// pruneLocked drops entries older than cutoff, compacting in place. The cutoff +// boundary itself is inclusive so an event exactly window-old still counts. +// Callers must hold p.mu. +func pruneLocked(ts []time.Time, cutoff time.Time) []time.Time { + kept := ts[:0] + for _, t := range ts { + if !t.Before(cutoff) { + kept = append(kept, t) + } + } + return kept +} + +// Record appends a forced-disturb timestamp for the model and prunes entries +// older than the window, so the per-model slice stays bounded regardless of how +// often Count runs. +func (p *Pressure) Record(model string, now time.Time) { + p.mu.Lock() + defer p.mu.Unlock() + cutoff := now.Add(-p.window) + kept := append(pruneLocked(p.events[model], cutoff), now) + p.events[model] = kept +} + +// Count returns the number of records for the model within [now-window, now], +// dropping any entries older than the window so the backing slice stays bounded. +func (p *Pressure) Count(model string, now time.Time) int { + p.mu.Lock() + defer p.mu.Unlock() + ts := p.events[model] + if len(ts) == 0 { + return 0 + } + kept := pruneLocked(ts, now.Add(-p.window)) + if len(kept) == 0 { + delete(p.events, model) + return 0 + } + p.events[model] = kept + return len(kept) +} + +// Reset clears all recorded events for model. Call after acting on the signal +// (a pressure-triggered scale-up) so a single burst does not trigger repeated +// scale-ups across consecutive ticks. +func (p *Pressure) Reset(model string) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.events, model) +} diff --git a/core/services/nodes/prefixcache/pressure_test.go b/core/services/nodes/prefixcache/pressure_test.go new file mode 100644 index 000000000..e6b741cb0 --- /dev/null +++ b/core/services/nodes/prefixcache/pressure_test.go @@ -0,0 +1,98 @@ +package prefixcache_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Pressure counter", func() { + t0 := time.Unix(1700000000, 0) + + It("counts events within the window and forgets older ones", func() { + p := prefixcache.NewPressure(time.Minute) + p.Record("m", t0) + p.Record("m", t0.Add(30*time.Second)) + Expect(p.Count("m", t0.Add(40*time.Second))).To(Equal(2)) + Expect(p.Count("m", t0.Add(90*time.Second))).To(Equal(1)) // first expired + }) + + It("tracks pressure per model independently", func() { + p := prefixcache.NewPressure(time.Minute) + p.Record("a", t0) + p.Record("a", t0.Add(10*time.Second)) + p.Record("b", t0.Add(20*time.Second)) + Expect(p.Count("a", t0.Add(30*time.Second))).To(Equal(2)) + Expect(p.Count("b", t0.Add(30*time.Second))).To(Equal(1)) + Expect(p.Count("c", t0.Add(30*time.Second))).To(Equal(0)) + }) + + It("returns zero for a model that was never recorded", func() { + p := prefixcache.NewPressure(time.Minute) + Expect(p.Count("never", t0)).To(Equal(0)) + }) + + It("includes the boundary timestamp at exactly now-window", func() { + p := prefixcache.NewPressure(time.Minute) + p.Record("m", t0) + // now-window == t0 exactly, so the entry is still within [now-window, now]. + Expect(p.Count("m", t0.Add(time.Minute))).To(Equal(1)) + // one nanosecond past the window drops it. + Expect(p.Count("m", t0.Add(time.Minute+1))).To(Equal(0)) + }) + + It("bounds the backing slice in Record without any Count calls", func() { + p := prefixcache.NewPressure(time.Minute) + // Record many timestamps, advancing now well past the window between + // each, and never call Count. Each Record must prune the entries that + // have fallen out of [now-window, now] so the slice cannot accumulate. + var last time.Time + for i := range 1000 { + last = t0.Add(time.Duration(i) * 10 * time.Second) + p.Record("m", last) + } + // With a 1m window and 10s spacing, at most ~7 records (the boundary is + // inclusive) can be within [last-window, last]. The slice must stay that + // bounded, never growing toward 1000. + Expect(p.LenForTest("m")).To(BeNumerically("<=", 7)) + // And the in-window count must reflect only those bounded entries. + Expect(p.Count("m", last)).To(Equal(p.LenForTest("m"))) + }) + + It("clears all recorded events on Reset", func() { + p := prefixcache.NewPressure(time.Minute) + p.Record("m", t0) + p.Record("m", t0.Add(10*time.Second)) + p.Record("m", t0.Add(20*time.Second)) + Expect(p.Count("m", t0.Add(30*time.Second))).To(BeNumerically(">", 0)) + + p.Reset("m") + + // After Reset the model has no in-window events even though the + // timestamps would otherwise still be within [now-window, now]. + Expect(p.Count("m", t0.Add(30*time.Second))).To(Equal(0)) + Expect(p.LenForTest("m")).To(Equal(0)) + }) + + It("Reset only clears the named model", func() { + p := prefixcache.NewPressure(time.Minute) + p.Record("a", t0) + p.Record("b", t0) + p.Reset("a") + Expect(p.Count("a", t0.Add(time.Second))).To(Equal(0)) + Expect(p.Count("b", t0.Add(time.Second))).To(Equal(1)) + }) + + It("does not accumulate repeated out-of-window Records", func() { + p := prefixcache.NewPressure(time.Minute) + // Each record is more than a window apart, so every Record prunes the + // previous one. The slice should never hold more than a single entry. + for i := range 100 { + p.Record("m", t0.Add(time.Duration(i)*2*time.Minute)) + } + Expect(p.LenForTest("m")).To(Equal(1)) + Expect(p.Count("m", t0.Add(198*time.Minute))).To(Equal(1)) + }) +}) diff --git a/core/services/nodes/prefixcache/provider.go b/core/services/nodes/prefixcache/provider.go new file mode 100644 index 000000000..976713aac --- /dev/null +++ b/core/services/nodes/prefixcache/provider.go @@ -0,0 +1,24 @@ +package prefixcache + +import "time" + +// Provider is the seam between SmartRouter and the prefix-cache implementation. +// The radix-tree (guessed) implementation is the only one today; a future +// KV-event (reported) implementation can satisfy the same interface without +// changing SmartRouter (epic #10063 / #10064). Affinity is tracked per replica: +// each loaded replica is a separate process with its own KV cache. +type Provider interface { + // Decide computes the prefix decision for a request given the candidate + // replicas (the selector-filtered set). It does not consult load - load + // filtering happens in the DB transaction. + Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision + // Observe records that the replica served the request whose prefix is chain. + // Returns true when the assignment was new or extended (caller broadcasts). + Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool + // Invalidate drops all entries for ONE replica. + Invalidate(model string, key ReplicaKey) + // InvalidateNode drops entries for ALL replicas of a node. + InvalidateNode(model, nodeID string) + // Evict sweeps expired entries for all models. + Evict(now time.Time) +} diff --git a/core/services/nodes/prefixcache/select.go b/core/services/nodes/prefixcache/select.go new file mode 100644 index 000000000..9ede1f5b1 --- /dev/null +++ b/core/services/nodes/prefixcache/select.go @@ -0,0 +1,93 @@ +package prefixcache + +// ReplicaKey identifies a specific loaded replica (a backend process). Affinity +// is tracked per replica, not per node, because each replica is a separate +// process with its own KV cache. +type ReplicaKey struct { + NodeID string + Replica int +} + +// less reports whether a sorts before b, ordering by NodeID then Replica. It is +// the deterministic tiebreak used wherever two replicas are otherwise equal. +func (a ReplicaKey) less(b ReplicaKey) bool { + if a.NodeID != b.NodeID { + return a.NodeID < b.NodeID + } + return a.Replica < b.Replica +} + +// Candidate is a load-eligible-or-not replica view from the registry. There is +// one Candidate per LOADED replica: the router no longer collapses replicas per +// node, so two replicas of the same model on the same node are two candidates. +type Candidate struct { + Key ReplicaKey + InFlight int +} + +// PrefixDecision is computed from the in-memory tree before the DB transaction. +// Hot is the replica holding the longest prefix match and HasHot reports whether +// there is one (a ReplicaKey has no "" sentinel). MatchRatio is matched/total +// for that match. ColdOrder lists candidate replicas ascending by cacheWeight +// (lowest = least valuable warm cache = best cold target). +type PrefixDecision struct { + Hot ReplicaKey + HasHot bool + MatchRatio float64 + ColdOrder []ReplicaKey +} + +// Select implements filter-then-score per replica: keep candidates within the +// load guard (relative to the min in-flight across ALL candidate replicas), then +// prefer the exact hot-match replica, else the lowest-cacheWeight eligible +// replica via ColdOrder, else a deterministic eligible fallback (least in-flight, +// tiebreak by NodeID then Replica). Returns (ReplicaKey{}, false) when nothing is +// selectable. +func Select(cands []Candidate, d PrefixDecision, cfg Config) (ReplicaKey, bool) { + if len(cands) == 0 { + return ReplicaKey{}, false + } + minIF := cands[0].InFlight + for _, c := range cands { + minIF = min(minIF, c.InFlight) + } + eligible := map[ReplicaKey]bool{} + for _, c := range cands { + withinAbs := c.InFlight <= minIF+cfg.BalanceAbsThreshold + // +1 softens the relative guard when minIF==0 so a zero baseline does + // not require exact-zero in-flight; the absolute guard governs near 0. + withinRel := float64(c.InFlight) <= float64(minIF)*cfg.BalanceRelThreshold+1 + if withinAbs && withinRel { + eligible[c.Key] = true + } + } + // Hot match wins if eligible and strong enough. + if d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && eligible[d.Hot] { + return d.Hot, true + } + // Cold placement: lowest cacheWeight eligible replica. + for _, k := range d.ColdOrder { + if eligible[k] { + return k, true + } + } + // Deterministic eligible fallback: least in-flight, tiebreak NodeID then + // Replica. ColdOrder may not cover the eligible set (the caller may pass an + // empty ColdOrder), so this guarantees Select still returns the best eligible + // replica rather than failing. + var best Candidate + found := false + for _, c := range cands { + if !eligible[c.Key] { + continue + } + if !found || c.InFlight < best.InFlight || + (c.InFlight == best.InFlight && c.Key.less(best.Key)) { + best, found = c, true + } + } + if found { + return best.Key, true + } + return ReplicaKey{}, false +} diff --git a/core/services/nodes/prefixcache/select_test.go b/core/services/nodes/prefixcache/select_test.go new file mode 100644 index 000000000..f11bc1879 --- /dev/null +++ b/core/services/nodes/prefixcache/select_test.go @@ -0,0 +1,139 @@ +package prefixcache_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +func rk(node string, replica int) prefixcache.ReplicaKey { + return prefixcache.ReplicaKey{NodeID: node, Replica: replica} +} + +var _ = Describe("Select (filter-then-score)", func() { + cfg := prefixcache.DefaultConfig() // abs=2, rel=1.5, minMatch=0.3 + + cand := func(node string, replica, inflight int) prefixcache.Candidate { + return prefixcache.Candidate{Key: rk(node, replica), InFlight: inflight} + } + + It("returns the hot-match replica when it is load-eligible and match >= min", func() { + cands := []prefixcache.Candidate{cand("A", 0, 1), cand("B", 0, 0)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("A", 0), HasHot: true, MatchRatio: 0.5, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("A", 0))) // A in-flight 1 <= min(0)+2 and <= 0*1.5+1 + }) + + It("rejects the hot match when it violates the absolute load guard", func() { + cands := []prefixcache.Candidate{cand("A", 0, 5), cand("B", 0, 0)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("A", 0), HasHot: true, MatchRatio: 0.9, + ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("B", 0))) // A 5 > min(0)+2, drop to cold placement + }) + + It("ignores a match below min_prefix_match", func() { + cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("A", 0), HasHot: true, MatchRatio: 0.2, // < 0.3 + ColdOrder: []prefixcache.ReplicaKey{rk("B", 0), rk("A", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("B", 0))) // cold placement: lowest cacheWeight eligible + }) + + It("cold-places to lowest-cacheWeight replica within the eligible subset", func() { + cands := []prefixcache.Candidate{cand("A", 0, 0), cand("B", 0, 0), cand("C", 0, 9)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + ColdOrder: []prefixcache.ReplicaKey{rk("C", 0), rk("B", 0), rk("A", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("B", 0))) // C filtered out by load; B is next in cold order + }) + + It("returns false when no candidates", func() { + _, ok := prefixcache.Select(nil, prefixcache.PrefixDecision{}, cfg) + Expect(ok).To(BeFalse()) + }) + + It("falls back to the least-in-flight eligible replica when ColdOrder is empty", func() { + // Deterministic eligible fallback: ColdOrder does not cover the eligible + // set, so Select picks the least-in-flight eligible replica, tiebreaking by + // NodeID then Replica. + cands := []prefixcache.Candidate{cand("B", 1, 0), cand("B", 0, 0), cand("A", 0, 0)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("A", 0))) // all in-flight 0; A < B; within B, replica 0 < 1 + }) + + It("returns false when no candidate is eligible", func() { + // Impossible in practice (min is always eligible) but guards the contract: + // an empty eligible set yields no selection. Here every candidate is the + // min, so one is always eligible; instead test the documented zero value. + cands := []prefixcache.Candidate{cand("A", 0, 0)} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{}, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("A", 0))) + }) +}) + +var _ = Describe("Select replica granularity", func() { + cfg := prefixcache.DefaultConfig() + + It("distinguishes two replicas of the same node as separate candidates", func() { + // Two replicas on NodeA: replica 0 is hot but saturated, replica 1 is cool. + // The round-robin floor must drop to replica 1, NOT collapse them per node. + cands := []prefixcache.Candidate{ + {Key: rk("A", 0), InFlight: 50}, + {Key: rk("A", 1), InFlight: 0}, + } + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0, + ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("A", 1))) + }) + + It("pins back to the exact hot replica when it is within slack", func() { + cands := []prefixcache.Candidate{ + {Key: rk("A", 0), InFlight: 1}, + {Key: rk("A", 1), InFlight: 0}, + } + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("A", 0), HasHot: true, MatchRatio: 1.0, + ColdOrder: []prefixcache.ReplicaKey{rk("A", 1), rk("A", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("A", 0))) // within slack -> reuse exact replica + }) +}) + +var _ = Describe("Select round-robin floor invariant", func() { + It("never pins to a saturated hot replica (round-robin floor)", func() { + cfg := prefixcache.DefaultConfig() + cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 50}, {Key: rk("cool", 0), InFlight: 0}} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0, + ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("cool", 0))) + }) + + It("improves reuse when balanced", func() { + cfg := prefixcache.DefaultConfig() + cands := []prefixcache.Candidate{{Key: rk("hot", 0), InFlight: 1}, {Key: rk("cool", 0), InFlight: 0}} + got, ok := prefixcache.Select(cands, prefixcache.PrefixDecision{ + Hot: rk("hot", 0), HasHot: true, MatchRatio: 1.0, + ColdOrder: []prefixcache.ReplicaKey{rk("cool", 0), rk("hot", 0)}, + }, cfg) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(rk("hot", 0))) // within slack -> reuse + }) +}) diff --git a/core/services/nodes/prefixcache/sync.go b/core/services/nodes/prefixcache/sync.go new file mode 100644 index 000000000..421fea7e2 --- /dev/null +++ b/core/services/nodes/prefixcache/sync.go @@ -0,0 +1,91 @@ +package prefixcache + +import ( + "time" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/xlog" +) + +// publisher is the minimal slice of messaging.Client that Sync needs. +type publisher interface { + Publish(subject string, v any) error +} + +// Sync wraps an Index, broadcasting new/extended observations to peers and +// applying peers' broadcasts. It is the cross-frontend coherence layer. +type Sync struct { + idx Provider + pub publisher +} + +func NewSync(idx Provider, pub publisher) *Sync { return &Sync{idx: idx, pub: pub} } + +// Observe records locally and, if new/extended, broadcasts to peers. It returns +// whether the local index treated the assignment as new or extended, so Sync +// satisfies prefixcache.Provider. +func (s *Sync) Observe(model string, chain []uint64, key ReplicaKey, now time.Time) bool { + changed := s.idx.Observe(model, chain, key, now) + if changed && s.pub != nil { + ev := messaging.PrefixCacheObserveEvent{Model: model, Chain: chain, NodeID: key.NodeID, Replica: key.Replica} + if err := s.pub.Publish(messaging.SubjectPrefixCacheObserve, ev); err != nil { + xlog.Debug("prefixcache: observe publish failed", "error", err) + } + } + return changed +} + +// Invalidate drops the local entry for one replica and broadcasts to peers. The +// local drop is a no-op for models that were never cached (Index.Invalidate does +// not intern a tree). The broadcast is UNCONDITIONAL (when a publisher is +// configured): the registry chokepoint fires for every replica removal, and a +// peer frontend may hold a stale entry for the model even when THIS frontend +// never cached it, so gating the broadcast on local-tree existence would drop +// cross-frontend invalidations and leave peers routing to a removed replica +// until their TTL. +func (s *Sync) Invalidate(model string, key ReplicaKey) { + s.idx.Invalidate(model, key) + if s.pub != nil { + ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: key.NodeID, Replica: key.Replica} + if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil { + xlog.Debug("prefixcache: invalidate publish failed", "error", err) + } + } +} + +// InvalidateNode drops the local entries for ALL replicas of node and broadcasts +// to peers. Like Invalidate the broadcast is unconditional for cross-frontend +// coherence. A negative Replica on the wire means "all replicas of the node". +func (s *Sync) InvalidateNode(model, node string) { + s.idx.InvalidateNode(model, node) + if s.pub != nil { + ev := messaging.PrefixCacheInvalidateEvent{Model: model, NodeID: node, Replica: -1} + if err := s.pub.Publish(messaging.SubjectPrefixCacheInvalidate, ev); err != nil { + xlog.Debug("prefixcache: invalidate-node publish failed", "error", err) + } + } +} + +// ApplyObserve applies a peer observe event locally (no re-broadcast). +func (s *Sync) ApplyObserve(ev messaging.PrefixCacheObserveEvent, now time.Time) { + s.idx.Observe(ev.Model, ev.Chain, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica}, now) +} + +// ApplyInvalidate applies a peer invalidate event locally (no re-broadcast). A +// negative Replica targets all replicas of the node. +func (s *Sync) ApplyInvalidate(ev messaging.PrefixCacheInvalidateEvent) { + if ev.Replica < 0 { + s.idx.InvalidateNode(ev.Model, ev.NodeID) + return + } + s.idx.Invalidate(ev.Model, ReplicaKey{NodeID: ev.NodeID, Replica: ev.Replica}) +} + +// Decide delegates to the wrapped index. +func (s *Sync) Decide(model string, chain []uint64, candidates []ReplicaKey, now time.Time) PrefixDecision { + return s.idx.Decide(model, chain, candidates, now) +} + +// Evict delegates eviction of expired entries to the wrapped index. It does not +// broadcast: each frontend evicts its own copy on its own TTL clock. +func (s *Sync) Evict(now time.Time) { s.idx.Evict(now) } diff --git a/core/services/nodes/prefixcache/sync_test.go b/core/services/nodes/prefixcache/sync_test.go new file mode 100644 index 000000000..001454618 --- /dev/null +++ b/core/services/nodes/prefixcache/sync_test.go @@ -0,0 +1,118 @@ +package prefixcache_test + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" +) + +type fakePub struct{ published []any } + +func (f *fakePub) Publish(subject string, v any) error { + f.published = append(f.published, v) + return nil +} + +// Sync must satisfy the Provider seam so SmartRouter can hold a single +// prefixcache.Provider that broadcasts via NATS. +var _ prefixcache.Provider = (*prefixcache.Sync)(nil) + +var _ = Describe("Sync", func() { + It("delegates Evict to the wrapped index", func() { + cfg := prefixcache.DefaultConfig() + cfg.TTL = time.Minute + idx := prefixcache.NewIndex(cfg) + s := prefixcache.NewSync(idx, &fakePub{}) + s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) + // Before TTL: still hot. + Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0).HasHot).To(BeTrue()) + // After TTL via Sync.Evict: entry is swept. + s.Evict(t0.Add(2 * time.Minute)) + Expect(idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 0)}, t0.Add(2*time.Minute)).HasHot).To(BeFalse()) + }) + + It("publishes an observe event with the replica when Observe is new", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + pub := &fakePub{} + s := prefixcache.NewSync(idx, pub) + s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // first time -> publish + Expect(pub.published).To(HaveLen(1)) + ev := pub.published[0].(messaging.PrefixCacheObserveEvent) + Expect(ev.NodeID).To(Equal("A")) + Expect(ev.Replica).To(Equal(1)) + s.Observe("m", []uint64{1, 2}, rk("A", 1), t0) // same -> no publish + Expect(pub.published).To(HaveLen(1)) + }) + + It("broadcasts an invalidate even for a model with no local tree, without interning one", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + pub := &fakePub{} + s := prefixcache.NewSync(idx, pub) + // A peer frontend may hold a stale entry for this model even though THIS + // frontend never cached it, so the invalidate MUST be broadcast for + // cross-frontend coherence. The local drop must still not intern a tree. + s.Invalidate("never-cached", rk("A", 0)) + Expect(pub.published).To(HaveLen(1)) + ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent) + Expect(ev.NodeID).To(Equal("A")) + Expect(ev.Replica).To(Equal(0)) + Expect(idx.TreeCountForTest()).To(Equal(0)) + }) + + It("broadcasts an invalidate for a cached replica too", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + pub := &fakePub{} + s := prefixcache.NewSync(idx, pub) + s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) // creates the tree (also publishes observe) + pub.published = nil + s.Invalidate("m", rk("A", 0)) + Expect(pub.published).To(HaveLen(1)) + Expect(pub.published[0]).To(BeAssignableToTypeOf(messaging.PrefixCacheInvalidateEvent{})) + }) + + It("broadcasts a node-wide invalidate with a negative replica", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + pub := &fakePub{} + s := prefixcache.NewSync(idx, pub) + s.InvalidateNode("m", "A") + Expect(pub.published).To(HaveLen(1)) + ev := pub.published[0].(messaging.PrefixCacheInvalidateEvent) + Expect(ev.NodeID).To(Equal("A")) + Expect(ev.Replica).To(BeNumerically("<", 0)) + }) + + It("applies a peer observe event into the local index with the replica", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + s := prefixcache.NewSync(idx, &fakePub{}) + s.ApplyObserve(messaging.PrefixCacheObserveEvent{Model: "m", Chain: []uint64{1, 2}, NodeID: "A", Replica: 2}, t0) + d := idx.Decide("m", []uint64{1, 2}, []prefixcache.ReplicaKey{rk("A", 2)}, t0) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(rk("A", 2))) + }) + + It("applies a peer single-replica invalidate", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + s := prefixcache.NewSync(idx, &fakePub{}) + s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) + s.Observe("m", []uint64{3, 4}, rk("A", 1), t0) + s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: 0}) + cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)} + Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse()) + Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeTrue()) + }) + + It("applies a peer node-wide invalidate when replica is negative", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + s := prefixcache.NewSync(idx, &fakePub{}) + s.Observe("m", []uint64{1, 2}, rk("A", 0), t0) + s.Observe("m", []uint64{3, 4}, rk("A", 1), t0) + s.ApplyInvalidate(messaging.PrefixCacheInvalidateEvent{Model: "m", NodeID: "A", Replica: -1}) + cands := []prefixcache.ReplicaKey{rk("A", 0), rk("A", 1)} + Expect(idx.Decide("m", []uint64{1, 2}, cands, t0).HasHot).To(BeFalse()) + Expect(idx.Decide("m", []uint64{3, 4}, cands, t0).HasHot).To(BeFalse()) + }) +}) diff --git a/core/services/nodes/reconciler.go b/core/services/nodes/reconciler.go index 9606c3f52..bf3cfd18f 100644 --- a/core/services/nodes/reconciler.go +++ b/core/services/nodes/reconciler.go @@ -8,6 +8,7 @@ import ( "time" "github.com/mudler/LocalAI/core/services/advisorylock" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" grpcclient "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/xlog" "github.com/nats-io/nats.go" @@ -56,6 +57,13 @@ type ReplicaReconciler struct { // probeStaleAfter: only probe node_models rows older than this so we // don't hammer every worker every tick for models we just heard from. probeStaleAfter time.Duration + // pressure is the shared forced-disturb counter written by the router. When + // a model's count within the Pressure's rolling window reaches pressureThreshold the + // reconciler treats its cache-warm replica as saturated and scales up, + // subject to the same MaxReplicas/capacity/UnsatisfiableUntil machinery as + // the other scale-up paths. nil disables this signal (a true no-op). + pressure *prefixcache.Pressure + pressureThreshold int } // ModelScheduler abstracts the scheduling logic needed by the reconciler. @@ -83,6 +91,12 @@ type ReplicaReconcilerOptions struct { Interval time.Duration // default 30s ScaleDownDelay time.Duration // default 5m ProbeStaleAfter time.Duration // default 2m + // Pressure is the shared forced-disturb counter written by the router. nil + // disables the cache-saturation autoscale signal (a true no-op). + Pressure *prefixcache.Pressure + // PressureThreshold is the forced-disturb count within PressureWindow that + // triggers a scale-up. Default prefixcache.DefaultConfig().PressureScaleThreshold (1). + PressureThreshold int } // NewReplicaReconciler creates a new ReplicaReconciler. @@ -103,16 +117,22 @@ func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler { if prober == nil { prober = grpcModelProber{token: opts.RegistrationToken} } + pressureThreshold := opts.PressureThreshold + if pressureThreshold == 0 { + pressureThreshold = prefixcache.DefaultConfig().PressureScaleThreshold + } return &ReplicaReconciler{ - registry: opts.Registry, - scheduler: opts.Scheduler, - unloader: opts.Unloader, - adapter: opts.Adapter, - prober: prober, - db: opts.DB, - interval: interval, - scaleDownDelay: scaleDownDelay, - probeStaleAfter: probeStaleAfter, + registry: opts.Registry, + scheduler: opts.Scheduler, + unloader: opts.Unloader, + adapter: opts.Adapter, + prober: prober, + db: opts.DB, + interval: interval, + scaleDownDelay: scaleDownDelay, + probeStaleAfter: probeStaleAfter, + pressure: opts.Pressure, + pressureThreshold: pressureThreshold, } } @@ -409,13 +429,25 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu } xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName, "current", current, "min", cfg.MinReplicas, "adding", needed) - rc.scaleUp(ctx, cfg, needed) - // Successful (or partial) scale-up clears the hysteresis so a future - // dip starts fresh. - _ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName) + if rc.scaleUp(ctx, cfg, needed) { + // A real (or partial) scale-up clears the hysteresis so a future + // dip starts fresh. If scaleUp added nothing (scheduler errored or + // no node could be loaded) we leave the hysteresis intact so the + // next tick retries from where it left off rather than resetting + // the unsatisfiable counter on a failed attempt. + _ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName) + } return } + // scaledUp tracks whether a scale-up already fired in this tick. The two + // scale-up paths below (busy-burst and pressure) share the single `current` + // value read once above; scaleUp does not re-check it. So at most one of + // them may fire per tick, otherwise a model that is both busy AND over the + // pressure threshold would scale +2 and could overshoot MaxReplicas by one. + // Scale-down is also skipped in a tick that scaled up. + scaledUp := false + // 2. Auto-scale up if all replicas are busy if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) { if rc.allReplicasBusy(ctx, cfg.ModelName) { @@ -432,17 +464,63 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu } xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName, "current", current) - rc.scaleUp(ctx, cfg, 1) + // Only mark the tick as having scaled up if a replica was actually + // added. On a failed scaleUp, leave scaledUp false so the pressure + // path below and the scale-down logic still apply as they would + // have if the busy-burst path had not run. + scaledUp = rc.scaleUp(ctx, cfg, 1) } } - // 3. Scale down idle replicas above minimum - floor := cfg.MinReplicas - if floor < 1 { - floor = 1 + // 2b. Auto-scale up on prefix-cache forced-disturb pressure. A forced-disturb + // is recorded by the router when a request had a usable hot prefix match + // but the load guard forced it off the warm node: the cache-warm replica + // is saturated. We reuse the same MaxReplicas + capacity guards as the + // busy-burst path, and the same UnsatisfiableUntil cooldown gates this + // block at the top of reconcileModel, so a no-capacity model will not + // spin. Pressure never overrides MaxReplicas or force-evicts. + // + // Skipped when the busy-burst path already scaled up this tick: at most + // one scaleUp(+1) per tick (see scaledUp above). + if !scaledUp && rc.pressure != nil && current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) { + if pressureCount := rc.pressure.Count(cfg.ModelName, time.Now()); pressureCount >= rc.pressureThreshold { + candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg) + if selectorMatched { + capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs) + if capErr == nil && capacity > 0 { + xlog.Info("Reconciler: prefix-cache forced-disturb pressure, scaling up", + "model", cfg.ModelName, "current", current, + "pressure", pressureCount, + "threshold", rc.pressureThreshold) + if rc.scaleUp(ctx, cfg, 1) { + scaledUp = true + // Consume the signal only on a real scale-up: + // Pressure.Count is non-draining (it prunes only by + // age), so a single burst stays in-window for the whole + // window and would re-fire scaleUp on every tick. Reset + // clears the model's events so a fresh scale-up needs + // fresh forced-disturbs to accumulate. If scaleUp added + // nothing (scheduler errored or no node could be loaded) + // we preserve the signal so the next tick retries off + // the same accumulated pressure instead of having to + // re-accumulate a full window from scratch. + rc.pressure.Reset(cfg.ModelName) + } + } + // No capacity: transient demand, not a misconfig - let the next + // tick retry naturally (mirrors the busy-burst path's choice not + // to enter cooldown for burst load). + } + } } - if int(current) > floor { - rc.scaleDownIdle(ctx, cfg, int(current), floor) + + // 3. Scale down idle replicas above minimum. Skipped in a tick that already + // scaled up so we never scale up and down in the same pass. + if !scaledUp { + floor := max(cfg.MinReplicas, 1) + if int(current) > floor { + rc.scaleDownIdle(ctx, cfg, int(current), floor) + } } } @@ -470,10 +548,17 @@ func (rc *ReplicaReconciler) markCapacityProblem(ctx context.Context, modelName, // scaleUp schedules additional replicas of the model. Callers in // reconcileModel are expected to have already capped `count` against // ClusterCapacityForModel so this function never tries to overshoot. -func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) { +// +// Returns true if at least one replica was actually scheduled. Callers use +// this to gate signal-consuming side effects (Pressure.Reset, +// ClearUnsatisfiable) on a real scale-up: a failed/no-op scaleUp must not +// discard the accumulated forced-disturb pressure or clear the unsatisfiable +// hysteresis, otherwise the signal has to re-accumulate from scratch and the +// next tick can't simply retry. +func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) bool { if rc.scheduler == nil { xlog.Warn("Reconciler: no scheduler available, cannot scale up") - return + return false } // Resolve selector → candidate node IDs (nil when no selector → "any @@ -481,18 +566,21 @@ func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingCon // reconcileModel, but defensively short-circuit here too. candidateNodeIDs, ok := rc.candidateNodeIDsForSelector(ctx, cfg) if !ok { - return + return false } + scheduled := 0 for i := 0; i < count; i++ { node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs) if err != nil { xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName, "attempt", i+1, "error", err) - return // stop trying on first failure + break // stop trying on first failure } + scheduled++ xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name) } + return scheduled > 0 } // scaleDownIdle removes idle replicas above the floor. diff --git a/core/services/nodes/reconciler_test.go b/core/services/nodes/reconciler_test.go index 0697374a8..ce8800caf 100644 --- a/core/services/nodes/reconciler_test.go +++ b/core/services/nodes/reconciler_test.go @@ -2,12 +2,14 @@ package nodes import ( "context" + "errors" "runtime" "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/LocalAI/core/services/testutil" "gorm.io/gorm" ) @@ -245,6 +247,225 @@ var _ = Describe("ReplicaReconciler", func() { }) }) + Describe("Forced-disturb pressure autoscale (Phase 6)", func() { + It("scales up when pressure exceeds threshold, replicas= threshold every tick and + // drives the model toward MaxReplicas off a single burst. + node := registerNode("consume-node", "10.0.0.64:50051") + Expect(registry.SetNodeModel(context.Background(), node.ID, "consume-model", 0, "loaded", "addr1", 0)).To(Succeed()) + setSchedulingConfig("consume-model", 1, 4, "") + + pressure := prefixcache.NewPressure(time.Minute) + now := time.Now() + pressure.Record("consume-model", now) + pressure.Record("consume-model", now) + pressure.Record("consume-model", now) + + scheduler := &fakeScheduler{scheduleNode: node} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + Pressure: pressure, + }) + + // First tick: pressure above threshold → one scale-up. + reconciler.reconcile(context.Background()) + Expect(scheduler.scheduleCalls).To(HaveLen(1), + "first tick must scale up once on the burst") + + // Second tick: the burst's events are still inside the window, but + // the first scale-up Reset them, so no further scale-up occurs. + reconciler.reconcile(context.Background()) + Expect(scheduler.scheduleCalls).To(HaveLen(1), + "a single burst must not re-trigger scale-up on the next in-window tick") + }) + + It("does not consume the pressure signal when scaleUp fails", func() { + // Pressure above threshold and capacity exists, but the scheduler + // errors so no replica is actually added. The forced-disturb signal + // must be preserved (NOT Reset) so the next tick retries the + // scale-up off the same accumulated pressure, instead of having to + // re-accumulate a full window of forced-disturbs from scratch. + node := registerNode("fail-node", "10.0.0.66:50051") + Expect(registry.SetNodeModel(context.Background(), node.ID, "fail-model", 0, "loaded", "addr1", 0)).To(Succeed()) + setSchedulingConfig("fail-model", 1, 4, "") + + pressure := prefixcache.NewPressure(time.Minute) + now := time.Now() + pressure.Record("fail-model", now) + pressure.Record("fail-model", now) + pressure.Record("fail-model", now) + Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1)) + + // Scheduler errors: scaleUp attempts but adds nothing. + scheduler := &fakeScheduler{scheduleErr: errors.New("schedule boom")} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + Pressure: pressure, + }) + + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(1), + "scaleUp must have attempted exactly one schedule call") + Expect(pressure.Count("fail-model", time.Now())).To(BeNumerically(">=", 1), + "a failed scaleUp must NOT consume (Reset) the pressure signal — next tick should retry") + }) + + It("consumes the pressure signal only when scaleUp succeeds", func() { + // Mirror of the failure case: when the scheduler succeeds and a + // replica is actually added, the forced-disturb signal IS consumed + // (Reset to 0) so a single burst scales up only once. + node := registerNode("ok-node", "10.0.0.67:50051") + Expect(registry.SetNodeModel(context.Background(), node.ID, "ok-model", 0, "loaded", "addr1", 0)).To(Succeed()) + setSchedulingConfig("ok-model", 1, 4, "") + + pressure := prefixcache.NewPressure(time.Minute) + now := time.Now() + pressure.Record("ok-model", now) + pressure.Record("ok-model", now) + pressure.Record("ok-model", now) + + scheduler := &fakeScheduler{scheduleNode: node} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + Pressure: pressure, + }) + + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(1), + "successful scaleUp must have scheduled one replica") + Expect(pressure.Count("ok-model", time.Now())).To(Equal(0), + "a successful scaleUp must consume (Reset) the pressure signal to 0") + }) + + It("performs at most one scale-up per tick when both busy and over pressure", func() { + // The single loaded replica is busy (all-replicas-busy fires) AND + // pressure is above threshold. Both scale-up paths are eligible in + // the same tick. The invariant is at-most-one scaleUp(+1) per tick, + // so exactly one schedule call must happen, not two. + node := registerNode("dual-node", "10.0.0.65:50051") + Expect(registry.SetNodeModel(context.Background(), node.ID, "dual-model", 0, "loaded", "addr1", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), node.ID, "dual-model", 0)).To(Succeed()) + setSchedulingConfig("dual-model", 1, 4, "") + + pressure := prefixcache.NewPressure(time.Minute) + pressure.Record("dual-model", time.Now()) + pressure.Record("dual-model", time.Now()) + + scheduler := &fakeScheduler{scheduleNode: node} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + Pressure: pressure, + }) + + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(1), + "busy + pressure in one tick must still scale up by exactly one, not two") + }) + + It("does not spin when pressure is high but no capacity exists", func() { + // Single node, cap 1, already loaded → capacity 0. Pressure is high + // but there is nowhere to place a replica: must not call scheduler. + registerCappedNodeFn := func(name, address string, cap int) *BackendNode { + node := &BackendNode{ + Name: name, + NodeType: NodeTypeBackend, + Address: address, + MaxReplicasPerModel: cap, + } + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + return node + } + node := registerCappedNodeFn("pcap-node", "10.0.0.63:50051", 1) + Expect(registry.SetNodeModel(context.Background(), node.ID, "pcap-model", 0, "loaded", "addr1", 0)).To(Succeed()) + // MaxReplicas high enough that replicas unchanged behavior. +type RoutePreference struct { + PreferredNodeID string + PreferredReplica int +} + // FindAndLockNodeWithModel atomically finds the best loaded replica of the // given model and increments its in-flight counter within a single // transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction @@ -758,7 +893,7 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string) // NodeSelector so a cached replica on a now-excluded node isn't picked over a // matching replica elsewhere — the selector-mismatch fall-through path used to // trigger an eviction-busy loop when both sides had the model loaded. -func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, *NodeModel, error) { +func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName string, candidateNodeIDs []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) { var nm NodeModel var node BackendNode @@ -781,17 +916,33 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s // stale row in the same window, and other helpers that mirror this // JOIN need the same invariant. Belt-and-braces: status filter here // AND the status-checked node fetch below. - q := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + base := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id"). Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?", modelName, "loaded", StatusHealthy) if len(candidateNodeIDs) > 0 { - q = q.Where("node_models.node_id IN ?", candidateNodeIDs) + base = base.Where("node_models.node_id IN ?", candidateNodeIDs) } - if err := q. - Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC"). - First(&nm).Error; err != nil { - return err + + picked := false + if pref != nil && pref.PreferredNodeID != "" { + // Lock the EXACT (node_id, replica_index) row the caller chose. The + // caller (prefix-cache router) has already applied the load guard + // per replica, so here we only require that exact replica still be + // loaded+healthy. Fall through to the default ORDER BY when that + // specific replica is not found/loaded. + q := base.Session(&gorm.Session{}). + Where("node_models.node_id = ? AND node_models.replica_index = ?", pref.PreferredNodeID, pref.PreferredReplica) + if err := q.First(&nm).Error; err == nil { + picked = true + } + } + if !picked { + if err := base. + Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC"). + First(&nm).Error; err != nil { + return err + } } if err := tx.Model(&nm).Updates(map[string]any{ @@ -815,6 +966,47 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s return &node, &nm, nil } +// LoadedReplicaStats returns one ReplicaCandidate per loaded+healthy replica of +// modelName, carrying its current in-flight count. It is a read used by the +// prefix-cache router to apply the load guard when choosing a preferred node. +// When candidateNodeIDs is non-empty, only replicas on those nodes are +// returned; pass nil to consider any healthy node. The result is never nil; +// an empty slice means no loaded replica exists. +func (r *NodeRegistry) LoadedReplicaStats(ctx context.Context, modelName string, candidateNodeIDs []string) ([]ReplicaCandidate, error) { + type row struct { + NodeID string + Address string + ReplicaIndex int + InFlight int + LastUsed time.Time + AvailableVRAM uint64 + } + q := r.db.WithContext(ctx).Model(&NodeModel{}). + Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id"). + Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?", + modelName, "loaded", StatusHealthy) + if len(candidateNodeIDs) > 0 { + q = q.Where("node_models.node_id IN ?", candidateNodeIDs) + } + + // Narrow to only the columns the sole consumer (router buildPreference) + // reads: NodeID and InFlight. The other ReplicaCandidate fields stay at + // their zero value, which the consumer does not read. This avoids the + // JOIN-side available_vram fetch and the extra column transfer. + var rows []row + err := q.Select("node_models.node_id AS node_id, node_models.in_flight AS in_flight"). + Scan(&rows).Error + if err != nil { + return nil, fmt.Errorf("loading replica stats for %s: %w", modelName, err) + } + + out := make([]ReplicaCandidate, 0, len(rows)) + for _, rw := range rows { + out = append(out, ReplicaCandidate(rw)) + } + return out, nil +} + // TouchNodeModel updates the last_used timestamp for LRU tracking on a single // replica row. func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) { @@ -1198,8 +1390,12 @@ func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSche } return r.db.WithContext(ctx). Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "model_name"}}, - DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}), + Columns: []clause.Column{{Name: "model_name"}}, + DoUpdates: clause.AssignmentColumns([]string{ + "node_selector", "min_replicas", "max_replicas", + "route_policy", "balance_abs_threshold", "balance_rel_threshold", "min_prefix_match", + "updated_at", + }), }). Create(config).Error } diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go index a07398909..85ae598c9 100644 --- a/core/services/nodes/registry_test.go +++ b/core/services/nodes/registry_test.go @@ -245,7 +245,7 @@ var _ = Describe("NodeRegistry", func() { Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed()) - foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil) + foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.ID).To(Equal(node.ID)) Expect(foundNM.ModelName).To(Equal("my-model")) @@ -257,7 +257,7 @@ var _ = Describe("NodeRegistry", func() { }) It("returns error when model is not loaded anywhere", func() { - _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil) + _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil, nil) Expect(err).To(HaveOccurred()) }) @@ -274,7 +274,7 @@ var _ = Describe("NodeRegistry", func() { Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed()) - foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil) + foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.Name).To(Equal("lock-light")) }) @@ -299,7 +299,7 @@ var _ = Describe("NodeRegistry", func() { Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed()) - foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}) + foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}, nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.ID).To(Equal(included.ID)) Expect(foundNM.NodeID).To(Equal(included.ID)) @@ -326,7 +326,7 @@ var _ = Describe("NodeRegistry", func() { // (FindAndLockNodeWithModel atomically increments to lock the row.) picks := make([]string, 0, 9) for i := 0; i < 9; i++ { - n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil) + n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil, nil) Expect(err).ToNot(HaveOccurred()) picks = append(picks, n.Name) Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed()) @@ -355,7 +355,7 @@ var _ = Describe("NodeRegistry", func() { // query must return an error so Route() falls through to schedule // a fresh load on a matching node instead of reusing the excluded // replica. - _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}) + _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}, nil) Expect(err).To(HaveOccurred()) }) @@ -422,7 +422,7 @@ var _ = Describe("NodeRegistry", func() { goPick := PickBestReplica(candidates) Expect(goPick).ToNot(BeNil()) - sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil) + sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(sqlNode.ID).To(Equal(goPick.NodeID), @@ -433,6 +433,124 @@ var _ = Describe("NodeRegistry", func() { }) }) + Describe("FindAndLockNodeWithModel preference", func() { + var nodeA, nodeB *BackendNode + + BeforeEach(func() { + nodeA = makeNode("pref-a", "10.0.0.70:50051", 8_000_000_000) + nodeB = makeNode("pref-b", "10.0.0.71:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), nodeA, true)).To(Succeed()) + Expect(registry.Register(context.Background(), nodeB, true)).To(Succeed()) + // Both loaded+healthy for model "pref-model", in_flight 0. + Expect(registry.SetNodeModel(context.Background(), nodeA.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), nodeB.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed()) + }) + + It("locks the preferred node when eligible", func() { + node, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: nodeB.ID}) + Expect(err).ToNot(HaveOccurred()) + Expect(node.ID).To(Equal(nodeB.ID)) + Expect(nm.NodeID).To(Equal(nodeB.ID)) + + // in_flight is incremented atomically via gorm.Expr, so verify the + // persisted value through a re-fetch (the returned struct mirrors + // the pre-increment read, like the default-pick path). + persisted, err := registry.GetNodeModel(context.Background(), nodeB.ID, "pref-model", 0) + Expect(err).ToNot(HaveOccurred()) + Expect(persisted.InFlight).To(Equal(1)) + }) + + It("falls back to default order when preferred not loaded", func() { + node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: "ZZZ"}) + Expect(err).ToNot(HaveOccurred()) + Expect(node.ID).To(BeElementOf(nodeA.ID, nodeB.ID)) + }) + + It("nil preference behaves like before", func() { + node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(node).ToNot(BeNil()) + }) + + It("locks the EXACT preferred replica when the node hosts two replicas", func() { + // A single node hosts replica 0 and replica 1 of a model, both + // loaded+healthy. The preference must lock the SPECIFIC replica + // requested, not the least-loaded replica on the node. + node := makeNode("pref-multi", "10.0.0.72:50051", 16_000_000_000) + node.MaxReplicasPerModel = 2 + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "addr0", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "addr1", 0)).To(Succeed()) + + // pref={node, 1} must lock replica 1 specifically. + gotNode, nm1, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil, + &RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 1}) + Expect(err).ToNot(HaveOccurred()) + Expect(gotNode.ID).To(Equal(node.ID)) + Expect(nm1.ReplicaIndex).To(Equal(1)) + + // pref={node, 0} must lock replica 0 specifically. + _, nm0, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil, + &RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 0}) + Expect(err).ToNot(HaveOccurred()) + Expect(nm0.ReplicaIndex).To(Equal(0)) + }) + }) + + Describe("LoadedReplicaStats", func() { + var n1, n2, n3 *BackendNode + + BeforeEach(func() { + n1 = makeNode("stats-1", "10.0.0.80:50051", 8_000_000_000) + n2 = makeNode("stats-2", "10.0.0.81:50051", 8_000_000_000) + n3 = makeNode("stats-3", "10.0.0.82:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n3, true)).To(Succeed()) + // n1 loaded+busy, n2 loaded+idle, n3 has a different model only. + Expect(registry.SetNodeModel(context.Background(), n1.ID, "stats-model", 0, "loaded", "10.0.0.80:6000", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), n2.ID, "stats-model", 0, "loaded", "10.0.0.81:6000", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), n3.ID, "other-model", 0, "loaded", "", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed()) + }) + + It("returns loaded healthy replicas with in-flight counts", func() { + stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(stats).To(HaveLen(2)) + byNode := map[string]ReplicaCandidate{} + for _, s := range stats { + byNode[s.NodeID] = s + } + Expect(byNode).To(HaveKey(n1.ID)) + Expect(byNode).To(HaveKey(n2.ID)) + Expect(byNode[n1.ID].InFlight).To(Equal(2)) + Expect(byNode[n2.ID].InFlight).To(Equal(0)) + }) + + It("filters to the candidate node set when provided", func() { + stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", []string{n2.ID}) + Expect(err).ToNot(HaveOccurred()) + Expect(stats).To(HaveLen(1)) + Expect(stats[0].NodeID).To(Equal(n2.ID)) + }) + + It("excludes unhealthy nodes", func() { + Expect(registry.MarkUnhealthy(context.Background(), n1.ID)).To(Succeed()) + stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(stats).To(HaveLen(1)) + Expect(stats[0].NodeID).To(Equal(n2.ID)) + }) + + It("returns empty for a model with no loaded replicas", func() { + stats, err := registry.LoadedReplicaStats(context.Background(), "no-such-model", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(stats).To(BeEmpty()) + }) + }) + Describe("MarkHealthy and MarkUnhealthy round-trip", func() { It("transitions healthy -> unhealthy -> healthy", func() { node := makeNode("roundtrip-node", "10.0.0.60:50051", 8_000_000_000) @@ -632,6 +750,30 @@ var _ = Describe("NodeRegistry", func() { Expect(fetched.MaxReplicas).To(Equal(5)) }) + It("persists and updates route policy and thresholds", func() { + err := registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ + ModelName: "prefix-cache-model", RoutePolicy: "prefix_cache", + BalanceAbsThreshold: 3, BalanceRelThreshold: 2.0, MinPrefixMatch: 0.4, + }) + Expect(err).ToNot(HaveOccurred()) + + got, err := registry.GetModelScheduling(context.Background(), "prefix-cache-model") + Expect(err).ToNot(HaveOccurred()) + Expect(got.RoutePolicy).To(Equal("prefix_cache")) + Expect(got.BalanceAbsThreshold).To(Equal(3)) + Expect(got.BalanceRelThreshold).To(BeNumerically("==", 2.0)) + Expect(got.MinPrefixMatch).To(BeNumerically("==", 0.4)) + + // Update must not be dropped on conflict. + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ + ModelName: "prefix-cache-model", RoutePolicy: "round_robin", + })).ToNot(HaveOccurred()) + + got, err = registry.GetModelScheduling(context.Background(), "prefix-cache-model") + Expect(err).ToNot(HaveOccurred()) + Expect(got.RoutePolicy).To(Equal("round_robin")) + }) + It("lists all configs", func() { Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed()) Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed()) @@ -903,6 +1045,187 @@ var _ = Describe("NodeRegistry", func() { }) }) + Describe("SetReplicaRemovedHook", func() { + type removed struct { + model, node string + replica int + } + + It("fires once with the specific replica after RemoveNodeModel", func() { + node := makeNode("hook-remove-one", "10.0.0.230:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-model", 1, "loaded", "a", 0)).To(Succeed()) + + var fired []removed + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex}) + }) + + // RemoveNodeModel(replica 1) must fire with the SPECIFIC replica index. + Expect(registry.RemoveNodeModel(context.Background(), node.ID, "hook-model", 1)).To(Succeed()) + Expect(fired).To(HaveLen(1)) + Expect(fired[0]).To(Equal(removed{model: "hook-model", node: node.ID, replica: 1})) + }) + + It("fires once with replica<0 after RemoveAllNodeModelReplicas", func() { + node := makeNode("hook-remove-all", "10.0.0.231:50051", 16_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 0, "loaded", "a", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 1, "loaded", "b", 0)).To(Succeed()) + + var fired []removed + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex}) + }) + + // One call covers all replicas of that model on the node: a negative + // replica index signals "all replicas", and the consumer's + // InvalidateNode drops every entry for the (model, node) pair. + Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "hook-all-model")).To(Succeed()) + Expect(fired).To(HaveLen(1)) + Expect(fired[0].model).To(Equal("hook-all-model")) + Expect(fired[0].node).To(Equal(node.ID)) + Expect(fired[0].replica).To(BeNumerically("<", 0)) + }) + + It("does not panic when no hook is set", func() { + node := makeNode("hook-unset", "10.0.0.232:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "no-hook-model", 0, "loaded", "a", 0)).To(Succeed()) + + Expect(func() { + Expect(registry.RemoveNodeModel(context.Background(), node.ID, "no-hook-model", 0)).To(Succeed()) + Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "no-hook-model")).To(Succeed()) + }).ToNot(Panic()) + }) + + // firedModelSet collects the distinct model names the hook saw for the + // given node. The bulk node-scoped deletes below remove every replica of + // every model on the node in one statement, so the chokepoint must fire + // the hook once per distinct model name (the consumer's Invalidate + // drops all entries for that (model, node) pair). + seedTwoModels := func(node *BackendNode) { + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 0, "loaded", "a0", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 1, "loaded", "a1", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", 0, "loaded", "b0", 0)).To(Succeed()) + } + + It("fires once per distinct model after MarkOffline", func() { + node := makeNode("hook-offline", "10.0.0.240:50051", 8_000_000_000) + seedTwoModels(node) + + fired := map[removed]int{} + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + // Bulk node-scoped deletes signal "all replicas" with replica<0. + Expect(replicaIndex).To(BeNumerically("<", 0)) + fired[removed{model: modelName, node: nodeID}]++ + }) + + Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed()) + Expect(fired).To(HaveLen(2)) + Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1)) + Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1)) + }) + + It("fires once per distinct model after MarkDraining", func() { + node := makeNode("hook-draining", "10.0.0.241:50051", 8_000_000_000) + seedTwoModels(node) + + fired := map[removed]int{} + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + // Bulk node-scoped deletes signal "all replicas" with replica<0. + Expect(replicaIndex).To(BeNumerically("<", 0)) + fired[removed{model: modelName, node: nodeID}]++ + }) + + Expect(registry.MarkDraining(context.Background(), node.ID)).To(Succeed()) + Expect(fired).To(HaveLen(2)) + Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1)) + Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1)) + }) + + It("fires once per distinct model after Deregister", func() { + node := makeNode("hook-deregister", "10.0.0.242:50051", 8_000_000_000) + seedTwoModels(node) + + fired := map[removed]int{} + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + // Bulk node-scoped deletes signal "all replicas" with replica<0. + Expect(replicaIndex).To(BeNumerically("<", 0)) + fired[removed{model: modelName, node: nodeID}]++ + }) + + Expect(registry.Deregister(context.Background(), node.ID)).To(Succeed()) + Expect(fired).To(HaveLen(2)) + Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1)) + Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1)) + }) + + It("fires once per distinct model when re-registration clears stale rows", func() { + node := makeNode("hook-reregister", "10.0.0.243:50051", 8_000_000_000) + seedTwoModels(node) + + fired := map[removed]int{} + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + // Bulk node-scoped deletes signal "all replicas" with replica<0. + Expect(replicaIndex).To(BeNumerically("<", 0)) + fired[removed{model: modelName, node: nodeID}]++ + }) + + // Re-register the same node (same name): the re-register path + // clears the stale model rows, which must fire the hook. + again := makeNode("hook-reregister", "10.0.0.243:50052", 8_000_000_000) + Expect(registry.Register(context.Background(), again, true)).To(Succeed()) + Expect(fired).To(HaveLen(2)) + Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1)) + Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1)) + }) + + // Atomicity: the bulk node-scoped delete in MarkOffline/MarkDraining/ + // re-register now captures the model names and deletes the rows inside a + // single transaction. A true SetNodeModel-between-capture-and-delete race + // can't be forced deterministically here, but we can assert the + // post-condition the transaction guarantees: the set of fired hooks + // equals exactly the set of node_models rows the operation removed, with + // nothing left behind. If the capture and delete ever saw inconsistent + // snapshots, either a surviving row (delete missed it) or a missing hook + // (capture missed it) would break one of these assertions. + It("MarkOffline fires hooks for exactly the rows it deletes (consistent snapshot)", func() { + node := makeNode("hook-atomic-offline", "10.0.0.244:50051", 8_000_000_000) + seedTwoModels(node) + + // Capture what the transaction should remove, straight from the DB, + // before running the operation. + before, err := registry.GetNodeModels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + expectedModels := map[string]struct{}{} + for _, nm := range before { + expectedModels[nm.ModelName] = struct{}{} + } + Expect(expectedModels).To(HaveLen(2), "seed should create two distinct models") + + fired := map[string]struct{}{} + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) { + Expect(nodeID).To(Equal(node.ID)) + Expect(replicaIndex).To(BeNumerically("<", 0)) + fired[modelName] = struct{}{} + }) + + Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed()) + + // Hooks fired for exactly the distinct models that existed. + Expect(fired).To(Equal(expectedModels), + "hooks must fire for exactly the set of models the transaction deleted") + + // And the delete actually emptied the node_models rows for the node: + // no row survives that did not get a hook. + after, err := registry.GetNodeModels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(after).To(BeEmpty(), "no node_models row should survive the bulk delete") + }) + }) + Describe("ApplyAutoLabels", func() { It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() { node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index a9b01e80c..ec783b283 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -12,6 +12,8 @@ import ( "time" "github.com/mudler/LocalAI/core/services/advisorylock" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" + "github.com/mudler/LocalAI/pkg/distributedhdr" grpc "github.com/mudler/LocalAI/pkg/grpc" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/vram" @@ -43,6 +45,22 @@ type SmartRouterOptions struct { // anti-affinity is disabled at the scheduler layer; the per-node // watchdog still enforces the rule on arrival. ConflictResolver ConcurrencyConflictResolver + // PrefixProvider, when set, enables prefix-cache-aware routing: requests + // carrying a prompt prefix chain (distributedhdr.PrefixChain) are biased + // toward the node that already holds the longest matching prefix, subject + // to the load guard in prefixcache.Select. nil disables it entirely and + // routing is byte-for-byte the round-robin floor. At runtime this is the + // *prefixcache.Sync so Observe/Invalidate broadcast to peers. + PrefixProvider prefixcache.Provider + // PrefixConfig holds the global policy + thresholds. Per-model overrides on + // ModelSchedulingConfig refine it per request. Unused when PrefixProvider + // is nil. + PrefixConfig prefixcache.Config + // Pressure, when set, records a forced-disturb each time a request had a + // usable hot prefix match but the load guard forced it off the warm node. + // The reconciler reads the same instance to autoscale a saturated cache-warm + // replica. nil disables recording (the disabled path stays a no-op). + Pressure *prefixcache.Pressure } // SmartRouter routes inference requests to the best available backend node. @@ -56,6 +74,14 @@ type SmartRouter struct { db *gorm.DB // for advisory locks during routing stagingTracker *StagingTracker // tracks file staging progress for UI visibility conflictResolver ConcurrencyConflictResolver + // prefixProvider is the prefix-cache routing seam (nil disables it; see + // SmartRouterOptions.PrefixProvider). prefixConfig holds the global policy + // and thresholds. + prefixProvider prefixcache.Provider + prefixConfig prefixcache.Config + // pressure records forced-disturb events (hot match forced off the warm + // node by the load guard). nil disables recording. See SmartRouterOptions. + pressure *prefixcache.Pressure // installFlight coalesces concurrent identical NATS install requests // (same nodeID + backend + modelID + replica) so 6 simultaneous chat // completions for one not-yet-loaded model produce ONE round-trip, not @@ -91,6 +117,9 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter stagingTracker: NewStagingTracker(), conflictResolver: opts.ConflictResolver, probeCache: newProbeCache(probeCacheTTL), + prefixProvider: opts.PrefixProvider, + prefixConfig: opts.PrefixConfig, + pressure: opts.Pressure, } } @@ -230,18 +259,31 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType trackingKey = modelName } + // Fetch the model's scheduling config once: it is immutable for the life of + // this request, and resolveSelectorCandidates, buildPreference, and + // nodeMatchesScheduling all read it. Fetching once gives a consistent + // snapshot and avoids three DB round-trips for one row. nil sched means + // "no scheduling constraints", same as before. + sched, _ := r.registry.GetModelScheduling(ctx, trackingKey) + // Resolve the model's NodeSelector once so cached-replica lookup and the // new-load scheduler agree on the candidate set. Without this, a cached // replica on a node the selector now excludes was picked over a matching // replica elsewhere, and the fall-through then tried to load on the // matching node where the model was already at capacity (eviction-busy). - candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey) + candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, trackingKey, sched) if err != nil { return nil, err } + // Compute the prefix-cache preference once for this request. pref biases + // FindAndLockNodeWithModel toward the warm-cache node; observeChain is + // non-nil only when this model uses prefix_cache, gating the Observe calls + // below. Both are nil (no-op) when prefix-cache routing is disabled. + pref, observeChain := r.buildPreference(ctx, trackingKey, candidateNodeIDs, sched) + // Step 1: Find and atomically lock a node with this model loaded - node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs) + node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref) if err == nil && node != nil { modelAddr := node.Address if nm.Address != "" { @@ -258,7 +300,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType "node", node.Name, "model", modelName, "replica", replicaIdx) } else { // Verify node still matches scheduling constraints - if !r.nodeMatchesScheduling(ctx, node, trackingKey) { + if !r.nodeMatchesScheduling(ctx, node, sched) { r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx) xlog.Info("Cached model on node that no longer matches selector, falling through", "node", node.Name, "model", trackingKey, "replica", replicaIdx) @@ -269,6 +311,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType // onFirstComplete callback releases the reservation after the first inference // call finishes, so in-flight returns to 0 when idle. r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx) + r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx}) grpcClient := r.buildClientForAddr(node, modelAddr, parallel) tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx) tracked.OnFirstComplete(func() { @@ -288,7 +331,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType // Step 2: Model not loaded — schedule loading with distributed lock to prevent duplicates loadModel := func() (*RouteResult, error) { // Re-check after acquiring lock — another request may have loaded it - node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs) + node, nm, err := r.registry.FindAndLockNodeWithModel(ctx, trackingKey, candidateNodeIDs, pref) if err == nil && node != nil { modelAddr := node.Address if nm.Address != "" { @@ -305,7 +348,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType "node", node.Name, "model", modelName, "replica", replicaIdx) } else { // Verify node still matches scheduling constraints - if !r.nodeMatchesScheduling(ctx, node, trackingKey) { + if !r.nodeMatchesScheduling(ctx, node, sched) { r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx) xlog.Info("Cached model on node that no longer matches selector, falling through", "node", node.Name, "model", trackingKey, "replica", replicaIdx) @@ -314,6 +357,7 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType // Model loaded while we waited — FindAndLockNodeWithModel already incremented // in-flight as a reservation. Release it after the first inference completes. r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx) + r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: node.ID, Replica: replicaIdx}) grpcClient := r.buildClientForAddr(node, modelAddr, parallel) tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx) tracked.OnFirstComplete(func() { @@ -337,6 +381,10 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType return nil, err } + // Cold load landed on result.Node replica result.ReplicaIndex: record the + // assignment so subsequent requests with the same prefix prefer it. + r.observePrefix(trackingKey, observeChain, prefixcache.ReplicaKey{NodeID: result.Node.ID, Replica: result.ReplicaIndex}) + replicaIdx := result.ReplicaIndex tracked := NewInFlightTrackingClient(result.Client, r.registry, result.Node.ID, trackingKey, replicaIdx) tracked.OnFirstComplete(func() { @@ -389,13 +437,117 @@ func extractNodeIDs(nodes []BackendNode) []string { return ids } +// buildPreference computes the per-request route preference from the prefix +// chain on ctx and the model's resolved policy. The returned observeChain is +// non-nil only when the resolved policy is prefix_cache, signalling Route to +// record the assignment after a successful pick; for round-robin models it is +// nil so the tree is never polluted. The *RoutePreference is non-nil only when +// a load-eligible preferred node was chosen. +// +// When prefix-cache routing is disabled (nil provider), no chain is present, +// or the policy resolves to round-robin, both returns are nil and routing is +// the unchanged round-robin floor. +func (r *SmartRouter) buildPreference(ctx context.Context, modelID string, candidateNodeIDs []string, sched *ModelSchedulingConfig) (*RoutePreference, []uint64) { + if r.prefixProvider == nil { + return nil, nil + } + chain := distributedhdr.PrefixChain(ctx) + if len(chain) == 0 { + return nil, nil + } + + // Resolve per-model policy + thresholds over the global config. + policy := r.prefixConfig.GlobalPolicy + cfg := r.prefixConfig + if sched != nil { + policy = prefixcache.ParsePolicy(sched.RoutePolicy).Resolve(r.prefixConfig.GlobalPolicy) + if sched.BalanceAbsThreshold > 0 { + cfg.BalanceAbsThreshold = sched.BalanceAbsThreshold + } + if sched.BalanceRelThreshold > 0 { + cfg.BalanceRelThreshold = sched.BalanceRelThreshold + } + if sched.MinPrefixMatch > 0 { + cfg.MinPrefixMatch = sched.MinPrefixMatch + } + } + if policy != prefixcache.RoutePolicyPrefixCache { + return nil, nil + } + + // Load the candidate replicas PER REPLICA. Affinity is tracked per replica + // (each replica is a separate process with its own KV cache), so two + // replicas of the same model on the same node are two distinct candidates. + // FindAndLockNodeWithModel then locks the EXACT (node, replica) the policy + // chose. + stats, err := r.registry.LoadedReplicaStats(ctx, modelID, candidateNodeIDs) + if err != nil { + xlog.Debug("prefixcache: loading replica stats failed, skipping preference", "model", modelID, "error", err) + return nil, chain + } + if len(stats) == 0 { + return nil, chain + } + cands := make([]prefixcache.Candidate, 0, len(stats)) + keys := make([]prefixcache.ReplicaKey, 0, len(stats)) + for _, s := range stats { + key := prefixcache.ReplicaKey{NodeID: s.NodeID, Replica: s.ReplicaIndex} + cands = append(cands, prefixcache.Candidate{Key: key, InFlight: s.InFlight}) + keys = append(keys, key) + } + + d := r.prefixProvider.Decide(modelID, chain, keys, time.Now()) + chosen, ok := prefixcache.Select(cands, d, cfg) + + // Observability for the prefix-cache routing decision. One line per request + // at Debug: enable with DEBUG=true on the frontend to assess cache-aware + // routing. hotMatchHonored=true means we routed to the cache-warm replica; + // false with HasHot means the load guard forced a cold pick. + xlog.Debug("prefix-cache routing decision", + "model", modelID, + "chainDepth", len(chain), + "candidates", len(cands), + "hotNode", d.Hot.NodeID, + "hotReplica", d.Hot.Replica, + "hasHot", d.HasHot, + "matchRatio", d.MatchRatio, + "minMatch", cfg.MinPrefixMatch, + "chosen", fmt.Sprintf("%s/%d", chosen.NodeID, chosen.Replica), + "hotMatchHonored", d.HasHot && chosen == d.Hot) + + // Forced-disturb: a usable hot prefix match existed but the load guard + // forced us off the warm replica (Select picked a different replica). This + // is the scale-worthy signal - the cache-warm replica is saturated. It + // deliberately does not fire for all-unique workloads (no hot match), + // avoiding false-positive scale-ups. nil pressure is a no-op. + if r.pressure != nil && d.HasHot && d.MatchRatio >= cfg.MinPrefixMatch && chosen != d.Hot { + r.pressure.Record(modelID, time.Now()) + } + + if !ok { + return nil, chain + } + return &RoutePreference{PreferredNodeID: chosen.NodeID, PreferredReplica: chosen.Replica}, chain +} + +// observePrefix records that the replica `key` served the request whose prompt +// prefix is chain. It is a no-op when prefix-cache routing is disabled or the +// chain is empty (round-robin models pass a nil chain so the tree is never +// polluted). +func (r *SmartRouter) observePrefix(modelID string, chain []uint64, key prefixcache.ReplicaKey) { + if r.prefixProvider == nil || len(chain) == 0 { + return + } + r.prefixProvider.Observe(modelID, chain, key, time.Now()) + xlog.Debug("prefix-cache observed assignment", "model", modelID, "node", key.NodeID, "replica", key.Replica, "chainDepth", len(chain)) +} + // resolveSelectorCandidates returns the node IDs that match the model's // NodeSelector. Returns nil when no selector is configured ("any healthy node" // — registry helpers treat nil as no filter). Returns an error when a // non-empty selector matches zero healthy nodes, since there is nothing to // route or schedule on. -func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string) ([]string, error) { - sched, _ := r.registry.GetModelScheduling(ctx, modelID) +func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID string, sched *ModelSchedulingConfig) ([]string, error) { if sched == nil || sched.NodeSelector == "" { return nil, nil } @@ -469,9 +621,8 @@ func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID str // nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model. // Returns true if no constraints exist or the node matches all selector labels. -func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool { - sched, err := r.registry.GetModelScheduling(ctx, modelName) - if err != nil || sched == nil || sched.NodeSelector == "" { +func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, sched *ModelSchedulingConfig) bool { + if sched == nil || sched.NodeSelector == "" { return true // no constraints } @@ -518,7 +669,8 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID // Check for scheduling constraints (node selector). If a selector is set, // we restrict the candidate pool to matching nodes; otherwise nil means // "any healthy node". - candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID) + sched, _ := r.registry.GetModelScheduling(ctx, modelID) + candidateNodeIDs, err := r.resolveSelectorCandidates(ctx, modelID, sched) if err != nil { return nil, "", 0, err } diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go index 26237773e..8c6be3269 100644 --- a/core/services/nodes/router_test.go +++ b/core/services/nodes/router_test.go @@ -12,7 +12,9 @@ import ( . "github.com/onsi/gomega" "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/LocalAI/core/services/testutil" + "github.com/mudler/LocalAI/pkg/distributedhdr" grpc "github.com/mudler/LocalAI/pkg/grpc" pb "github.com/mudler/LocalAI/pkg/grpc/proto" ggrpc "google.golang.org/grpc" @@ -111,18 +113,34 @@ type fakeModelRouter struct { findNodesWithModelByName map[string][]BackendNode findNodesWithModelErr error + // LoadedReplicaStats returns (keyed by model name) + loadedReplicaStatsByName map[string][]ReplicaCandidate + loadedReplicaStatsErr error + // Track calls for assertions decrementCalls []string // "nodeID:modelName" incrementCalls []string removeCalls []string setCalls []string touchCalls []string + + // Preferences passed to FindAndLockNodeWithModel, in call order. nil + // entries are recorded too, so tests can assert "preference was nil". + findAndLockPrefs []*RoutePreference } -func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string) (*BackendNode, *NodeModel, error) { +func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) { + f.findAndLockPrefs = append(f.findAndLockPrefs, pref) return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr } +func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) { + if f.loadedReplicaStatsErr != nil { + return nil, f.loadedReplicaStatsErr + } + return f.loadedReplicaStatsByName[modelName], nil +} + func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error { f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName) return nil @@ -1055,3 +1073,355 @@ var _ = Describe("SmartRouter", func() { }) }) }) + +// --------------------------------------------------------------------------- +// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests +// --------------------------------------------------------------------------- + +type observeRecord struct { + model string + chain []uint64 + key prefixcache.ReplicaKey +} + +type invalidateRecord struct { + model string + key prefixcache.ReplicaKey +} + +// fakePrefixProvider records all interactions and returns a configurable +// decision. +type fakePrefixProvider struct { + decideCalls int + observed []observeRecord + invalidated []invalidateRecord + invalidatedNode []string + decision prefixcache.PrefixDecision +} + +func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision { + f.decideCalls++ + return f.decision +} + +func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool { + f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key}) + return true +} + +func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) { + f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key}) +} + +func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) { + f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID) +} + +func (f *fakePrefixProvider) Evict(_ time.Time) {} + +var _ = Describe("SmartRouter prefix-cache routing", func() { + var ( + backend *stubBackend + factory *stubClientFactory + unloader *fakeUnloader + ) + + BeforeEach(func() { + backend = &stubBackend{healthResult: true} + factory = &stubClientFactory{client: backend} + unloader = &fakeUnloader{ + installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"}, + } + }) + + // loadedReg builds a fake registry with one loaded healthy replica for + // "m" on node "X", plus matching replica stats so buildPreference can run. + loadedReg := func() *fakeModelRouter { + node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"} + nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"} + return &fakeModelRouter{ + findAndLockNode: node, + findAndLockNM: nm, + getModelScheduling: &ModelSchedulingConfig{ + RoutePolicy: "prefix_cache", + }, + loadedReplicaStatsByName: map[string][]ReplicaCandidate{ + "m": {{NodeID: "X", InFlight: 0}}, + }, + } + } + + Context("nil provider (round-robin floor)", func() { + It("passes a nil preference and never decides or observes", func() { + reg := loadedReg() + router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory}) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + Expect(reg.findAndLockPrefs).ToNot(BeEmpty()) + for _, p := range reg.findAndLockPrefs { + Expect(p).To(BeNil()) + } + }) + }) + + Context("with a provider", func() { + It("passes the decided node as the preference and observes the pick", func() { + reg := loadedReg() + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}} + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + Expect(prov.decideCalls).To(BeNumerically(">=", 1)) + Expect(reg.findAndLockPrefs[0]).ToNot(BeNil()) + Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X")) + Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0)) + Expect(prov.observed).To(HaveLen(1)) + Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0})) + Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3})) + }) + + It("routes a recurring prefix back to the previously observed node", func() { + // Real Index as the provider: first request observes X, second + // request with the same chain must yield PreferredNodeID == X. + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + reg := loadedReg() + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: idx, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + // First request landed on X (cold placement on the only candidate) + // and observed the prefix there. + dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()) + Expect(dFirst.HasHot).To(BeTrue()) + Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0})) + + // Second request, same chain: X is now the warm-cache hot match, so + // the preference must point at it. + _, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1] + Expect(last).ToNot(BeNil()) + Expect(last.PreferredNodeID).To(Equal("X")) + Expect(last.PreferredReplica).To(Equal(0)) + }) + + It("prefers the exact hot replica when two replicas share a node", func() { + // Two replicas of "m" live on the SAME node X: replica 0 and replica + // 1. A hot prefix observed on (X,0) must produce a preference that + // locks replica 0 specifically, NOT the sibling replica 1 on the same + // node. This is the replica-granular regression this change fixes. + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"} + nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"} + reg := &fakeModelRouter{ + findAndLockNode: node, + findAndLockNM: nm, + getModelScheduling: &ModelSchedulingConfig{ + RoutePolicy: "prefix_cache", + }, + loadedReplicaStatsByName: map[string][]ReplicaCandidate{ + "m": { + {NodeID: "X", ReplicaIndex: 0, InFlight: 0}, + {NodeID: "X", ReplicaIndex: 1, InFlight: 0}, + }, + }, + } + // Seed the index so (X,0) is the warm replica for this chain. + idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now()) + + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: idx, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + pref := reg.findAndLockPrefs[0] + Expect(pref).ToNot(BeNil()) + Expect(pref.PreferredNodeID).To(Equal("X")) + Expect(pref.PreferredReplica).To(Equal(0), + "the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen") + }) + + It("does not decide or observe when no prefix chain is present", func() { + reg := loadedReg() + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}} + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + _, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(prov.decideCalls).To(Equal(0)) + Expect(prov.observed).To(BeEmpty()) + Expect(reg.findAndLockPrefs[0]).To(BeNil()) + }) + + It("does not observe for round-robin models even with a chain", func() { + reg := loadedReg() + reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"} + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}} + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(prov.decideCalls).To(Equal(0)) + Expect(prov.observed).To(BeEmpty()) + Expect(reg.findAndLockPrefs[0]).To(BeNil()) + }) + }) + + Context("forced-disturb pressure", func() { + // disturbReg builds a registry with two candidate replicas for "m": + // the hot node X is saturated (high in_flight) and Y is free. Select + // will therefore reject the hot node and pick Y, which is the + // forced-disturb signal. findAndLockNode returns Y so Route succeeds. + disturbReg := func() *fakeModelRouter { + nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"} + nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"} + return &fakeModelRouter{ + findAndLockNode: nodeY, + findAndLockNM: nm, + getModelScheduling: &ModelSchedulingConfig{ + RoutePolicy: "prefix_cache", + }, + loadedReplicaStatsByName: map[string][]ReplicaCandidate{ + "m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}}, + }, + } + } + + It("records pressure when a strong hot match was forced off the warm node", func() { + reg := disturbReg() + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{ + Hot: prefixcache.ReplicaKey{NodeID: "X"}, + HasHot: true, + MatchRatio: 1.0, + ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}}, + }} + pressure := prefixcache.NewPressure(time.Minute) + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + Pressure: pressure, + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0), + "hot match existed but the load guard forced us off X: must record pressure") + }) + + It("does not record pressure when the hot node is itself eligible", func() { + reg := loadedReg() // single node X, in_flight 0 → X stays eligible + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{ + Hot: prefixcache.ReplicaKey{NodeID: "X"}, + HasHot: true, + MatchRatio: 1.0, + }} + pressure := prefixcache.NewPressure(time.Minute) + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + Pressure: pressure, + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + Expect(pressure.Count("m", time.Now())).To(Equal(0), + "chosen == hot node, no disturb") + }) + + It("does not record pressure for an all-unique workload with no hot match", func() { + reg := loadedReg() + prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{ + HasHot: false, // no prefix match at all + MatchRatio: 0, + ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}}, + }} + pressure := prefixcache.NewPressure(time.Minute) + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: prov, + PrefixConfig: prefixcache.DefaultConfig(), + Pressure: pressure, + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + + Expect(pressure.Count("m", time.Now())).To(Equal(0), + "no hot match means no cache to disturb: must not false-positive") + }) + }) + + Context("removal chokepoint on unload", func() { + It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() { + idx := prefixcache.NewIndex(prefixcache.DefaultConfig()) + reg := loadedReg() + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + PrefixProvider: idx, + PrefixConfig: prefixcache.DefaultConfig(), + }) + + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6}) + // Warm the cache: X now holds the prefix. + _, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0})) + + // UnloadModel must route the eviction through the registry removal + // chokepoint (RemoveAllNodeModelReplicas). The registry's + // SetReplicaRemovedHook is what invalidates the prefix index in + // production; the router no longer invalidates directly. Here the + // fake registry records the removal but fires no hook, so we assert + // the chokepoint is exercised rather than the downstream + // invalidation (covered by the registry hook integration tests). + Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed()) + Expect(reg.removeCalls).To(ContainElement("X:m"), + "UnloadModel must remove the replica via the registry removal chokepoint") + }) + }) +}) diff --git a/docs/content/features/distributed-mode.md b/docs/content/features/distributed-mode.md index f9b09079a..c9b48d835 100644 --- a/docs/content/features/distributed-mode.md +++ b/docs/content/features/distributed-mode.md @@ -558,3 +558,17 @@ All fields are optional and composable: - Ensure the port range is not blocked by firewalls or used by other services - Verify the backend gallery configuration is correct - The worker needs network access to download backends from the gallery + +## Roadmap: Routing and Caching Enhancements + +The scheduling algorithm above is load-based (least in-flight, then least-recently-used). Work is underway to make routing **prefix-cache-aware**: bias each request toward the replica that already holds the relevant KV/prefix cache (multi-turn conversations and shared system prompts), so backends reuse cache instead of recomputing it. The first step is a router-side radix tree of prompt-prefix hashes mapped to nodes, with longest-prefix match, a load guard that preserves round-robin behavior under imbalance, and NATS sync across frontends. It is purely a routing-layer hint (no backend changes) and never routes worse than today's round-robin. + +Further enhancements, surfaced from a survey of SGLang, vLLM production-stack, Ray Serve, llm-d, AIBrix, and NVIDIA Dynamo, are tracked under the routing roadmap epic ([#10063](https://github.com/mudler/LocalAI/issues/10063)): + +- **Reported/precise KV-event mode** ([#10064](https://github.com/mudler/LocalAI/issues/10064)): subscribe to actual backend KV-cache events for exact residency instead of inferring it from routing history. +- **Multi-tier cache-overlap scoring** ([#10065](https://github.com/mudler/LocalAI/issues/10065)): credit GPU/CPU/disk cache tiers separately. +- **Pluggable scorer/filter/picker pipeline** ([#10066](https://github.com/mudler/LocalAI/issues/10066)): composable multi-signal routing (cache, queue depth, KV utilization, latency). +- **Load-shaping** ([#10067](https://github.com/mudler/LocalAI/issues/10067)): anti-herding (softmax/temperature) and dispatch-time freshness. +- **Prefill/decode disaggregation routing** ([#10068](https://github.com/mudler/LocalAI/issues/10068)): route prefill and decode to separate pools with KV transfer. +- **Per-user fairness (VTC)** ([#10069](https://github.com/mudler/LocalAI/issues/10069)): balance per-user token usage against pod load. +- **Minor tuning + MCP parity** ([#10070](https://github.com/mudler/LocalAI/issues/10070)): per-model TTL override, probabilistic LRU updates, and MCP scheduling-config tool parity. diff --git a/pkg/distributedhdr/prefixhash.go b/pkg/distributedhdr/prefixhash.go new file mode 100644 index 000000000..a4124098f --- /dev/null +++ b/pkg/distributedhdr/prefixhash.go @@ -0,0 +1,37 @@ +package distributedhdr + +import "context" + +type prefixChainKey struct{} + +// WithPrefixChain attaches a prompt prefix-hash chain to ctx so the distributed +// router can make a prefix-cache-aware decision. Set at inference entry where +// the rendered prompt is known; read in SmartRouter.Route. +func WithPrefixChain(ctx context.Context, chain []uint64) context.Context { + return context.WithValue(ctx, prefixChainKey{}, chain) +} + +// PrefixChain returns the chain attached by WithPrefixChain, or nil. +func PrefixChain(ctx context.Context) []uint64 { + if v, ok := ctx.Value(prefixChainKey{}).([]uint64); ok { + return v + } + return nil +} + +// PrefixChainHook, when set at startup (distributed mode only), builds a prefix +// hash chain from a model id and rendered prompt. Left nil in single-process +// mode so there is zero overhead. See core/application/distributed.go. +var PrefixChainHook func(model, prompt string) []uint64 + +// MaybeWithPrefixChain attaches a prefix chain to ctx iff the hook is set and +// returns a non-empty chain. Otherwise returns ctx unchanged. +func MaybeWithPrefixChain(ctx context.Context, model, prompt string) context.Context { + if PrefixChainHook == nil { + return ctx + } + if chain := PrefixChainHook(model, prompt); len(chain) > 0 { + return WithPrefixChain(ctx, chain) + } + return ctx +} diff --git a/pkg/distributedhdr/prefixhash_test.go b/pkg/distributedhdr/prefixhash_test.go new file mode 100644 index 000000000..5b85db23c --- /dev/null +++ b/pkg/distributedhdr/prefixhash_test.go @@ -0,0 +1,37 @@ +package distributedhdr_test + +import ( + "context" + + "github.com/mudler/LocalAI/pkg/distributedhdr" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("prefix chain ctx", func() { + It("round-trips the chain through ctx", func() { + ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3}) + Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{1, 2, 3})) + }) + It("returns nil when absent", func() { + Expect(distributedhdr.PrefixChain(context.Background())).To(BeNil()) + }) + + It("uses the hook to build the chain when set", func() { + distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return []uint64{42} } + defer func() { distributedhdr.PrefixChainHook = nil }() + ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi") + Expect(distributedhdr.PrefixChain(ctx)).To(Equal([]uint64{42})) + }) + It("is a no-op when the hook is nil", func() { + distributedhdr.PrefixChainHook = nil + ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi") + Expect(distributedhdr.PrefixChain(ctx)).To(BeNil()) + }) + It("is a no-op when the hook returns an empty chain", func() { + distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 { return nil } + defer func() { distributedhdr.PrefixChainHook = nil }() + ctx := distributedhdr.MaybeWithPrefixChain(context.Background(), "m", "hi") + Expect(distributedhdr.PrefixChain(ctx)).To(BeNil()) + }) +}) diff --git a/pkg/radixtree/radixtree.go b/pkg/radixtree/radixtree.go new file mode 100644 index 000000000..9f6602abc --- /dev/null +++ b/pkg/radixtree/radixtree.go @@ -0,0 +1,248 @@ +// Package radixtree implements a generic prefix tree over sequences of uint64 +// key-elements, mapping the longest stored prefix of a query sequence to a +// value. Entries carry a TTL and the tree tracks a recency-weighted score per +// value. The clock is injected (callers pass `now`) so behavior is fully +// deterministic and testable. It has no external dependencies. +package radixtree + +import ( + "math" + "sync" + "time" +) + +// Options configures a Tree. +type Options struct { + // TTL is the idle lifetime of an entry. An entry whose lastSeen is older + // than TTL (relative to the `now` passed in) is treated as absent and is + // swept by Evict. Refreshed on every Insert that traverses it. The boundary + // is strict greater-than: an entry whose age is exactly equal to TTL is + // still live; it expires only once age exceeds TTL. + TTL time.Duration + // HalfLife controls recency weighting in Weight(). An entry contributes + // 0.5^(age/HalfLife). Zero means "no decay" (every live entry counts 1). + HalfLife time.Duration + // MaxEntries bounds the number of value-bearing nodes. Zero means + // unbounded. When exceeded, Insert evicts the least-recently-seen entry. + MaxEntries int +} + +// Tree is a prefix tree. V is the stored value type (for prefix-cache routing, +// a node identifier). Safe for concurrent use. +type Tree[V comparable] struct { + mu sync.RWMutex + opts Options + root *node[V] + size int +} + +type node[V comparable] struct { + children map[uint64]*node[V] + value V + hasValue bool + lastSeen time.Time +} + +// New creates an empty Tree. +func New[V comparable](opts Options) *Tree[V] { + return &Tree[V]{opts: opts, root: &node[V]{children: map[uint64]*node[V]{}}} +} + +// LongestMatch returns the value at the deepest stored, non-expired prefix of +// key, the matched depth (number of key elements consumed), and ok. +func (t *Tree[V]) LongestMatch(key []uint64, now time.Time) (V, int, bool) { + t.mu.RLock() + defer t.mu.RUnlock() + var best V + bestDepth, found := 0, false + cur := t.root + for i, k := range key { + next, ok := cur.children[k] + if !ok { + break + } + cur = next + if cur.hasValue && !t.expired(cur, now) { + best, bestDepth, found = cur.value, i+1, true + } + } + return best, bestDepth, found +} + +// expired reports whether n's lastSeen is older than the configured TTL. The +// comparison is strict greater-than: an entry whose age equals TTL exactly is +// still considered live. With TTL == 0 (unbounded) nothing ever expires. +func (t *Tree[V]) expired(n *node[V], now time.Time) bool { + return t.opts.TTL > 0 && now.Sub(n.lastSeen) > t.opts.TTL +} + +// Insert records value at EVERY node along the key chain, not just the leaf, +// so each prefix-block node remembers the value (node id) that served that +// prefix. This is what makes LongestMatch find a shared prefix even when the +// query tail diverges (SGLang/vLLM-style prefix matching). Re-inserting a +// different value over a shared prefix node overwrites it: the last writer +// owns the shared prefix node (a recency heuristic, and the correct one - the +// most recent chain that traversed that block is the one most likely warm). +// lastSeen is refreshed on every traversed node so active prefixes stay live. +// Inserting an empty key is a no-op: the root never holds a value. +func (t *Tree[V]) Insert(key []uint64, value V, now time.Time) { + if len(key) == 0 { + return + } + t.mu.Lock() + defer t.mu.Unlock() + cur := t.root + for _, k := range key { + next, ok := cur.children[k] + if !ok { + next = &node[V]{children: map[uint64]*node[V]{}} + cur.children[k] = next + } + cur = next + if !cur.hasValue { + t.size++ + } + cur.value, cur.hasValue, cur.lastSeen = value, true, now + } + if t.opts.MaxEntries > 0 && t.size > t.opts.MaxEntries { + t.evictOldestLocked(now) + } +} + +// evictOldestLocked drops the single least-recently-seen value-bearing node and +// prunes any empty branches the removal leaves behind. Called with t.mu held. +func (t *Tree[V]) evictOldestLocked(now time.Time) { + var victim *node[V] + var walk func(n *node[V]) + walk = func(n *node[V]) { + if n.hasValue && (victim == nil || n.lastSeen.Before(victim.lastSeen)) { + victim = n + } + for _, c := range n.children { + walk(c) + } + } + walk(t.root) + if victim != nil { + // Clear the victim's value and reclaim it plus any ancestors that are + // now both value-less and childless. + t.pruneWalk(t.root, func(n *node[V]) bool { return n == victim }) + } +} + +// pruneWalk clears the value of every node for which shouldClear returns true, +// then removes the now empty (value-less and childless) branches that result. +// It keeps t.size accurate by decrementing once per cleared node. Returns true +// if n itself should be removed from its parent. Called with t.mu held. +func (t *Tree[V]) pruneWalk(n *node[V], shouldClear func(*node[V]) bool) bool { + for k, c := range n.children { + if t.pruneWalk(c, shouldClear) { + delete(n.children, k) + } + } + if n.hasValue && shouldClear(n) { + n.hasValue = false + var zero V + n.value = zero + t.size-- + } + return n != t.root && !n.hasValue && len(n.children) == 0 +} + +// Len returns the number of live (value-bearing) entries, including not-yet- +// swept expired ones. Use after Evict for the post-sweep count. +func (t *Tree[V]) Len() int { + t.mu.RLock() + defer t.mu.RUnlock() + return t.size +} + +// Evict removes expired value-bearing nodes and prunes resulting empty +// branches. O(n) in tree size; call periodically from a background sweeper. +func (t *Tree[V]) Evict(now time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + t.pruneWalk(t.root, func(n *node[V]) bool { return t.expired(n, now) }) +} + +// contribution returns the recency-weighted score a single live, non-expired +// node adds to its value's weight: 1.0 when HalfLife<=0 (a plain count), else +// 0.5^(age/HalfLife). It does not check hasValue or expiry; callers must filter +// those first. Shared by Weight and WeightsFor so the metric stays identical. +func (t *Tree[V]) contribution(n *node[V], now time.Time) float64 { + if t.opts.HalfLife <= 0 { + return 1 + } + age := now.Sub(n.lastSeen).Seconds() + return math.Pow(0.5, age/t.opts.HalfLife.Seconds()) +} + +// Weight returns the recency-weighted count of live entries anchored to value: +// sum over non-expired entries of 0.5^(age/HalfLife). With HalfLife==0 every +// live entry contributes 1.0 (a plain count). This is the "valuable warm cache" +// proxy used for cold placement and autoscale. +func (t *Tree[V]) Weight(value V, now time.Time) float64 { + t.mu.RLock() + defer t.mu.RUnlock() + var sum float64 + var walk func(n *node[V]) + walk = func(n *node[V]) { + if n.hasValue && n.value == value && !t.expired(n, now) { + sum += t.contribution(n, now) + } + for _, c := range n.children { + walk(c) + } + } + walk(t.root) + return sum +} + +// WeightsFor returns the recency-weighted weight (same metric as Weight) for +// each value in values, computed in a single tree traversal. Values not present +// in the tree map to 0. This is O(N + len(values)) versus calling Weight once +// per value (O(len(values) * N)). Concurrency-safe (read lock). +func (t *Tree[V]) WeightsFor(values []V, now time.Time) map[V]float64 { + want := make(map[V]struct{}, len(values)) + result := make(map[V]float64, len(values)) + for _, v := range values { + want[v] = struct{}{} + result[v] = 0 + } + if len(want) == 0 { + return result + } + t.mu.RLock() + defer t.mu.RUnlock() + var walk func(n *node[V]) + walk = func(n *node[V]) { + if n.hasValue && !t.expired(n, now) { + if _, ok := want[n.value]; ok { + result[n.value] += t.contribution(n, now) + } + } + for _, c := range n.children { + walk(c) + } + } + walk(t.root) + return result +} + +// Remove drops every entry whose value equals value, then prunes empty +// branches. Used when a replica is unloaded or its node goes offline so the +// tree never points at a node that no longer holds the model. It is the +// equality special case of RemoveFunc. +func (t *Tree[V]) Remove(value V) { + t.RemoveFunc(func(v V) bool { return v == value }) +} + +// RemoveFunc drops every entry whose value satisfies pred, then prunes empty +// branches. Generalizes Remove (Remove(v) == RemoveFunc(func(x V) bool { return +// x == v })). Used to drop, in one walk, every entry that belongs to a class of +// values (for example all replicas of a single node). +func (t *Tree[V]) RemoveFunc(pred func(V) bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.pruneWalk(t.root, func(n *node[V]) bool { return pred(n.value) }) +} diff --git a/pkg/radixtree/radixtree_suite_test.go b/pkg/radixtree/radixtree_suite_test.go new file mode 100644 index 000000000..1dc9ab0c5 --- /dev/null +++ b/pkg/radixtree/radixtree_suite_test.go @@ -0,0 +1,13 @@ +package radixtree_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestRadixTree(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "RadixTree Suite") +} diff --git a/pkg/radixtree/radixtree_test.go b/pkg/radixtree/radixtree_test.go new file mode 100644 index 000000000..507ad9bfb --- /dev/null +++ b/pkg/radixtree/radixtree_test.go @@ -0,0 +1,354 @@ +package radixtree_test + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/pkg/radixtree" +) + +var t0 = time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + +var _ = Describe("Tree construction", func() { + It("returns an empty tree that matches nothing", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + _, depth, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0) + Expect(ok).To(BeFalse()) + Expect(depth).To(Equal(0)) + }) +}) + +var _ = Describe("Insert and LongestMatch", func() { + It("returns the deepest matching prefix value", func() { + // Non-overlapping chains keep the longest-prefix intent clean: every + // node on the value's own chain records that value, and no other Insert + // overwrites a shared prefix node. A query that runs off the end of a + // chain stops matching at the deepest stored element it reached. + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0) + tr.Insert([]uint64{7, 8}, "nodeA", t0) + + v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("nodeB")) + Expect(depth).To(Equal(4)) + + v, depth, ok = tr.LongestMatch([]uint64{7, 8, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("nodeA")) + Expect(depth).To(Equal(2)) + }) + + It("lets the last writer own a shared prefix node", func() { + // When two chains share a leading block, value-at-every-node means the + // later Insert overwrites the shared prefix node. Inserting nodeA on + // [1,2] then nodeB on [1,2,3,4] makes nodeB own [1] and [1,2], so a + // query that diverges within the shared block resolves to nodeB. This + // is the intended recency heuristic: the most recent chain through that + // block is the one most likely still warm. + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2}, "nodeA", t0) + tr.Insert([]uint64{1, 2, 3, 4}, "nodeB", t0) + + v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 4, 5}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("nodeB")) + Expect(depth).To(Equal(4)) + + // The shared prefix [1,2] is now owned by nodeB (last writer wins). + v, depth, ok = tr.LongestMatch([]uint64{1, 2, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("nodeB")) + Expect(depth).To(Equal(2)) + }) + + It("returns ok=false when no prefix is stored", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{7, 8}, "nodeA", t0) + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0) + Expect(ok).To(BeFalse()) + }) + + It("matches a shared prefix when the query tail diverges", func() { + // SGLang/vLLM-style prefix matching: a single Insert of a full chain + // must let any query that shares a leading block match at the depth of + // the deepest shared element, even though the tails differ. This is the + // core use case (shared system prompt / multi-turn extension / volatile + // tail), not exact-repeat. + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2, 3, 4, 5}, "nodeA", t0) + v, depth, ok := tr.LongestMatch([]uint64{1, 2, 3, 9, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(depth).To(Equal(3)) // shared prefix [1,2,3] + Expect(v).To(Equal("nodeA")) + }) +}) + +var _ = Describe("TTL expiry", func() { + It("does not match an entry past its TTL", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + tr.Insert([]uint64{1, 2}, "nodeA", t0) + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Minute)) + Expect(ok).To(BeFalse()) + }) + + It("refreshes lastSeen on re-insert so a live path survives", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + tr.Insert([]uint64{1, 2}, "nodeA", t0) + tr.Insert([]uint64{1, 2}, "nodeA", t0.Add(50*time.Second)) + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(90*time.Second)) + Expect(ok).To(BeTrue()) + }) + + It("Evict reclaims expired nodes", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + // Value-at-every-node: Insert of a 2-element chain records nodeA at both + // {1} and {1,2}, so Len is 2 (one valued node per distinct prefix). + tr.Insert([]uint64{1, 2}, "nodeA", t0) + Expect(tr.Len()).To(Equal(2)) + tr.Evict(t0.Add(2 * time.Minute)) + Expect(tr.Len()).To(Equal(0)) + }) +}) + +var _ = Describe("Weight", func() { + It("counts live entries for a value with no decay", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0 + tr.Insert([]uint64{1}, "A", t0) + tr.Insert([]uint64{1, 2}, "A", t0) + tr.Insert([]uint64{9}, "B", t0) + Expect(tr.Weight("A", t0)).To(BeNumerically("==", 2)) + Expect(tr.Weight("B", t0)).To(BeNumerically("==", 1)) + Expect(tr.Weight("C", t0)).To(BeNumerically("==", 0)) + }) + + It("decays older entries by half-life", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute}) + tr.Insert([]uint64{1}, "A", t0) + // one half-life later, the entry weighs 0.5 + Expect(tr.Weight("A", t0.Add(time.Minute))).To(BeNumerically("~", 0.5, 0.001)) + }) + + It("ignores expired entries", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + tr.Insert([]uint64{1}, "A", t0) + Expect(tr.Weight("A", t0.Add(2*time.Minute))).To(BeNumerically("==", 0)) + }) +}) + +var _ = Describe("WeightsFor", func() { + It("matches per-value Weight with no decay", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) // HalfLife=0 + tr.Insert([]uint64{1}, "A", t0) + tr.Insert([]uint64{1, 2}, "A", t0) + tr.Insert([]uint64{9}, "B", t0) + + got := tr.WeightsFor([]string{"A", "B", "C"}, t0) + Expect(got).To(HaveLen(3)) + Expect(got["A"]).To(BeNumerically("==", 2)) + Expect(got["B"]).To(BeNumerically("==", 1)) + Expect(got["C"]).To(BeNumerically("==", 0)) + }) + + It("matches per-value Weight under decay", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, HalfLife: time.Minute}) + tr.Insert([]uint64{1}, "A", t0) + tr.Insert([]uint64{1, 2}, "A", t0.Add(30*time.Second)) + tr.Insert([]uint64{9}, "B", t0) + + now := t0.Add(time.Minute) + got := tr.WeightsFor([]string{"A", "B", "C"}, now) + Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12)) + Expect(got["B"]).To(BeNumerically("~", tr.Weight("B", now), 1e-12)) + Expect(got["C"]).To(BeNumerically("==", 0)) + }) + + It("respects TTL expiry and matches Weight at a non-zero age under decay", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute, HalfLife: 30 * time.Second}) + tr.Insert([]uint64{1}, "A", t0) // will be expired at now + tr.Insert([]uint64{2}, "A", t0.Add(90*time.Second)) // live, aged 30s at now + tr.Insert([]uint64{9}, "B", t0) // expired at now + + now := t0.Add(2 * time.Minute) + got := tr.WeightsFor([]string{"A", "B"}, now) + Expect(got["A"]).To(BeNumerically("~", tr.Weight("A", now), 1e-12)) + Expect(got["A"]).To(BeNumerically("~", 0.5, 0.001)) // single live entry aged one half-life + Expect(got["B"]).To(BeNumerically("==", 0)) + }) + + It("returns an empty map for an empty values slice", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1}, "A", t0) + Expect(tr.WeightsFor(nil, t0)).To(BeEmpty()) + Expect(tr.WeightsFor([]string{}, t0)).To(BeEmpty()) + }) + + It("maps a value not present in the tree to 0", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1}, "A", t0) + got := tr.WeightsFor([]string{"Z"}, t0) + Expect(got).To(HaveLen(1)) + Expect(got["Z"]).To(BeNumerically("==", 0)) + }) +}) + +var _ = Describe("Remove", func() { + It("drops every entry anchored to a value and prunes", func() { + // Non-overlapping chains so Remove("A") and the survival of B are both + // meaningful: with value-at-every-node, overlapping chains would let the + // later writer own the shared prefix nodes, so A could own nothing and + // the test would be vacuous. + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2}, "A", t0) + tr.Insert([]uint64{7, 8, 9}, "B", t0) + tr.Remove("A") + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0) + Expect(ok).To(BeFalse()) // A gone; its branch is fully reclaimed + v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("B")) // B survives + Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0)) + }) +}) + +var _ = Describe("RemoveFunc", func() { + It("drops every entry matching the predicate, prunes, and keeps the rest", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2}, "drop-a", t0) + tr.Insert([]uint64{3, 4}, "drop-b", t0) + tr.Insert([]uint64{7, 8, 9}, "keep", t0) + // Drop everything whose value starts with "drop". + tr.RemoveFunc(func(v string) bool { return len(v) >= 4 && v[:4] == "drop" }) + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0) + Expect(ok).To(BeFalse()) + _, _, ok = tr.LongestMatch([]uint64{3, 4}, t0) + Expect(ok).To(BeFalse()) + v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("keep")) + Expect(tr.Len()).To(Equal(3)) // only the 3-node "keep" chain remains + }) + + It("makes Remove a special case of RemoveFunc", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2}, "A", t0) + tr.Insert([]uint64{7, 8, 9}, "B", t0) + tr.RemoveFunc(func(v string) bool { return v == "A" }) + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0) + Expect(ok).To(BeFalse()) + v, _, ok := tr.LongestMatch([]uint64{7, 8, 9}, t0) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("B")) + }) +}) + +var _ = Describe("TTL boundary", func() { + It("treats age exactly equal to TTL as still live, and one tick past as expired", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Minute}) + tr.Insert([]uint64{1, 2}, "A", t0) + + // age == TTL: strict greater-than means this is still live. + _, _, ok := tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute)) + Expect(ok).To(BeTrue()) + + // one nanosecond past TTL: expired. + _, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(time.Minute+time.Nanosecond)) + Expect(ok).To(BeFalse()) + }) +}) + +var _ = Describe("MaxEntries eviction", func() { + It("drops the least-recently-seen entry when the cap is exceeded", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2}) + tr.Insert([]uint64{1}, "A", t0) + tr.Insert([]uint64{2}, "B", t0.Add(time.Second)) + tr.Insert([]uint64{3}, "C", t0.Add(2*time.Second)) + + Expect(tr.Len()).To(Equal(2)) + + // A was the least-recently-seen, so it is the one dropped. + _, _, ok := tr.LongestMatch([]uint64{1}, t0.Add(2*time.Second)) + Expect(ok).To(BeFalse()) + + // B and C survive. + _, _, ok = tr.LongestMatch([]uint64{2}, t0.Add(2*time.Second)) + Expect(ok).To(BeTrue()) + _, _, ok = tr.LongestMatch([]uint64{3}, t0.Add(2*time.Second)) + Expect(ok).To(BeTrue()) + }) + + It("prunes value-less ancestors left behind by an eviction", func() { + // Value-at-every-node: Inserting the deep chain B = [1,2,3] records B at + // {1}, {1,2}, and {1,2,3} (three valued nodes). With the cap at 2, the + // least-recently-seen valued nodes are evicted one per subsequent Insert. + // The two fresh single-element keys (C, D) are newer, so eviction keeps + // peeling B's nodes off until B's entire branch is reclaimed - none of + // its internal nodes may linger and inflate Len past the cap. + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2}) + tr.Insert([]uint64{1, 2, 3}, "B", t0) + tr.Insert([]uint64{5}, "C", t0.Add(time.Second)) + tr.Insert([]uint64{6}, "D", t0.Add(2*time.Second)) + + Expect(tr.Len()).To(Equal(2)) + // B (oldest) evicted; its deep branch reclaimed. + _, _, ok := tr.LongestMatch([]uint64{1, 2, 3}, t0.Add(2*time.Second)) + Expect(ok).To(BeFalse()) + _, _, ok = tr.LongestMatch([]uint64{1, 2}, t0.Add(2*time.Second)) + Expect(ok).To(BeFalse()) + Expect(tr.Weight("B", t0.Add(2*time.Second))).To(BeNumerically("==", 0)) + }) + + It("reclaims structure so the tree never grows past the cap under churn", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour, MaxEntries: 2}) + tr.Insert([]uint64{1}, "A", t0) + tr.Insert([]uint64{2}, "B", t0.Add(time.Second)) + Expect(tr.Len()).To(Equal(2)) + + for i := range 10 { + tr.Insert([]uint64{uint64(100 + i)}, "X", t0.Add(time.Duration(i+2)*time.Second)) + Expect(tr.Len()).To(Equal(2)) + } + }) +}) + +var _ = Describe("Empty key", func() { + It("LongestMatch on an empty key returns ok=false", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + tr.Insert([]uint64{1, 2}, "A", t0) + _, depth, ok := tr.LongestMatch([]uint64{}, t0) + Expect(ok).To(BeFalse()) + Expect(depth).To(Equal(0)) + }) + + It("Insert with an empty key is a no-op that creates no root value", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + Expect(func() { tr.Insert([]uint64{}, "A", t0) }).NotTo(Panic()) + Expect(tr.Len()).To(Equal(0)) + _, _, ok := tr.LongestMatch([]uint64{}, t0) + Expect(ok).To(BeFalse()) + Expect(tr.Weight("A", t0)).To(BeNumerically("==", 0)) + }) +}) + +var _ = Describe("Concurrent access", func() { + It("is race-free under parallel insert/match/weight", func() { + tr := radixtree.New[string](radixtree.Options{TTL: time.Hour}) + done := make(chan struct{}) + for g := range 8 { + go func(g int) { + defer GinkgoRecover() + for i := range 1000 { + tr.Insert([]uint64{uint64(g), uint64(i % 10)}, "n", t0) + tr.LongestMatch([]uint64{uint64(g), 1}, t0) + tr.Weight("n", t0) + } + done <- struct{}{} + }(g) + } + for range 8 { + <-done + } + }) +}) diff --git a/tests/e2e/distributed/model_routing_test.go b/tests/e2e/distributed/model_routing_test.go index d8e3127f0..cb64e4622 100644 --- a/tests/e2e/distributed/model_routing_test.go +++ b/tests/e2e/distributed/model_routing_test.go @@ -63,7 +63,7 @@ var _ = Describe("Model Routing", Label("Distributed"), func() { Expect(models[0].InFlight).To(Equal(2)) // FindAndLockNodeWithModel should return this node and atomically increment in-flight - foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil) + foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.ID).To(Equal(node.ID)) Expect(foundModel.ModelName).To(Equal("llama3")) diff --git a/tests/e2e/distributed/prefix_cache_routing_test.go b/tests/e2e/distributed/prefix_cache_routing_test.go new file mode 100644 index 000000000..ad0cb709d --- /dev/null +++ b/tests/e2e/distributed/prefix_cache_routing_test.go @@ -0,0 +1,234 @@ +package distributed_test + +import ( + "context" + "time" + + "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" + "github.com/mudler/LocalAI/pkg/distributedhdr" + grpcPkg "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + ggrpc "google.golang.org/grpc" + + pgdriver "gorm.io/driver/postgres" + gormDB "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// prefixStubBackend implements grpc.Backend with a canned-success HealthCheck +// and LoadModel so SmartRouter.probeHealth passes and any cold load returns +// success — no real inference happens. Mirrors the stubBackend pattern used by +// the SmartRouter unit tests in core/services/nodes/router_test.go, reproduced +// here because that fake lives in the internal (unexported) nodes package. +type prefixStubBackend struct { + grpcPkg.Backend // embed so unused methods satisfy the interface; they panic only if called + + healthResult bool +} + +func (f *prefixStubBackend) HealthCheck(_ context.Context) (bool, error) { + return f.healthResult, nil +} + +func (f *prefixStubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) { + return &pb.Result{Success: true}, nil +} + +func (f *prefixStubBackend) IsBusy() bool { return false } + +// prefixStubClientFactory hands the same fake backend to every NewClient call, +// so the SmartRouter never opens a real gRPC connection during routing. +type prefixStubClientFactory struct { + client *prefixStubBackend +} + +func (f *prefixStubClientFactory) NewClient(_ string, _ bool) grpcPkg.Backend { + return f.client +} + +var _ = Describe("Prefix-cache aware routing", Label("Distributed"), func() { + const model = "model" + + var ( + infra *TestInfra + db *gormDB.DB + registry *nodes.NodeRegistry + router *nodes.SmartRouter + idx *prefixcache.Index + + nodeXID string + nodeYID string + + chainA = []uint64{1, 2, 3, 4, 5} // conversation A + chainShared = []uint64{1, 2, 3, 9, 9} // shares leading prefix [1,2,3] with A + chainUnrelated = []uint64{7, 8, 9} // no shared prefix with A + ) + + // routeAndSettle drives one request through the router for the given prefix + // chain and immediately settles the in-flight reservation the way a real + // inference completion would (Release closes the client; the DecrementInFlight + // emulates the OnFirstComplete callback that fires after the first inference). + // Settling keeps both nodes balanced at in_flight=0 so the prefix-cache load + // guard never falsely forces a request off its warm node between steps. + routeAndSettle := func(chain []uint64) string { + GinkgoHelper() + ctx := distributedhdr.WithPrefixChain(context.Background(), chain) + result, err := router.Route(ctx, model, model, "llama-cpp", + &pb.ModelOptions{ModelFile: model}, false) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + Expect(result.Node).ToNot(BeNil()) + nodeID := result.Node.ID + result.Release() + Expect(registry.DecrementInFlight(context.Background(), nodeID, model, 0)).To(Succeed()) + return nodeID + } + + BeforeEach(func() { + infra = SetupInfra("localai_prefix_cache_routing_test") + + var err error + db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + Expect(err).ToNot(HaveOccurred()) + + registry, err = nodes.NewNodeRegistry(db) + Expect(err).ToNot(HaveOccurred()) + + // The prefix-cache index is the real radix-tree provider. Keep a handle so + // the specs can assert Decide() directly in addition to observing Route(). + idx = prefixcache.NewIndex(prefixcache.DefaultConfig()) + + // Wire the registry chokepoint hook ourselves. In production distributed.go + // wires this; a bare SmartRouter test must register it so removal-path + // invalidation is exercised end to end. A negative replica index means + // "all replicas of the node" (InvalidateNode); otherwise drop the exact + // replica. + registry.SetReplicaRemovedHook(func(modelName, nodeID string, replica int) { + if replica < 0 { + idx.InvalidateNode(modelName, nodeID) + } else { + idx.Invalidate(modelName, prefixcache.ReplicaKey{NodeID: nodeID, Replica: replica}) + } + }) + + // Register TWO healthy nodes and mark the model loaded on both (replica 0). + nodeX := &nodes.BackendNode{Name: "node-x", Address: "127.0.0.1:50051"} + nodeY := &nodes.BackendNode{Name: "node-y", Address: "127.0.0.1:50052"} + Expect(registry.Register(context.Background(), nodeX, true)).To(Succeed()) + Expect(registry.Register(context.Background(), nodeY, true)).To(Succeed()) + nodeXID = nodeX.ID + nodeYID = nodeY.ID + Expect(registry.SetNodeModel(context.Background(), nodeXID, model, 0, "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), nodeYID, model, 0, "loaded", "", 0)).To(Succeed()) + + factory := &prefixStubClientFactory{client: &prefixStubBackend{healthResult: true}} + router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{ + ClientFactory: factory, + PrefixProvider: idx, + PrefixConfig: prefixcache.DefaultConfig(), + DB: db, + }) + }) + + It("locks affinity, honors shared prefixes, isolates unrelated chains, and re-homes on failover", func() { + now := time.Now() + // Both nodes host replica 0 of the model. + keys := []prefixcache.ReplicaKey{{NodeID: nodeXID, Replica: 0}, {NodeID: nodeYID, Replica: 0}} + + // --- Step 1: cold miss + observe ------------------------------------- + // chainA's prefix has never been seen, so there is no hot match yet; the + // request cold-places on some loaded node X and the assignment is recorded. + Expect(idx.Decide(model, chainA, keys, now).HasHot).To(BeFalse(), + "step 1: chainA must be a cold miss (no prior affinity)") + placedNode := routeAndSettle(chainA) + Expect(placedNode).To(Or(Equal(nodeXID), Equal(nodeYID))) + // From here on, "X" is whichever node served chainA first. + nodeX := placedNode + var nodeY string + if nodeX == nodeXID { + nodeY = nodeYID + } else { + nodeY = nodeXID + } + hotX := prefixcache.ReplicaKey{NodeID: nodeX, Replica: 0} + Expect(idx.Decide(model, chainA, keys, time.Now()).Hot).To(Equal(hotX), + "step 1: chainA must now be recorded against the replica that served it") + + // --- Step 2: hot-match affinity -------------------------------------- + // The SAME chain routes back to X. + Expect(routeAndSettle(chainA)).To(Equal(nodeX), + "step 2: a repeat of chainA must return to its warm node X") + + // --- Step 3: shared-prefix match (the regression we fixed) ----------- + // A DIFFERENT chain that shares the leading prefix [1,2,3] with X's chain + // but diverges at the tail still matches the shared head and routes to X. + // Before the radix-tree fix this fell through to a cold placement. + Expect(idx.Decide(model, chainShared, keys, time.Now()).Hot).To(Equal(hotX), + "step 3: chainShared must hot-match X on the shared prefix") + Expect(routeAndSettle(chainShared)).To(Equal(nodeX), + "step 3: chainShared must route to X via the shared-prefix match") + + // --- Step 4: negative control ---------------------------------------- + // A completely unrelated chain shares no prefix with X's chain, so it must + // NOT hot-match X's affinity. (Cold placement may still pick X or Y by + // load/cacheWeight, but it must not be a false hot match.) Asserting the + // provider decision directly is the robust check. + Expect(idx.Decide(model, chainUnrelated, keys, time.Now()).HasHot).To(BeFalse(), + "step 4: chainUnrelated must be a cold miss, not a false hot match on X") + + // --- Step 5: failover + invalidation --------------------------------- + // Remove node X's replica of the model. This fires the registry chokepoint + // hook, which invalidates the prefix-cache entry for X. A request for X's + // chain must then fail over to the surviving node Y, and the prefix entry + // must no longer pin to X (it re-homes to Y on the next observe). + Expect(registry.RemoveAllNodeModelReplicas(context.Background(), nodeX, model)).To(Succeed()) + + yKeys := []prefixcache.ReplicaKey{{NodeID: nodeY, Replica: 0}} + // The chokepoint hook dropped X from the index immediately. + Expect(idx.Decide(model, chainA, yKeys, time.Now()).Hot).ToNot(Equal(hotX), + "step 5: after X's replica is removed, chainA must no longer pin to X") + + // Route(chainA): only Y still hosts the model, so it fails over to Y. + Expect(routeAndSettle(chainA)).To(Equal(nodeY), + "step 5: chainA must fail over to the surviving node Y") + + // And the entry has re-homed: chainA now hot-matches Y, never X. + reHomed := idx.Decide(model, chainA, yKeys, time.Now()) + hotY := prefixcache.ReplicaKey{NodeID: nodeY, Replica: 0} + Expect(reHomed.Hot).ToNot(Equal(hotX), + "step 5: chainA must not re-home to the removed node X") + Expect(reHomed.Hot).To(Equal(hotY), + "step 5: chainA must re-home to the surviving node Y") + }) + + It("tracks affinity per replica when ONE node hosts TWO replicas of the model", func() { + // This is the bug the replica-granular change fixes: two replicas of the + // same model on the SAME node are distinct KV caches. A prefix observed + // on replica (node,0) must NOT be reported as hot on the sibling replica + // (node,1) of the same node. + const multiNodeModel = "multi-replica-model" + multiNode := &nodes.BackendNode{Name: "node-multi", Address: "127.0.0.1:50060", MaxReplicasPerModel: 2} + Expect(registry.Register(context.Background(), multiNode, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 0, "loaded", "addr0", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), multiNode.ID, multiNodeModel, 1, "loaded", "addr1", 0)).To(Succeed()) + + chain := []uint64{42, 43, 44} + key0 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 0} + key1 := prefixcache.ReplicaKey{NodeID: multiNode.ID, Replica: 1} + + // Observe the chain on replica 0 only. + idx.Observe(multiNodeModel, chain, key0, time.Now()) + + d := idx.Decide(multiNodeModel, chain, []prefixcache.ReplicaKey{key0, key1}, time.Now()) + Expect(d.HasHot).To(BeTrue()) + Expect(d.Hot).To(Equal(key0), + "the prefix was served by replica 0; the SAME-node sibling replica 1 must NOT be chosen") + }) +})