diff --git a/core/services/nodes/inflight.go b/core/services/nodes/inflight.go index 1be7518f3..ad866aed2 100644 --- a/core/services/nodes/inflight.go +++ b/core/services/nodes/inflight.go @@ -2,6 +2,7 @@ package nodes import ( "context" + "sync" "time" "github.com/mudler/LocalAI/pkg/grpc" @@ -18,6 +19,9 @@ type InFlightTrackingClient struct { registry InFlightTracker nodeID string modelName string + + firstOnce sync.Once // guards onFirstComplete + onFirstComplete func() // called once after the first tracked inference call completes } // NewInFlightTrackingClient wraps a gRPC backend client with in-flight tracking. @@ -30,6 +34,14 @@ func NewInFlightTrackingClient(inner grpc.Backend, registry InFlightTracker, nod } } +// OnFirstComplete registers a callback that fires once after the first tracked +// inference call completes. This is used to release the initial in-flight +// reservation (set during model load) after the triggering request finishes, +// so that in-flight returns to 0 when the model is idle. +func (c *InFlightTrackingClient) OnFirstComplete(fn func()) { + c.onFirstComplete = fn +} + func (c *InFlightTrackingClient) track(ctx context.Context) func() { if err := c.registry.IncrementInFlight(ctx, c.nodeID, c.modelName); err != nil { xlog.Warn("Failed to increment in-flight counter", "node", c.nodeID, "model", c.modelName, "error", err) @@ -39,6 +51,10 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() { decCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c.registry.DecrementInFlight(decCtx, c.nodeID, c.modelName) + // Release the initial reservation after the first inference call completes + if c.onFirstComplete != nil { + c.firstOnce.Do(c.onFirstComplete) + } } } diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go index ce0165e60..946dccc23 100644 --- a/core/services/nodes/model_router_test.go +++ b/core/services/nodes/model_router_test.go @@ -160,13 +160,19 @@ var _ = Describe("ModelRouterAdapter", func() { adapter.mu.Unlock() Expect(hasRelease).To(BeTrue()) - // Verify calling ReleaseModel triggers the release (which decrements in-flight) - adapter.ReleaseModel("test-model") - + // The initial in-flight reservation is released via OnFirstComplete after + // the first inference call, not during ReleaseModel. ReleaseModel only + // closes the client. fakeReg.mu.Lock() - count := fakeReg.decrementCalled["node-1:test-model"] + countBeforeRelease := fakeReg.decrementCalled["node-1:test-model"] fakeReg.mu.Unlock() - Expect(count).To(Equal(1)) + Expect(countBeforeRelease).To(Equal(0)) + + adapter.ReleaseModel("test-model") + fakeReg.mu.Lock() + countAfterRelease := fakeReg.decrementCalled["node-1:test-model"] + fakeReg.mu.Unlock() + Expect(countAfterRelease).To(Equal(0)) }) }) }) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index 7d77b62bc..8e7cc0359 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -130,15 +130,20 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType "node", node.Name, "model", trackingKey) // Fall through to step 2 (scheduleNewModel) } else { - // Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight, - // and Release decrements it. No InFlightTrackingClient to avoid double-counting. + // Node is alive — FindAndLockNodeWithModel already incremented in-flight as a + // reservation. InFlightTrackingClient handles per-inference tracking, and its + // onFirstComplete callback releases the reservation after the first inference + // call finishes, so in-flight returns to 0 when idle. r.registry.TouchNodeModel(ctx, node.ID, trackingKey) grpcClient := r.buildClientForAddr(node, modelAddr, parallel) + tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey) + tracked.OnFirstComplete(func() { + r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) + }) return &RouteResult{ Node: node, - Client: grpcClient, + Client: tracked, Release: func() { - r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) closeClient(grpcClient) }, }, nil @@ -171,14 +176,18 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType "node", node.Name, "model", trackingKey) // Fall through to scheduling below } else { - // Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting + // Model loaded while we waited — FindAndLockNodeWithModel already incremented + // in-flight as a reservation. Release it after the first inference completes. r.registry.TouchNodeModel(ctx, node.ID, trackingKey) grpcClient := r.buildClientForAddr(node, modelAddr, parallel) + tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey) + tracked.OnFirstComplete(func() { + r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) + }) return &RouteResult{ Node: node, - Client: grpcClient, + Client: tracked, Release: func() { - r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) closeClient(grpcClient) }, }, nil @@ -225,11 +234,13 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType } tracked := NewInFlightTrackingClient(client, r.registry, node.ID, trackingKey) + tracked.OnFirstComplete(func() { + r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) + }) return &RouteResult{ Node: node, Client: tracked, Release: func() { - r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) closeClient(client) }, }, nil diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go index 8c53531fe..7d3668337 100644 --- a/core/services/nodes/router_test.go +++ b/core/services/nodes/router_test.go @@ -307,9 +307,12 @@ var _ = Describe("SmartRouter", func() { // TouchNodeModel should have been called Expect(reg.touchCalls).To(ContainElement("n1:my-model")) - // Call release — should decrement in-flight + // The initial in-flight reservation from FindAndLockNodeWithModel is released + // after the first inference call completes via OnFirstComplete callback. + // Release only closes the client. result.Release() - Expect(reg.decrementCalls).To(ContainElement("n1:my-model")) + // No decrement on Release — it happens via OnFirstComplete after first Predict + Expect(reg.decrementCalls).To(BeEmpty()) }) }) diff --git a/tests/e2e/distributed/router_tracking_test.go b/tests/e2e/distributed/router_tracking_test.go index 30f09b0e7..1643bbc70 100644 --- a/tests/e2e/distributed/router_tracking_test.go +++ b/tests/e2e/distributed/router_tracking_test.go @@ -139,7 +139,7 @@ var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() { Expect(err).ToNot(HaveOccurred()) defer result.Release() - // Read the baseline in-flight count (Route sets initialInFlight=1) + // Read the baseline in-flight count (Route sets initialInFlight=1, decremented after first inference) models, err := registry.GetNodeModels(context.Background(), nodeID) Expect(err).ToNot(HaveOccurred()) var baseline int