From 4a557f1b2bd5979ef6e294c5967defa990d1b0b3 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 20 May 2026 22:52:14 +0000 Subject: [PATCH] feat(usage): UsageMiddleware records source + snapshots key name Reads auth_source and auth_apikey from the Echo context (set by auth.Middleware in the previous task). Snapshots UserAPIKey.ID and Name onto each row so revoked keys remain readable in history. Falls back to source=web when no auth_source is set (auth disabled or unrecognised path). Refs: #9862 Signed-off-by: Ettore Di Giacinto --- core/http/middleware/usage.go | 14 ++++ core/http/middleware/usage_test.go | 116 +++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 core/http/middleware/usage_test.go diff --git a/core/http/middleware/usage.go b/core/http/middleware/usage.go index b82c1ee3f..59a86ffc5 100644 --- a/core/http/middleware/usage.go +++ b/core/http/middleware/usage.go @@ -149,9 +149,17 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { return handlerErr } + source := auth.GetSource(c) + if source == "" { + // Auth disabled or unrecognised path: classify as web so the row is still + // bucketable rather than silently dropped from per-source aggregates. + source = auth.UsageSourceWeb + } + record := &auth.UsageRecord{ UserID: user.ID, UserName: user.Name, + Source: source, Model: resp.Model, Endpoint: c.Request().URL.Path, PromptTokens: resp.Usage.PromptTokens, @@ -161,6 +169,12 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { CreatedAt: startTime, } + if key := auth.GetAPIKey(c); key != nil { + id := key.ID + record.APIKeyID = &id + record.APIKeyName = key.Name + } + batcher.add(record) return handlerErr diff --git a/core/http/middleware/usage_test.go b/core/http/middleware/usage_test.go new file mode 100644 index 000000000..10ab8e92e --- /dev/null +++ b/core/http/middleware/usage_test.go @@ -0,0 +1,116 @@ +//go:build auth + +package middleware_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/http/middleware" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +// testAuthDB returns a fresh in-memory SQLite auth DB. +func testAuthDB() *gorm.DB { + db, err := auth.InitDB(":memory:") + if err != nil { + panic(err) + } + return db +} + +var _ = Describe("UsageMiddleware", func() { + var ( + e *echo.Echo + db *gorm.DB + ) + + BeforeEach(func() { + db = testAuthDB() + e = echo.New() + middleware.InitUsageRecorder(db) + }) + + okHandler := func(c echo.Context) error { + body, _ := json.Marshal(map[string]any{ + "model": "gpt-4", + "usage": map[string]int{ + "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, + }, + }) + c.Response().Header().Set("Content-Type", "application/json") + c.Response().WriteHeader(http.StatusOK) + _, _ = c.Response().Write(body) + return nil + } + + // The batcher flushes every 5s. For tests we wait one tick past that. + flush := func() { time.Sleep(6 * time.Second) } + + It("records source=web when auth_source is web", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + c.Set("auth_source", auth.UsageSourceWeb) + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceWeb)) + Expect(rec.APIKeyID).To(BeNil()) + Expect(rec.APIKeyName).To(BeEmpty()) + }) + + It("records source=apikey with snapshotted name when auth_apikey is set", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + c.Set("auth_source", auth.UsageSourceAPIKey) + c.Set("auth_apikey", &auth.UserAPIKey{ID: "key-1", Name: "ci-runner"}) + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceAPIKey)) + Expect(rec.APIKeyID).ToNot(BeNil()) + Expect(*rec.APIKeyID).To(Equal("key-1")) + Expect(rec.APIKeyName).To(Equal("ci-runner")) + }) + + It("falls back to source=web when auth_source is empty", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + // no auth_source set + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceWeb)) + }) +})