refactor: apply SOTA patterns (Atomic HitCount, Zero-Copy Body, Low-Lock RateLimit)

This commit is contained in:
Fabrizio Salmi
2025-12-06 22:52:01 +01:00
parent c29a7ce9aa
commit 00c547e2a3
6 changed files with 55 additions and 34 deletions

View File

@@ -24,6 +24,8 @@ import (
"os"
"strings"
"sync"
"sync/atomic"
"github.com/fsnotify/fsnotify"
"github.com/oschwald/maxminddb-golang"
@@ -358,15 +360,16 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
return true
}
hitCount, ok := value.(HitCount)
atomicCounter, ok := value.(*atomic.Int64)
if !ok {
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
return true
}
hitCount := atomicCounter.Load()
m.logger.Info("Rule Hit",
zap.String("rule_id", string(ruleID)),
zap.Int("hits", int(hitCount)),
zap.Int64("hits", hitCount),
)
return true
})
@@ -529,12 +532,13 @@ func (m *Middleware) getRuleHitStats() map[string]int {
m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key))
return true // Continue iteration
}
hitCount, ok := value.(HitCount)
// SOTA Pattern: Wait-Free stats collection
atomicCounter, ok := value.(*atomic.Int64)
if !ok {
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
return true // Continue iteration
}
stats[string(ruleID)] = int(hitCount)
stats[string(ruleID)] = int(atomicCounter.Load())
return true
})
return stats

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"os"
"strings"
"sync/atomic"
"time"
"go.uber.org/zap"
@@ -25,10 +26,11 @@ func (m *Middleware) DebugRequest(r *http.Request, state *WAFState, msg string)
if !ok {
return true
}
hitCount, ok := value.(HitCount)
atomicCounter, ok := value.(*atomic.Int64)
if !ok {
return true
}
hitCount := atomicCounter.Load()
ruleIDs = append(ruleIDs, string(ruleID))
scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount))
return true

View File

@@ -1757,16 +1757,16 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
middleware := &Middleware{
logger: logger,
rateLimiter: func() *RateLimiter {
rl := &RateLimiter{
config: RateLimit{
config := RateLimit{
Requests: 1,
Window: time.Minute,
CleanupInterval: time.Minute,
Paths: []string{"/api/v1/.*", "/admin/.*"},
MatchAllPaths: false,
},
requests: make(map[string]map[string]*requestCounter),
stopCleanup: make(chan struct{}),
}
rl, err := NewRateLimiter(config)
if err != nil {
t.Fatalf("Failed to create rate limiter: %v", err)
}
rl.startCleanup()
return rl

View File

@@ -58,32 +58,40 @@ func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
// isRateLimited checks if a given IP is rate limited for a specific path.
func (rl *RateLimiter) isRateLimited(ip, path string) bool {
now := time.Now()
rl.Lock() // Use Lock for write operations or potential creation of nested maps.
defer rl.Unlock()
rl.incrementTotalRequestsMetric() // Increment the total requests received
// SOTA Pattern: Reduce Lock Contention (move expensive regex out of critical section)
matched := false
var key string
// 1. Determine if this path needs limiting (Read-only config access, safe without lock if config is immutable)
if rl.config.MatchAllPaths {
matched = true
key = ip
} else {
// Check if path is matching
if len(rl.config.PathRegexes) > 0 {
matched := false
for _, regex := range rl.config.PathRegexes {
if regex.MatchString(path) {
matched = true
break
}
}
if !matched {
return false // Path does not match any configured paths, no rate limiting
}
}
if matched {
key = ip + path
}
}
}
if !matched && !rl.config.MatchAllPaths {
// Optimization: If no path matched, we don't need to track this request
rl.incrementTotalRequestsMetric()
return false
}
now := time.Now()
rl.Lock() // Critical Section Start
defer rl.Unlock()
rl.incrementTotalRequestsMetric() // Metric under lock to ensure consistency (or use atomic outside)
// Initialize the nested map if it doesn't exist
if _, exists := rl.requests[ip]; !exists {

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"strconv"
"strings"
"unsafe"
"go.uber.org/zap"
)
@@ -215,7 +216,12 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
}
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read
return string(bodyBytes), nil
// SOTA Pattern: Zero-Copy (avoid allocation for string conversion)
if len(bodyBytes) == 0 {
return "", nil
}
return unsafe.String(&bodyBytes[0], len(bodyBytes)), nil
}
// Helper function to extract all headers

View File

@@ -10,6 +10,7 @@ import (
"os"
"regexp"
"sort"
"sync/atomic"
)
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
@@ -112,14 +113,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
// incrementRuleHitCount increments the hit counter for a given rule ID.
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
hitCount := HitCount(1) // Default increment
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded {
hitCount = currentCount.(HitCount) + 1
}
m.ruleHits.Store(ruleID, hitCount)
// SOTA Pattern: Wait-Free / Lock-Free Data Structures (using atomic)
counterInterface, _ := m.ruleHits.LoadOrStore(ruleID, &atomic.Int64{})
counter := counterInterface.(*atomic.Int64)
newVal := counter.Add(1)
m.logger.Debug("Rule hit count updated",
zap.String("rule_id", string(ruleID)),
zap.Int("hit_count", int(hitCount)), // More descriptive log field
zap.Int64("hit_count", newVal),
)
}