refactor: apply SOTA patterns (Atomic HitCount, Zero-Copy Body, Low-Lock RateLimit)

This commit is contained in:
Fabrizio Salmi
2025-12-06 22:52:01 +01:00
parent c29a7ce9aa
commit 00c547e2a3
6 changed files with 55 additions and 34 deletions

View File

@@ -24,6 +24,8 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
@@ -358,15 +360,16 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
return true return true
} }
hitCount, ok := value.(HitCount) atomicCounter, ok := value.(*atomic.Int64)
if !ok { if !ok {
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value)) m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
return true return true
} }
hitCount := atomicCounter.Load()
m.logger.Info("Rule Hit", m.logger.Info("Rule Hit",
zap.String("rule_id", string(ruleID)), zap.String("rule_id", string(ruleID)),
zap.Int("hits", int(hitCount)), zap.Int64("hits", hitCount),
) )
return true return true
}) })
@@ -529,12 +532,13 @@ func (m *Middleware) getRuleHitStats() map[string]int {
m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key)) m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key))
return true // Continue iteration return true // Continue iteration
} }
hitCount, ok := value.(HitCount) // SOTA Pattern: Wait-Free stats collection
atomicCounter, ok := value.(*atomic.Int64)
if !ok { if !ok {
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value)) m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
return true // Continue iteration return true // Continue iteration
} }
stats[string(ruleID)] = int(hitCount) stats[string(ruleID)] = int(atomicCounter.Load())
return true return true
}) })
return stats return stats

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync/atomic"
"time" "time"
"go.uber.org/zap" "go.uber.org/zap"
@@ -25,10 +26,11 @@ func (m *Middleware) DebugRequest(r *http.Request, state *WAFState, msg string)
if !ok { if !ok {
return true return true
} }
hitCount, ok := value.(HitCount) atomicCounter, ok := value.(*atomic.Int64)
if !ok { if !ok {
return true return true
} }
hitCount := atomicCounter.Load()
ruleIDs = append(ruleIDs, string(ruleID)) ruleIDs = append(ruleIDs, string(ruleID))
scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount)) scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount))
return true return true

View File

@@ -1757,16 +1757,16 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
middleware := &Middleware{ middleware := &Middleware{
logger: logger, logger: logger,
rateLimiter: func() *RateLimiter { rateLimiter: func() *RateLimiter {
rl := &RateLimiter{ config := RateLimit{
config: RateLimit{ Requests: 1,
Requests: 1, Window: time.Minute,
Window: time.Minute, CleanupInterval: time.Minute,
CleanupInterval: time.Minute, Paths: []string{"/api/v1/.*", "/admin/.*"},
Paths: []string{"/api/v1/.*", "/admin/.*"}, MatchAllPaths: false,
MatchAllPaths: false, }
}, rl, err := NewRateLimiter(config)
requests: make(map[string]map[string]*requestCounter), if err != nil {
stopCleanup: make(chan struct{}), t.Fatalf("Failed to create rate limiter: %v", err)
} }
rl.startCleanup() rl.startCleanup()
return rl return rl

View File

@@ -58,33 +58,41 @@ func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
// isRateLimited checks if a given IP is rate limited for a specific path. // isRateLimited checks if a given IP is rate limited for a specific path.
func (rl *RateLimiter) isRateLimited(ip, path string) bool { func (rl *RateLimiter) isRateLimited(ip, path string) bool {
now := time.Now() // SOTA Pattern: Reduce Lock Contention (move expensive regex out of critical section)
matched := false
rl.Lock() // Use Lock for write operations or potential creation of nested maps.
defer rl.Unlock()
rl.incrementTotalRequestsMetric() // Increment the total requests received
var key string var key string
// 1. Determine if this path needs limiting (Read-only config access, safe without lock if config is immutable)
if rl.config.MatchAllPaths { if rl.config.MatchAllPaths {
matched = true
key = ip key = ip
} else { } else {
// Check if path is matching
if len(rl.config.PathRegexes) > 0 { if len(rl.config.PathRegexes) > 0 {
matched := false
for _, regex := range rl.config.PathRegexes { for _, regex := range rl.config.PathRegexes {
if regex.MatchString(path) { if regex.MatchString(path) {
matched = true matched = true
break break
} }
} }
if !matched { if matched {
return false // Path does not match any configured paths, no rate limiting key = ip + path
} }
} }
key = ip + path
} }
if !matched && !rl.config.MatchAllPaths {
// Optimization: If no path matched, we don't need to track this request
rl.incrementTotalRequestsMetric()
return false
}
now := time.Now()
rl.Lock() // Critical Section Start
defer rl.Unlock()
rl.incrementTotalRequestsMetric() // Metric under lock to ensure consistency (or use atomic outside)
// Initialize the nested map if it doesn't exist // Initialize the nested map if it doesn't exist
if _, exists := rl.requests[ip]; !exists { if _, exists := rl.requests[ip]; !exists {
rl.requests[ip] = make(map[string]*requestCounter) rl.requests[ip] = make(map[string]*requestCounter)

View File

@@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"unsafe"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -215,7 +216,12 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err) return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
} }
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read
return string(bodyBytes), nil
// SOTA Pattern: Zero-Copy (avoid allocation for string conversion)
if len(bodyBytes) == 0 {
return "", nil
}
return unsafe.String(&bodyBytes[0], len(bodyBytes)), nil
} }
// Helper function to extract all headers // Helper function to extract all headers

View File

@@ -10,6 +10,7 @@ import (
"os" "os"
"regexp" "regexp"
"sort" "sort"
"sync/atomic"
) )
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool { func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
@@ -112,14 +113,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
// incrementRuleHitCount increments the hit counter for a given rule ID. // incrementRuleHitCount increments the hit counter for a given rule ID.
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) { func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
hitCount := HitCount(1) // Default increment // SOTA Pattern: Wait-Free / Lock-Free Data Structures (using atomic)
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded { counterInterface, _ := m.ruleHits.LoadOrStore(ruleID, &atomic.Int64{})
hitCount = currentCount.(HitCount) + 1 counter := counterInterface.(*atomic.Int64)
} newVal := counter.Add(1)
m.ruleHits.Store(ruleID, hitCount)
m.logger.Debug("Rule hit count updated", m.logger.Debug("Rule hit count updated",
zap.String("rule_id", string(ruleID)), zap.String("rule_id", string(ruleID)),
zap.Int("hit_count", int(hitCount)), // More descriptive log field zap.Int64("hit_count", newVal),
) )
} }