Files
LocalAI/core/http/middleware/trace.go
Richard Palethorpe 99b5c5f156 feat(api): Allow tracing of requests and responses (#7609)
* feat(api): Allow tracing of requests and responses

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(traces): Add traces UI

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2025-12-29 11:06:06 +01:00

157 lines
3.5 KiB
Go

package middleware
import (
"bytes"
"github.com/emirpasic/gods/v2/queues/circularbuffer"
"io"
"net/http"
"sort"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/xlog"
)
type APIExchangeRequest struct {
Method string `json:"method"`
Path string `json:"path"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
}
type APIExchangeResponse struct {
Status int `json:"status"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
}
type APIExchange struct {
Timestamp time.Time `json:"timestamp"`
Request APIExchangeRequest `json:"request"`
Response APIExchangeResponse `json:"response"`
}
var traceBuffer *circularbuffer.Queue[APIExchange]
var mu sync.Mutex
var logChan = make(chan APIExchange, 100)
type bodyWriter struct {
http.ResponseWriter
body *bytes.Buffer
}
func (w *bodyWriter) Write(b []byte) (int, error) {
w.body.Write(b)
return w.ResponseWriter.Write(b)
}
func (w *bodyWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// TraceMiddleware intercepts and logs JSON API requests and responses
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
if app.ApplicationConfig().EnableTracing && traceBuffer == nil {
traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems)
go func() {
for exchange := range logChan {
mu.Lock()
traceBuffer.Enqueue(exchange)
mu.Unlock()
}
}()
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !app.ApplicationConfig().EnableTracing {
return next(c)
}
if c.Request().Header.Get("Content-Type") != "application/json" {
return next(c)
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
xlog.Error("Failed to read request body")
return err
}
// Restore the body for downstream handlers
c.Request().Body = io.NopCloser(bytes.NewBuffer(body))
startTime := time.Now()
// Wrap response writer to capture body
resBody := new(bytes.Buffer)
mw := &bodyWriter{
ResponseWriter: c.Response().Writer,
body: resBody,
}
c.Response().Writer = mw
err = next(c)
if err != nil {
c.Response().Writer = mw.ResponseWriter // Restore original writer if error
return err
}
// Create exchange log
requestHeaders := c.Request().Header.Clone()
requestBody := make([]byte, len(body))
copy(requestBody, body)
responseHeaders := c.Response().Header().Clone()
responseBody := make([]byte, resBody.Len())
copy(responseBody, resBody.Bytes())
exchange := APIExchange{
Timestamp: startTime,
Request: APIExchangeRequest{
Method: c.Request().Method,
Path: c.Path(),
Headers: &requestHeaders,
Body: &requestBody,
},
Response: APIExchangeResponse{
Status: c.Response().Status,
Headers: &responseHeaders,
Body: &responseBody,
},
}
select {
case logChan <- exchange:
default:
xlog.Warn("Trace channel full, dropping trace")
}
return nil
}
}
}
// GetTraces returns a copy of the logged API exchanges for display
func GetTraces() []APIExchange {
mu.Lock()
traces := traceBuffer.Values()
mu.Unlock()
sort.Slice(traces, func(i, j int) bool {
return traces[i].Timestamp.Before(traces[j].Timestamp)
})
return traces
}
// ClearTraces clears the in-memory logs
func ClearTraces() {
mu.Lock()
traceBuffer.Clear()
mu.Unlock()
}