mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
fixes for the following issues:
- https://github.com/fabriziosalmi/caddy-waf/issues/41 - https://github.com/fabriziosalmi/caddy-waf/issues/40
This commit is contained in:
@@ -127,6 +127,7 @@ func extractIP(remoteAddr string, logger *zap.Logger) string {
|
||||
return host
|
||||
}
|
||||
|
||||
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
|
||||
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
|
||||
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
|
||||
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))
|
||||
@@ -141,6 +142,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
scanner := bufio.NewScanner(file)
|
||||
validEntries := 0
|
||||
totalLines := 0
|
||||
invalidEntries := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
totalLines++
|
||||
@@ -156,6 +158,9 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
zap.Int("line", totalLines),
|
||||
zap.String("entry", line),
|
||||
)
|
||||
invalidEntries++
|
||||
// If you want the entire load to fail if any single IP entry is invalid, uncomment the line below
|
||||
// return fmt.Errorf("failed to add IP entry %s : %w", line, err)
|
||||
} else {
|
||||
validEntries++
|
||||
}
|
||||
@@ -169,6 +174,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
bl.logger.Info("IP blacklist loaded",
|
||||
zap.String("path", path),
|
||||
zap.Int("valid_entries", validEntries),
|
||||
zap.Int("invalid_entries", invalidEntries),
|
||||
zap.Int("total_lines", totalLines),
|
||||
)
|
||||
return nil
|
||||
|
||||
97
caddywaf.go
97
caddywaf.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -29,6 +28,7 @@ var (
|
||||
_ caddy.Provisioner = (*Middleware)(nil)
|
||||
_ caddyhttp.MiddlewareHandler = (*Middleware)(nil)
|
||||
_ caddyfile.Unmarshaler = (*Middleware)(nil)
|
||||
_ caddy.Validator = (*Middleware)(nil)
|
||||
)
|
||||
|
||||
// ==================== Initialization and Setup ====================
|
||||
@@ -209,19 +209,18 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
}
|
||||
|
||||
// Load IP blacklist
|
||||
m.ipBlacklist = NewCIDRTrie()
|
||||
m.logger.Debug("ipBlacklist initialized in Provision", zap.Bool("isNil", m.ipBlacklist == nil))
|
||||
if m.IPBlacklistFile != "" {
|
||||
err = m.loadIPBlacklistIntoMap(m.IPBlacklistFile, m.ipBlacklist)
|
||||
m.ipBlacklist = NewCIDRTrie()
|
||||
err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load DNS blacklist
|
||||
m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{}
|
||||
if m.DNSBlacklistFile != "" {
|
||||
err = m.blacklistLoader.LoadDNSBlacklistFromFile(m.DNSBlacklistFile, m.dnsBlacklist)
|
||||
m.dnsBlacklist = make(map[string]struct{})
|
||||
err = m.loadDNSBlacklist(m.DNSBlacklistFile, m.dnsBlacklist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load DNS blacklist: %w", err)
|
||||
}
|
||||
@@ -414,25 +413,21 @@ func (m *Middleware) ReloadConfig() error {
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.logger.Info("Reloading WAF configuration")
|
||||
|
||||
newIPBlacklist := NewCIDRTrie()
|
||||
if m.IPBlacklistFile != "" {
|
||||
if err := m.loadIPBlacklistIntoMap(m.IPBlacklistFile, newIPBlacklist); err != nil {
|
||||
newIPBlacklist := NewCIDRTrie()
|
||||
if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil {
|
||||
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
|
||||
return fmt.Errorf("failed to reload IP blacklist: %v", err)
|
||||
}
|
||||
} else {
|
||||
m.logger.Debug("No IP blacklist file specified, skipping reload")
|
||||
m.ipBlacklist = newIPBlacklist
|
||||
}
|
||||
|
||||
newDNSBlacklist := make(map[string]struct{})
|
||||
if m.DNSBlacklistFile != "" {
|
||||
if err := m.loadDNSBlacklistIntoMap(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
|
||||
newDNSBlacklist := make(map[string]struct{})
|
||||
if err := m.loadDNSBlacklist(m.DNSBlacklistFile, newDNSBlacklist); err != nil {
|
||||
m.logger.Error("Failed to reload DNS blacklist", zap.String("file", m.DNSBlacklistFile), zap.Error(err))
|
||||
return fmt.Errorf("failed to reload DNS blacklist: %v", err)
|
||||
}
|
||||
} else {
|
||||
m.logger.Debug("No DNS blacklist file specified, skipping reload")
|
||||
m.dnsBlacklist = newDNSBlacklist
|
||||
}
|
||||
|
||||
// Call the external loadRules function
|
||||
@@ -441,62 +436,38 @@ func (m *Middleware) ReloadConfig() error {
|
||||
return fmt.Errorf("failed to reload rules: %v", err)
|
||||
}
|
||||
|
||||
m.ipBlacklist = newIPBlacklist
|
||||
m.dnsBlacklist = newDNSBlacklist
|
||||
|
||||
m.logger.Info("WAF configuration reloaded successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap *CIDRTrie) error {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read IP blacklist file: %v", err)
|
||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
|
||||
return nil
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
blacklist := make(map[string]struct{})
|
||||
err := m.blacklistLoader.LoadIPBlacklistFromFile(path, blacklist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(line, "/") {
|
||||
// Handle single IP addresses
|
||||
ip := net.ParseIP(line)
|
||||
if ip == nil {
|
||||
m.logger.Warn("Skipping invalid IP address format in blacklist", zap.String("address", line))
|
||||
continue
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
line = line + "/32"
|
||||
} else {
|
||||
line = line + "/128"
|
||||
}
|
||||
}
|
||||
|
||||
if err := blacklistMap.Insert(line); err != nil {
|
||||
m.logger.Warn("Failed to insert CIDR into trie", zap.String("cidr", line), zap.Error(err))
|
||||
}
|
||||
// Convert the map to CIDRTrie
|
||||
for ip := range blacklist {
|
||||
blacklistMap.Insert(ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read DNS blacklist file: %v", err)
|
||||
func (m *Middleware) loadDNSBlacklist(path string, blacklistMap map[string]struct{}) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
m.logger.Warn("Skipping DNS blacklist load, file does not exist", zap.String("file", path))
|
||||
return nil
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.ToLower(strings.TrimSpace(line))
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
blacklistMap[line] = struct{}{} // Changed to struct{}{}
|
||||
err := m.blacklistLoader.LoadDNSBlacklistFromFile(path, blacklistMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load DNS blacklist: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -570,3 +541,11 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
return m.configLoader.UnmarshalCaddyfile(d, m)
|
||||
}
|
||||
|
||||
// Validate implements caddy.Validator.
|
||||
func (m *Middleware) Validate() error {
|
||||
if m.logLevel == 0 {
|
||||
m.logLevel = zapcore.InfoLevel // Default log level
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
50
handler.go
50
handler.go
@@ -11,6 +11,9 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type ContextKeyLogId string
|
||||
type ContextKeyRule string
|
||||
|
||||
// ServeHTTP implements caddyhttp.Handler.
|
||||
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
logID := uuid.New().String()
|
||||
@@ -41,7 +44,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
err := next.ServeHTTP(recorder, r)
|
||||
|
||||
// Phase 3: Response Header analysis
|
||||
if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) {
|
||||
if m.isPhaseBlocked(recorder, r, 3, state) {
|
||||
return nil // Request blocked in Phase 3, short-circuit
|
||||
}
|
||||
|
||||
@@ -83,17 +86,6 @@ func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phas
|
||||
return false
|
||||
}
|
||||
|
||||
// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic.
|
||||
func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool {
|
||||
m.handlePhase(recorder, r, phase, state)
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
recorder.WriteHeader(state.StatusCode)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// logRequestStart logs the start of WAF evaluation.
|
||||
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
|
||||
m.logger.Info("WAF request evaluation started",
|
||||
@@ -126,15 +118,20 @@ func (m *Middleware) initializeWAFState() *WAFState {
|
||||
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
|
||||
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
|
||||
body := recorder.BodyString()
|
||||
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string)))
|
||||
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
|
||||
|
||||
if !ok {
|
||||
m.logger.Error("Log ID not found in context")
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID))
|
||||
|
||||
for _, rule := range m.Rules[4] {
|
||||
if rule.regex.MatchString(body) {
|
||||
m.processRuleMatch(recorder, r, &rule, body, state)
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
if m.processRuleMatch(recorder, r, &rule, body, state) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -191,9 +188,16 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor
|
||||
}
|
||||
w.WriteHeader(recorder.StatusCode())
|
||||
|
||||
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
|
||||
|
||||
if !ok {
|
||||
m.logger.Error("Log ID not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := w.Write(recorder.body.Bytes()) // Copy body from recorder to original writer
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string)))
|
||||
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", logID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +225,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"),
|
||||
)
|
||||
m.logger.Debug("Country blocking phase completed - blocked by country")
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country blocking phase completed - not blocked")
|
||||
@@ -235,7 +238,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by rate limit"),
|
||||
)
|
||||
m.logger.Debug("Rate limiting phase completed - blocked by rate limit")
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Rate limiting phase completed - not blocked")
|
||||
@@ -254,7 +256,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP,
|
||||
zap.String("message", "Request blocked by IP blacklist"),
|
||||
)
|
||||
m.logger.Debug("IP blacklist phase completed - blocked")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -269,7 +270,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by IP blacklist"),
|
||||
)
|
||||
m.logger.Debug("IP blacklist phase completed - blocked")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -281,7 +281,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
zap.String("message", "Request blocked by DNS blacklist"),
|
||||
zap.String("host", r.Host),
|
||||
)
|
||||
m.logger.Debug("DNS blacklist phase completed - blocked")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -292,6 +291,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
}
|
||||
|
||||
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
|
||||
|
||||
for _, rule := range rules {
|
||||
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets)))
|
||||
|
||||
@@ -338,16 +338,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
)
|
||||
if phase == 3 || phase == 4 {
|
||||
if recorder, ok := w.(*responseRecorder); ok {
|
||||
if !m.processRuleMatch(recorder, r, &rule, value, state) {
|
||||
if m.processRuleMatch(recorder, r, &rule, value, state) {
|
||||
return // Stop processing if the rule match indicates blocking
|
||||
}
|
||||
} else {
|
||||
if !m.processRuleMatch(w, r, &rule, value, state) {
|
||||
if m.processRuleMatch(w, r, &rule, value, state) {
|
||||
return // Stop processing if the rule match indicates blocking
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !m.processRuleMatch(w, r, &rule, value, state) {
|
||||
if m.processRuleMatch(w, r, &rule, value, state) {
|
||||
return // Stop processing if the rule match indicates blocking
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ var sensitiveKeys = []string{
|
||||
"routing", // Routing number
|
||||
"mfa", // Multi-factor authentication code
|
||||
"otp", // One-time password
|
||||
"code", // Generic code
|
||||
//"code", // Generic code <------ REMOVED THIS
|
||||
}
|
||||
|
||||
var sensitiveKeysMutex sync.RWMutex // Add mutex for thread safety when modifying
|
||||
|
||||
@@ -18,9 +18,6 @@ type RequestValueExtractor struct {
|
||||
redactSensitiveData bool // Add this field
|
||||
}
|
||||
|
||||
// Define a custom type for context keys
|
||||
type ContextKeyLogId string
|
||||
|
||||
// Extraction Target Constants - Improved Readability and Maintainability
|
||||
const (
|
||||
TargetMethod = "METHOD"
|
||||
|
||||
15
response.go
15
response.go
@@ -29,14 +29,19 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, err := w.Write([]byte(resp.Body))
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", r.Context().Value("logID").(string)))
|
||||
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
|
||||
if !ok {
|
||||
m.logger.Error("Log ID not found in context, cannot log custom response error")
|
||||
return
|
||||
}
|
||||
m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", logID))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Default blocking behavior
|
||||
logID := uuid.New().String()
|
||||
if logIDCtx, ok := r.Context().Value("logID").(string); ok {
|
||||
if logIDCtx, ok := r.Context().Value(ContextKeyLogId("logID")).(string); ok {
|
||||
logID = logIDCtx
|
||||
}
|
||||
|
||||
@@ -70,7 +75,11 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state
|
||||
w.WriteHeader(statusCode)
|
||||
} else {
|
||||
// Debug log when response is already written, including log_id
|
||||
logID, _ := r.Context().Value("logID").(string)
|
||||
logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string)
|
||||
if !ok {
|
||||
m.logger.Error("Log ID not found in context, cannot log blockRequest when response already written")
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Debug("blockRequest called but response already written",
|
||||
zap.Int("intended_status_code", statusCode),
|
||||
|
||||
Reference in New Issue
Block a user