diff --git a/core/cli/federated.go b/core/cli/federated.go index 10935dc0a..a19e11df0 100644 --- a/core/cli/federated.go +++ b/core/cli/federated.go @@ -15,12 +15,13 @@ type FederatedCLI struct { Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances." group:"p2p"` TargetWorker string `env:"LOCALAI_TARGET_WORKER,TARGET_WORKER" help:"Target worker to run the federated server on" group:"p2p"` UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-size limit in megabytes" group:"api"` + AffinitySync bool `env:"LOCALAI_FEDERATED_AFFINITY_SYNC,FEDERATED_AFFINITY_SYNC" default:"false" help:"Broadcast prefix-cache affinity observations to other federation servers over the p2p generic channel (enable on every federation server that should cohere)" group:"p2p"` } func (f *FederatedCLI) Run(ctx *cliContext.Context) error { warnDeprecatedFlags() - fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker, int64(f.UploadLimit)*1024*1024) + fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker, int64(f.UploadLimit)*1024*1024, f.AffinitySync) c, cancel := context.WithCancel(context.Background()) diff --git a/core/p2p/federated.go b/core/p2p/federated.go index d12ad7829..3a5853935 100644 --- a/core/p2p/federated.go +++ b/core/p2p/federated.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "encoding/json" "fmt" "math/rand/v2" @@ -33,20 +34,25 @@ type FederatedServer struct { prefixCfg prefixcache.Config prefixIndex *prefixcache.Index prefixSync *prefixcache.Sync + prefixProvider prefixcache.Provider // Index (sync off) or Sync (sync on) + syncAffinity bool } -func NewFederatedServer(listenAddr, service, p2pToken string, loadBalanced bool, workerTarget string, bodyLimit int64) *FederatedServer { +func NewFederatedServer(listenAddr, service, p2pToken string, loadBalanced bool, workerTarget string, bodyLimit int64, syncAffinity bool) *FederatedServer { cfg := prefixcache.DefaultConfig() + idx := prefixcache.NewIndex(cfg) return &FederatedServer{ - listenAddr: listenAddr, - service: service, - p2ptoken: p2pToken, - requestTable: map[string]int{}, - loadBalanced: loadBalanced, - workerTarget: workerTarget, - bodyLimit: bodyLimit, - prefixCfg: cfg, - prefixIndex: prefixcache.NewIndex(cfg), + listenAddr: listenAddr, + service: service, + p2ptoken: p2pToken, + requestTable: map[string]int{}, + loadBalanced: loadBalanced, + workerTarget: workerTarget, + bodyLimit: bodyLimit, + prefixCfg: cfg, + prefixIndex: idx, + prefixProvider: idx, + syncAffinity: syncAffinity, } } @@ -151,7 +157,7 @@ func extractModel(queryModel string, body []byte) string { // chain, or "" when there is no match strong enough among the candidates. It // reuses prefixcache's per-model radix-tree Decide; the final load-guarded pick // is done by clusterrouting.PickWithAffinity so the VRAM tier is preserved. -func affinityPreferred(idx *prefixcache.Index, model string, chain []uint64, candidates []clusterrouting.ReplicaCandidate, cfg prefixcache.Config, now time.Time) string { +func affinityPreferred(idx prefixcache.Provider, model string, chain []uint64, candidates []clusterrouting.ReplicaCandidate, cfg prefixcache.Config, now time.Time) string { if idx == nil || len(chain) == 0 || len(candidates) == 0 { return "" } @@ -186,9 +192,9 @@ func (fs *FederatedServer) selectPeer(model string, body []byte, now time.Time) } var chain []uint64 preferred := "" - if fs.prefixIndex != nil && model != "" && len(body) > 0 { + if fs.prefixProvider != nil && model != "" && len(body) > 0 { chain = prefixcache.ExtractChain(model, string(body), fs.prefixCfg) - preferred = affinityPreferred(fs.prefixIndex, model, chain, candidates, fs.prefixCfg, now) + preferred = affinityPreferred(fs.prefixProvider, model, chain, candidates, fs.prefixCfg, now) } best := clusterrouting.PickWithAffinity(candidates, preferred, fs.prefixCfg.BalanceAbsThreshold) if best == nil { @@ -200,10 +206,29 @@ func (fs *FederatedServer) selectPeer(model string, body []byte, now time.Time) // observeServed records that peerID served the given chain for model, so the // next request sharing that prefix is routed back to the same warm peer. func (fs *FederatedServer) observeServed(model string, chain []uint64, peerID string, now time.Time) { - if fs.prefixIndex == nil || len(chain) == 0 || peerID == "" || model == "" { + if fs.prefixProvider == nil || len(chain) == 0 || peerID == "" || model == "" { return } - fs.prefixIndex.Observe(model, chain, prefixcache.ReplicaKey{NodeID: peerID}, now) + fs.prefixProvider.Observe(model, chain, prefixcache.ReplicaKey{NodeID: peerID}, now) +} + +// evictLoop periodically sweeps expired affinity entries so the in-memory tree +// does not grow unbounded. Runs for the lifetime of the proxy. +func (fs *FederatedServer) evictLoop(ctx context.Context) { + interval := fs.prefixCfg.TTL / 2 + if interval <= 0 { + interval = time.Minute + } + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case now := <-t.C: + fs.prefixProvider.Evict(now) + } + } } func (fs *FederatedServer) RecordRequest(nodeID string) { diff --git a/core/p2p/federated_server.go b/core/p2p/federated_server.go index aaaa8245a..46ba5cf15 100644 --- a/core/p2p/federated_server.go +++ b/core/p2p/federated_server.go @@ -13,6 +13,7 @@ import ( "time" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/xlog" ) @@ -62,10 +63,19 @@ func isWebsocketUpgrade(req *http.Request) bool { } func (f *FederatedServer) Start(ctx context.Context) error { - n, err := NewNode(f.p2ptoken) + var extraOpts []node.Option + if f.syncAffinity { + extraOpts = append(extraOpts, node.EnableGenericHub, node.GenericChannelHandlers(f.affinityHandler())) + } + n, err := NewNode(f.p2ptoken, extraOpts...) if err != nil { return fmt.Errorf("creating a new node: %w", err) } + if f.syncAffinity { + f.prefixSync = prefixcache.NewSync(f.prefixIndex, &genericChannelPublisher{node: n}) + f.prefixProvider = f.prefixSync + xlog.Info("Federation affinity sync enabled (generic channel)") + } err = n.Start(ctx) if err != nil { return fmt.Errorf("creating a new node: %w", err) @@ -77,6 +87,8 @@ func (f *FederatedServer) Start(ctx context.Context) error { return err } + go f.evictLoop(ctx) + return f.proxy(ctx, n) } diff --git a/core/p2p/federated_test.go b/core/p2p/federated_test.go index 817a79564..407058d9c 100644 --- a/core/p2p/federated_test.go +++ b/core/p2p/federated_test.go @@ -179,3 +179,18 @@ var _ = Describe("L7 request handling", func() { Expect(isWebsocketUpgrade(req)).To(BeFalse()) }) }) + +var _ = Describe("affinityPreferred with a sync provider", func() { + ref := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + + It("returns the warm peer when the provider is a Sync wrapping the index", func() { + cfg := prefixcache.DefaultConfig() + idx := prefixcache.NewIndex(cfg) + sync := prefixcache.NewSync(idx, nil) + chain := prefixcache.ExtractChain("m1", `{"model":"m1","messages":[{"role":"system","content":"a long shared system prompt for affinity"}]}`, cfg) + sync.Observe("m1", chain, prefixcache.ReplicaKey{NodeID: "warm"}, ref) + + cands := []clusterrouting.ReplicaCandidate{{NodeID: "warm"}, {NodeID: "cold"}} + Expect(affinityPreferred(sync, "m1", chain, cands, cfg, ref)).To(Equal("warm")) + }) +}) diff --git a/core/p2p/p2p.go b/core/p2p/p2p.go index f7a0e0a26..9e42b30e2 100644 --- a/core/p2p/p2p.go +++ b/core/p2p/p2p.go @@ -409,11 +409,12 @@ func ExposeService(ctx context.Context, host, port, token, servicesID string, mo return n, err } -func NewNode(token string) (*node.Node, error) { +func NewNode(token string, extraOpts ...node.Option) (*node.Node, error) { nodeOpts, err := newNodeOpts(token) if err != nil { return nil, err } + nodeOpts = append(nodeOpts, extraOpts...) n, err := node.New(nodeOpts...) if err != nil {