mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
refactor: apply SOTA patterns (Atomic HitCount, Zero-Copy Body, Low-Lock RateLimit)
This commit is contained in:
12
caddywaf.go
12
caddywaf.go
@@ -24,6 +24,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
@@ -358,15 +360,16 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
|
||||
return true
|
||||
}
|
||||
|
||||
hitCount, ok := value.(HitCount)
|
||||
atomicCounter, ok := value.(*atomic.Int64)
|
||||
if !ok {
|
||||
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
||||
return true
|
||||
}
|
||||
hitCount := atomicCounter.Load()
|
||||
|
||||
m.logger.Info("Rule Hit",
|
||||
zap.String("rule_id", string(ruleID)),
|
||||
zap.Int("hits", int(hitCount)),
|
||||
zap.Int64("hits", hitCount),
|
||||
)
|
||||
return true
|
||||
})
|
||||
@@ -529,12 +532,13 @@ func (m *Middleware) getRuleHitStats() map[string]int {
|
||||
m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key))
|
||||
return true // Continue iteration
|
||||
}
|
||||
hitCount, ok := value.(HitCount)
|
||||
// SOTA Pattern: Wait-Free stats collection
|
||||
atomicCounter, ok := value.(*atomic.Int64)
|
||||
if !ok {
|
||||
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
||||
return true // Continue iteration
|
||||
}
|
||||
stats[string(ruleID)] = int(hitCount)
|
||||
stats[string(ruleID)] = int(atomicCounter.Load())
|
||||
return true
|
||||
})
|
||||
return stats
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -25,10 +26,11 @@ func (m *Middleware) DebugRequest(r *http.Request, state *WAFState, msg string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
hitCount, ok := value.(HitCount)
|
||||
atomicCounter, ok := value.(*atomic.Int64)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
hitCount := atomicCounter.Load()
|
||||
ruleIDs = append(ruleIDs, string(ruleID))
|
||||
scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount))
|
||||
return true
|
||||
|
||||
@@ -1757,16 +1757,16 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
rateLimiter: func() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
config: RateLimit{
|
||||
config := RateLimit{
|
||||
Requests: 1,
|
||||
Window: time.Minute,
|
||||
CleanupInterval: time.Minute,
|
||||
Paths: []string{"/api/v1/.*", "/admin/.*"},
|
||||
MatchAllPaths: false,
|
||||
},
|
||||
requests: make(map[string]map[string]*requestCounter),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
rl, err := NewRateLimiter(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create rate limiter: %v", err)
|
||||
}
|
||||
rl.startCleanup()
|
||||
return rl
|
||||
|
||||
@@ -58,32 +58,40 @@ func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
|
||||
|
||||
// isRateLimited checks if a given IP is rate limited for a specific path.
|
||||
func (rl *RateLimiter) isRateLimited(ip, path string) bool {
|
||||
now := time.Now()
|
||||
|
||||
rl.Lock() // Use Lock for write operations or potential creation of nested maps.
|
||||
defer rl.Unlock()
|
||||
|
||||
rl.incrementTotalRequestsMetric() // Increment the total requests received
|
||||
|
||||
// SOTA Pattern: Reduce Lock Contention (move expensive regex out of critical section)
|
||||
matched := false
|
||||
var key string
|
||||
|
||||
// 1. Determine if this path needs limiting (Read-only config access, safe without lock if config is immutable)
|
||||
if rl.config.MatchAllPaths {
|
||||
matched = true
|
||||
key = ip
|
||||
} else {
|
||||
// Check if path is matching
|
||||
if len(rl.config.PathRegexes) > 0 {
|
||||
matched := false
|
||||
for _, regex := range rl.config.PathRegexes {
|
||||
if regex.MatchString(path) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return false // Path does not match any configured paths, no rate limiting
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
key = ip + path
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && !rl.config.MatchAllPaths {
|
||||
// Optimization: If no path matched, we don't need to track this request
|
||||
rl.incrementTotalRequestsMetric()
|
||||
return false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
rl.Lock() // Critical Section Start
|
||||
defer rl.Unlock()
|
||||
|
||||
rl.incrementTotalRequestsMetric() // Metric under lock to ensure consistency (or use atomic outside)
|
||||
|
||||
// Initialize the nested map if it doesn't exist
|
||||
if _, exists := rl.requests[ip]; !exists {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -215,7 +216,12 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s
|
||||
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
||||
}
|
||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read
|
||||
return string(bodyBytes), nil
|
||||
|
||||
// SOTA Pattern: Zero-Copy (avoid allocation for string conversion)
|
||||
if len(bodyBytes) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return unsafe.String(&bodyBytes[0], len(bodyBytes)), nil
|
||||
}
|
||||
|
||||
// Helper function to extract all headers
|
||||
|
||||
13
rules.go
13
rules.go
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
|
||||
@@ -112,14 +113,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
|
||||
|
||||
// incrementRuleHitCount increments the hit counter for a given rule ID.
|
||||
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
|
||||
hitCount := HitCount(1) // Default increment
|
||||
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded {
|
||||
hitCount = currentCount.(HitCount) + 1
|
||||
}
|
||||
m.ruleHits.Store(ruleID, hitCount)
|
||||
// SOTA Pattern: Wait-Free / Lock-Free Data Structures (using atomic)
|
||||
counterInterface, _ := m.ruleHits.LoadOrStore(ruleID, &atomic.Int64{})
|
||||
counter := counterInterface.(*atomic.Int64)
|
||||
newVal := counter.Add(1)
|
||||
|
||||
m.logger.Debug("Rule hit count updated",
|
||||
zap.String("rule_id", string(ruleID)),
|
||||
zap.Int("hit_count", int(hitCount)), // More descriptive log field
|
||||
zap.Int64("hit_count", newVal),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user