mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
refactor: apply SOTA patterns (Atomic HitCount, Zero-Copy Body, Low-Lock RateLimit)
This commit is contained in:
12
caddywaf.go
12
caddywaf.go
@@ -24,6 +24,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
@@ -358,15 +360,16 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
hitCount, ok := value.(HitCount)
|
atomicCounter, ok := value.(*atomic.Int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
hitCount := atomicCounter.Load()
|
||||||
|
|
||||||
m.logger.Info("Rule Hit",
|
m.logger.Info("Rule Hit",
|
||||||
zap.String("rule_id", string(ruleID)),
|
zap.String("rule_id", string(ruleID)),
|
||||||
zap.Int("hits", int(hitCount)),
|
zap.Int64("hits", hitCount),
|
||||||
)
|
)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -529,12 +532,13 @@ func (m *Middleware) getRuleHitStats() map[string]int {
|
|||||||
m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key))
|
m.logger.Error("Invalid type for rule ID in ruleHits map", zap.Any("key", key))
|
||||||
return true // Continue iteration
|
return true // Continue iteration
|
||||||
}
|
}
|
||||||
hitCount, ok := value.(HitCount)
|
// SOTA Pattern: Wait-Free stats collection
|
||||||
|
atomicCounter, ok := value.(*atomic.Int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
m.logger.Error("Invalid type for hit count in ruleHits map", zap.Any("value", value))
|
||||||
return true // Continue iteration
|
return true // Continue iteration
|
||||||
}
|
}
|
||||||
stats[string(ruleID)] = int(hitCount)
|
stats[string(ruleID)] = int(atomicCounter.Load())
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -25,10 +26,11 @@ func (m *Middleware) DebugRequest(r *http.Request, state *WAFState, msg string)
|
|||||||
if !ok {
|
if !ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
hitCount, ok := value.(HitCount)
|
atomicCounter, ok := value.(*atomic.Int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
hitCount := atomicCounter.Load()
|
||||||
ruleIDs = append(ruleIDs, string(ruleID))
|
ruleIDs = append(ruleIDs, string(ruleID))
|
||||||
scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount))
|
scores = append(scores, fmt.Sprintf("%s:%d", string(ruleID), hitCount))
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -1757,16 +1757,16 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
|||||||
middleware := &Middleware{
|
middleware := &Middleware{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
rateLimiter: func() *RateLimiter {
|
rateLimiter: func() *RateLimiter {
|
||||||
rl := &RateLimiter{
|
config := RateLimit{
|
||||||
config: RateLimit{
|
Requests: 1,
|
||||||
Requests: 1,
|
Window: time.Minute,
|
||||||
Window: time.Minute,
|
CleanupInterval: time.Minute,
|
||||||
CleanupInterval: time.Minute,
|
Paths: []string{"/api/v1/.*", "/admin/.*"},
|
||||||
Paths: []string{"/api/v1/.*", "/admin/.*"},
|
MatchAllPaths: false,
|
||||||
MatchAllPaths: false,
|
}
|
||||||
},
|
rl, err := NewRateLimiter(config)
|
||||||
requests: make(map[string]map[string]*requestCounter),
|
if err != nil {
|
||||||
stopCleanup: make(chan struct{}),
|
t.Fatalf("Failed to create rate limiter: %v", err)
|
||||||
}
|
}
|
||||||
rl.startCleanup()
|
rl.startCleanup()
|
||||||
return rl
|
return rl
|
||||||
|
|||||||
@@ -58,33 +58,41 @@ func NewRateLimiter(config RateLimit) (*RateLimiter, error) {
|
|||||||
|
|
||||||
// isRateLimited checks if a given IP is rate limited for a specific path.
|
// isRateLimited checks if a given IP is rate limited for a specific path.
|
||||||
func (rl *RateLimiter) isRateLimited(ip, path string) bool {
|
func (rl *RateLimiter) isRateLimited(ip, path string) bool {
|
||||||
now := time.Now()
|
// SOTA Pattern: Reduce Lock Contention (move expensive regex out of critical section)
|
||||||
|
matched := false
|
||||||
rl.Lock() // Use Lock for write operations or potential creation of nested maps.
|
|
||||||
defer rl.Unlock()
|
|
||||||
|
|
||||||
rl.incrementTotalRequestsMetric() // Increment the total requests received
|
|
||||||
|
|
||||||
var key string
|
var key string
|
||||||
|
|
||||||
|
// 1. Determine if this path needs limiting (Read-only config access, safe without lock if config is immutable)
|
||||||
if rl.config.MatchAllPaths {
|
if rl.config.MatchAllPaths {
|
||||||
|
matched = true
|
||||||
key = ip
|
key = ip
|
||||||
} else {
|
} else {
|
||||||
// Check if path is matching
|
|
||||||
if len(rl.config.PathRegexes) > 0 {
|
if len(rl.config.PathRegexes) > 0 {
|
||||||
matched := false
|
|
||||||
for _, regex := range rl.config.PathRegexes {
|
for _, regex := range rl.config.PathRegexes {
|
||||||
if regex.MatchString(path) {
|
if regex.MatchString(path) {
|
||||||
matched = true
|
matched = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !matched {
|
if matched {
|
||||||
return false // Path does not match any configured paths, no rate limiting
|
key = ip + path
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
key = ip + path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !matched && !rl.config.MatchAllPaths {
|
||||||
|
// Optimization: If no path matched, we don't need to track this request
|
||||||
|
rl.incrementTotalRequestsMetric()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
rl.Lock() // Critical Section Start
|
||||||
|
defer rl.Unlock()
|
||||||
|
|
||||||
|
rl.incrementTotalRequestsMetric() // Metric under lock to ensure consistency (or use atomic outside)
|
||||||
|
|
||||||
// Initialize the nested map if it doesn't exist
|
// Initialize the nested map if it doesn't exist
|
||||||
if _, exists := rl.requests[ip]; !exists {
|
if _, exists := rl.requests[ip]; !exists {
|
||||||
rl.requests[ip] = make(map[string]*requestCounter)
|
rl.requests[ip] = make(map[string]*requestCounter)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -215,7 +216,12 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s
|
|||||||
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
||||||
}
|
}
|
||||||
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read
|
r.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) // Restore body for next read
|
||||||
return string(bodyBytes), nil
|
|
||||||
|
// SOTA Pattern: Zero-Copy (avoid allocation for string conversion)
|
||||||
|
if len(bodyBytes) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return unsafe.String(&bodyBytes[0], len(bodyBytes)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to extract all headers
|
// Helper function to extract all headers
|
||||||
|
|||||||
13
rules.go
13
rules.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
|
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, target, value string, state *WAFState) bool {
|
||||||
@@ -112,14 +113,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
|
|||||||
|
|
||||||
// incrementRuleHitCount increments the hit counter for a given rule ID.
|
// incrementRuleHitCount increments the hit counter for a given rule ID.
|
||||||
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
|
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
|
||||||
hitCount := HitCount(1) // Default increment
|
// SOTA Pattern: Wait-Free / Lock-Free Data Structures (using atomic)
|
||||||
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded {
|
counterInterface, _ := m.ruleHits.LoadOrStore(ruleID, &atomic.Int64{})
|
||||||
hitCount = currentCount.(HitCount) + 1
|
counter := counterInterface.(*atomic.Int64)
|
||||||
}
|
newVal := counter.Add(1)
|
||||||
m.ruleHits.Store(ruleID, hitCount)
|
|
||||||
m.logger.Debug("Rule hit count updated",
|
m.logger.Debug("Rule hit count updated",
|
||||||
zap.String("rule_id", string(ruleID)),
|
zap.String("rule_id", string(ruleID)),
|
||||||
zap.Int("hit_count", int(hitCount)), // More descriptive log field
|
zap.Int64("hit_count", newVal),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user