fabriziosalmi
2025-01-29 13:34:27 +01:00
parent c99e875aaf
commit 9223d337fc
7 changed files with 82 additions and 94 deletions

View File

@@ -127,6 +127,7 @@ func extractIP(remoteAddr string, logger *zap.Logger) string {
return host
}
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))
@@ -141,6 +142,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
scanner := bufio.NewScanner(file)
validEntries := 0
totalLines := 0
invalidEntries := 0
for scanner.Scan() {
totalLines++
@@ -156,6 +158,9 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
zap.Int("line", totalLines),
zap.String("entry", line),
)
invalidEntries++
// If you want the entire load to fail if any single IP entry is invalid, uncomment the line below
// return fmt.Errorf("failed to add IP entry %s : %w", line, err)
} else {
validEntries++
}
@@ -169,6 +174,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
bl.logger.Info("IP blacklist loaded",
zap.String("path", path),
zap.Int("valid_entries", validEntries),
zap.Int("invalid_entries", invalidEntries),
zap.Int("total_lines", totalLines),
)
return nil

View File

@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"strings"
@@ -29,6 +28,7 @@ var (
_ caddy.Provisioner = (*Middleware)(nil)
_ caddyhttp.MiddlewareHandler = (*Middleware)(nil)
_ caddyfile.Unmarshaler = (*Middleware)(nil)
_ caddy.Validator = (*Middleware)(nil)
)
// ==================== Initialization and Setup ====================
@@ -209,19 +209,18 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
}
// Load IP blacklist
m.ipBlacklist = NewCIDRTrie()
m.logger.Debug("ipBlacklist initialized in Provision", zap.Bool("isNil", m.ipBlacklist == nil))
if m.IPBlacklistFile != "" {
err = m.loadIPBlacklistIntoMap(m.IPBlacklistFile, m.ipBlacklist)
m.ipBlacklist = NewCIDRTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
}
}
// Load DNS blacklist
m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{}
if m.DNSBlacklistFile != "" {
err = m.blacklistLoader.LoadDNSBlacklistFromFile(m.DNSBlacklistFile, m.dnsBlacklist)
m.dnsBlacklist = make(map[string]struct{})
err = m.loadDNSBlacklist(m.DNSBlacklistFile, m.dnsBlacklist)
if err != nil {
return fmt.Errorf("failed to load DNS blacklist: %w", err)
}
@@ -414,25 +413,21 @@ func (m *Middleware) ReloadConfig() error {
defer m.mu.Unlock()
m.logger.Info("Reloading WAF configuration")
newIPBlacklist := NewCIDRTrie()
if m.IPBlacklistFile != "" {
if err := m.loadIPBlacklistIntoMap(m.IPBlacklistFile, newIPBlacklist); err != nil {
newIPBlacklist := NewCIDRTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil {
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload IP blacklist: %v", err)
}
} else {
m.logger.Debug("No IP blacklist file specified, skipping reload")
m.ipBlacklist = newIPBlacklist
}
newDNSBlacklist := make(map[string]struct{})
if m.DNSBlacklistFile != "" {
if err := m.loadDNSBlacklistIntoMap(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
newDNSBlacklist := make(map[string]struct{})
if err := m.loadDNSBlacklist(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
m.logger.Error("Failed to reload DNS blacklist", zap.String("file", m.DNSBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload DNS blacklist: %v", err)
}
} else {
m.logger.Debug("No DNS blacklist file specified, skipping reload")
m.dnsBlacklist = newDNSBlacklist
}
// Call the external loadRules function
@@ -441,62 +436,38 @@ func (m *Middleware) ReloadConfig() error {
return fmt.Errorf("failed to reload rules: %v", err)
}
m.ipBlacklist = newIPBlacklist
m.dnsBlacklist = newDNSBlacklist
m.logger.Info("WAF configuration reloaded successfully")
return nil
}
func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap *CIDRTrie) error {
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read IP blacklist file: %v", err)
func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
return nil
}
lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
blacklist := make(map[string]struct{})
err := m.blacklistLoader.LoadIPBlacklistFromFile(path, blacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
}
if !strings.Contains(line, "/") {
// Handle single IP addresses
ip := net.ParseIP(line)
if ip == nil {
m.logger.Warn("Skipping invalid IP address format in blacklist", zap.String("address", line))
continue
}
if ip.To4() != nil {
line = line + "/32"
} else {
line = line + "/128"
}
}
if err := blacklistMap.Insert(line); err != nil {
m.logger.Warn("Failed to insert CIDR into trie", zap.String("cidr", line), zap.Error(err))
}
// Convert the map to CIDRTrie
for ip := range blacklist {
blacklistMap.Insert(ip)
}
return nil
}
func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error {
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read DNS blacklist file: %v", err)
func (m *Middleware) loadDNSBlacklist(path string, blacklistMap map[string]struct{}) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping DNS blacklist load, file does not exist", zap.String("file", path))
return nil
}
lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.ToLower(strings.TrimSpace(line))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
blacklistMap[line] = struct{}{} // Changed to struct{}{}
err := m.blacklistLoader.LoadDNSBlacklistFromFile(path, blacklistMap)
if err != nil {
return fmt.Errorf("failed to load DNS blacklist: %w", err)
}
return nil
}
@@ -570,3 +541,11 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
return m.configLoader.UnmarshalCaddyfile(d, m)
}
// Validate implements caddy.Validator.
func (m *Middleware) Validate() error {
if m.logLevel == 0 {
m.logLevel = zapcore.InfoLevel // Default log level
}
return nil
}

View File

@@ -11,6 +11,9 @@ import (
"go.uber.org/zap/zapcore"
)
type ContextKeyLogId string
type ContextKeyRule string
// ServeHTTP implements caddyhttp.Handler.
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
logID := uuid.New().String()
@@ -41,7 +44,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
err := next.ServeHTTP(recorder, r)
// Phase 3: Response Header analysis
if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) {
if m.isPhaseBlocked(recorder, r, 3, state) {
return nil // Request blocked in Phase 3, short-circuit
}
@@ -83,17 +86,6 @@ func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phas
return false
}
// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic.
func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool {
m.handlePhase(recorder, r, phase, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
recorder.WriteHeader(state.StatusCode)
return true
}
return false
}
// logRequestStart logs the start of WAF evaluation.
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
m.logger.Info("WAF request evaluation started",
@@ -126,15 +118,20 @@ func (m *Middleware) initializeWAFState() *WAFState {
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
body := recorder.BodyString()
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string)))
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
if !ok {
m.logger.Error("Log ID not found in context")
return
}
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID))
for _, rule := range m.Rules[4] {
if rule.regex.MatchString(body) {
m.processRuleMatch(recorder, r, &rule, body, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
if m.processRuleMatch(recorder, r, &rule, body, state) {
return
}
}
}
}
@@ -191,9 +188,16 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor
}
w.WriteHeader(recorder.StatusCode())
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
if !ok {
m.logger.Error("Log ID not found in context")
return
}
_, err := w.Write(recorder.body.Bytes()) // Copy body from recorder to original writer
if err != nil {
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string)))
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", logID))
}
}
@@ -221,7 +225,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country"),
)
m.logger.Debug("Country blocking phase completed - blocked by country")
return
}
m.logger.Debug("Country blocking phase completed - not blocked")
@@ -235,7 +238,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
zap.String("message", "Request blocked by rate limit"),
)
m.logger.Debug("Rate limiting phase completed - blocked by rate limit")
return
}
m.logger.Debug("Rate limiting phase completed - not blocked")
@@ -254,7 +256,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
} else {
@@ -269,7 +270,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
}
@@ -281,7 +281,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
zap.String("message", "Request blocked by DNS blacklist"),
zap.String("host", r.Host),
)
m.logger.Debug("DNS blacklist phase completed - blocked")
return
}
@@ -292,6 +291,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
}
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
for _, rule := range rules {
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets)))
@@ -338,16 +338,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
)
if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
if !m.processRuleMatch(recorder, r, &rule, value, state) {
if m.processRuleMatch(recorder, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
if m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
if m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}

View File

@@ -36,7 +36,7 @@ var sensitiveKeys = []string{
"routing", // Routing number
"mfa", // Multi-factor authentication code
"otp", // One-time password
"code", // Generic code
//"code", // Generic code <------ REMOVED THIS
}
var sensitiveKeysMutex sync.RWMutex // Add mutex for thread safety when modifying

View File

@@ -18,9 +18,6 @@ type RequestValueExtractor struct {
redactSensitiveData bool // Add this field
}
// Define a custom type for context keys
type ContextKeyLogId string
// Extraction Target Constants - Improved Readability and Maintainability
const (
TargetMethod = "METHOD"

View File

@@ -29,14 +29,19 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state
w.WriteHeader(resp.StatusCode)
_, err := w.Write([]byte(resp.Body))
if err != nil {
m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", r.Context().Value("logID").(string)))
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
if !ok {
m.logger.Error("Log ID not found in context, cannot log custom response error")
return
}
m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", logID))
}
return
}
// Default blocking behavior
logID := uuid.New().String()
if logIDCtx, ok := r.Context().Value("logID").(string); ok {
if logIDCtx, ok := r.Context().Value(ContextKeyLogId("logID")).(string); ok {
logID = logIDCtx
}
@@ -70,7 +75,11 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state
w.WriteHeader(statusCode)
} else {
// Debug log when response is already written, including log_id
logID, _ := r.Context().Value("logID").(string)
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
if !ok {
m.logger.Error("Log ID not found in context, cannot log blockRequest when response already written")
return
}
m.logger.Debug("blockRequest called but response already written",
zap.Int("intended_status_code", statusCode),

View File

@@ -25,9 +25,6 @@ var (
_ caddyfile.Unmarshaler = (*Middleware)(nil)
)
// Define a custom type for context keys
type ContextKeyRule string
// Define custom types for rule hits
type RuleID string
type HitCount int