Minor code improvements.

This commit is contained in:
fabriziosalmi
2025-01-22 13:12:26 +01:00
parent 89e2f269c1
commit f4bd92c5a0
11 changed files with 1313 additions and 960 deletions

1
.gitignore vendored
View File

@@ -13,3 +13,4 @@ testdata/rules.json
log.json
validation.log
caddy-waf.DS_Store
vendor

View File

@@ -1,6 +1,7 @@
package caddywaf
import (
"bufio" // Optimized reading
"fmt"
"net"
"os"
@@ -25,18 +26,22 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
if bl.logger == nil {
bl.logger = zap.NewNop()
}
bl.logger.Debug("Loading DNS blacklist from file", zap.String("file", path))
bl.logger.Debug("Loading DNS blacklist", zap.String("path", path)) // Improved log message
content, err := os.ReadFile(path)
file, err := os.Open(path)
if err != nil {
bl.logger.Warn("Failed to read DNS blacklist file", zap.String("file", path), zap.Error(err))
return fmt.Errorf("failed to read DNS blacklist file: %w", err)
bl.logger.Warn("Failed to open DNS blacklist file", zap.String("path", path), zap.Error(err)) // Path instead of file for consistency
return fmt.Errorf("failed to open DNS blacklist file: %w", err) // More accurate error message
}
defer file.Close()
lines := strings.Split(string(content), "\n")
scanner := bufio.NewScanner(file)
validEntries := 0
totalLines := 0 // Initialize totalLines
for _, line := range lines {
for scanner.Scan() {
totalLines++
line := scanner.Text()
line = strings.ToLower(strings.TrimSpace(line))
if line == "" || strings.HasPrefix(line, "#") {
continue // Skip empty lines and comments
@@ -45,10 +50,15 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
validEntries++
}
bl.logger.Info("DNS blacklist loaded successfully",
zap.String("file", path),
if err := scanner.Err(); err != nil {
bl.logger.Error("Error reading DNS blacklist file", zap.String("path", path), zap.Error(err)) // More specific error log
return fmt.Errorf("error reading DNS blacklist file: %w", err)
}
bl.logger.Info("DNS blacklist loaded", // Improved log message
zap.String("path", path),
zap.Int("valid_entries", validEntries),
zap.Int("total_lines", len(lines)),
zap.Int("total_lines", totalLines), // Use totalLines which is correctly counted
)
return nil
}
@@ -77,14 +87,14 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
defer m.mu.RUnlock()
if _, exists := m.dnsBlacklist[normalizedHost]; exists {
m.logger.Info("Host is blacklisted",
m.logger.Debug("DNS blacklist hit", // More concise log message, debug level
zap.String("host", host),
zap.String("blacklisted_domain", normalizedHost),
)
return true
}
m.logger.Debug("Host is not blacklisted", zap.String("host", host))
m.logger.Debug("DNS blacklist miss", zap.String("host", host)) // More concise log message, debug level
return false
}
@@ -92,9 +102,9 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
func extractIP(remoteAddr string, logger *zap.Logger) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
logger.Debug("Failed to extract IP from remote address, using full address",
logger.Debug("Using full remote address as IP", // More descriptive debug log
zap.String("remoteAddr", remoteAddr),
zap.Error(err),
zap.Error(err), // Keep error for debugging
)
return remoteAddr // Assume the input is already an IP address
}
@@ -106,18 +116,22 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
if bl.logger == nil {
bl.logger = zap.NewNop()
}
bl.logger.Debug("Loading IP blacklist from file", zap.String("file", path))
bl.logger.Debug("Loading IP blacklist", zap.String("path", path)) // Improved log message
content, err := os.ReadFile(path)
file, err := os.Open(path)
if err != nil {
bl.logger.Warn("Failed to read IP blacklist file", zap.String("file", path), zap.Error(err))
return fmt.Errorf("failed to read IP blacklist file: %w", err)
bl.logger.Warn("Failed to open IP blacklist file", zap.String("path", path), zap.Error(err)) // Path instead of file for consistency
return fmt.Errorf("failed to open IP blacklist file: %w", err) // More accurate error message
}
defer file.Close()
lines := strings.Split(string(content), "\n")
scanner := bufio.NewScanner(file)
validEntries := 0
totalLines := 0 // Initialize totalLines
for i, line := range lines {
for scanner.Scan() {
totalLines++
line := scanner.Text()
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue // Skip empty lines and comments
@@ -127,7 +141,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
// Valid CIDR range
ipBlacklist[line] = struct{}{}
validEntries++
bl.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line))
bl.logger.Debug("Added CIDR to IP blacklist", zap.String("cidr", line)) // More specific debug log
continue
}
@@ -135,21 +149,26 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
// Valid IP address
ipBlacklist[line] = struct{}{}
validEntries++
bl.logger.Debug("Added IP to blacklist", zap.String("ip", line))
bl.logger.Debug("Added IP to IP blacklist", zap.String("ip", line)) // More specific debug log
continue
}
bl.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping",
zap.String("file", path),
zap.Int("line", i+1),
bl.logger.Warn("Invalid IP/CIDR entry in blacklist file", // More concise warning message
zap.String("path", path),
zap.Int("line", totalLines), // Use totalLines which is correctly counted
zap.String("entry", line),
)
}
bl.logger.Info("IP blacklist loaded successfully",
zap.String("file", path),
if scanner.Err() != nil {
bl.logger.Error("Error reading IP blacklist file", zap.String("path", path), zap.Error(scanner.Err())) // More specific error log
return fmt.Errorf("error reading IP blacklist file: %w", scanner.Err())
}
bl.logger.Info("IP blacklist loaded", // Improved log message
zap.String("path", path),
zap.Int("valid_entries", validEntries),
zap.Int("total_lines", len(lines)),
zap.Int("total_lines", totalLines), // Use totalLines which is correctly counted
)
return nil
}

View File

@@ -84,7 +84,7 @@ func TestLoadDNSBlacklistFromFile_InvalidFile(t *testing.T) {
dnsBlacklist := make(map[string]struct{})
err := bl.LoadDNSBlacklistFromFile("nonexistent.txt", dnsBlacklist)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read DNS blacklist file")
assert.Contains(t, err.Error(), "failed to open DNS blacklist file") // Updated error message
}
// TestLoadIPBlacklistFromFile tests loading IP addresses and CIDR ranges from a file.
@@ -136,7 +136,7 @@ func TestLoadIPBlacklistFromFile_InvalidFile(t *testing.T) {
ipBlacklist := make(map[string]struct{})
err := bl.LoadIPBlacklistFromFile("nonexistent.txt", ipBlacklist)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read IP blacklist file")
assert.Contains(t, err.Error(), "failed to open IP blacklist file") // Updated error message
}
// TestIsDNSBlacklisted tests checking if a host is blacklisted.
@@ -400,9 +400,10 @@ func TestParseCountryBlock(t *testing.T) {
t.Fatal("Failed to advance to the first directive")
}
err := cl.parseCountryBlock(d, m, true)
handler := cl.parseCountryBlockDirective(true) // Get the directive handler
err := handler(d, m) // Execute the handler
if err != nil {
t.Fatalf("parseCountryBlock failed: %v", err)
t.Fatalf("parseCountryBlockDirective failed: %v", err)
}
if !m.CountryBlock.Enabled {
@@ -459,14 +460,32 @@ func TestParseBlacklistFile(t *testing.T) {
t.Fatal("Failed to advance to the first directive")
}
err := cl.parseBlacklistFile(d, m, true)
handler := cl.parseBlacklistFileDirective(true) // Get the directive handler for IP blacklist
err := handler(d, m) // Execute the handler
if err != nil {
t.Fatalf("parseBlacklistFile failed: %v", err)
t.Fatalf("parseBlacklistFileDirective failed: %v", err)
}
if m.IPBlacklistFile != tmpFile {
t.Errorf("Expected IP blacklist file to be '%s', got '%s'", tmpFile, m.IPBlacklistFile)
}
// Test dns_blacklist_file
tmpDNSFile := filepath.Join(tmpDir, "dns_blacklist.txt")
d = caddyfile.NewTestDispenser(`
dns_blacklist_file ` + tmpDNSFile + `
`)
if !d.Next() {
t.Fatal("Failed to advance to the dns_blacklist_file directive")
}
handler = cl.parseBlacklistFileDirective(false) // Get handler for DNS blacklist
err = handler(d, m)
if err != nil {
t.Fatalf("parseBlacklistFileDirective for dns failed: %v", err)
}
if m.DNSBlacklistFile != tmpDNSFile {
t.Errorf("Expected DNS blacklist file to be '%s', got '%s'", tmpDNSFile, m.DNSBlacklistFile)
}
}
// TestParseAnomalyThreshold tests the parseAnomalyThreshold function.
@@ -1008,10 +1027,10 @@ func TestValidateRule(t *testing.T) {
func TestProcessRuleMatch(t *testing.T) {
logger := newMockLogger()
middleware := &Middleware{
logger: logger.Logger, // Use the embedded *zap.Logger
logger: logger.Logger,
AnomalyThreshold: 10,
ruleHits: sync.Map{}, // Use sync.Map directly
muMetrics: sync.RWMutex{}, // Use sync.RWMutex directly
ruleHits: sync.Map{},
muMetrics: sync.RWMutex{},
}
rule := &Rule{
@@ -1028,12 +1047,22 @@ func TestProcessRuleMatch(t *testing.T) {
}
req := httptest.NewRequest("GET", "http://example.com", nil)
w := httptest.NewRecorder()
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id" // Or generate a UUID if needed: uuid.New().String()
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
// Create a new request with the context
req = req.WithContext(ctx)
// Create a ResponseRecorder
w := NewResponseRecorder(httptest.NewRecorder())
// Test blocking rule
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
assert.False(t, shouldContinue)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equal(t, http.StatusForbidden, w.StatusCode())
assert.True(t, state.Blocked)
assert.Equal(t, 5, state.TotalScore)
@@ -1043,7 +1072,8 @@ func TestProcessRuleMatch(t *testing.T) {
TotalScore: 0,
ResponseWritten: false,
}
w = httptest.NewRecorder()
// Re-create a ResponseRecorder for the second test
w = NewResponseRecorder(httptest.NewRecorder())
shouldContinue = middleware.processRuleMatch(w, req, rule, "value", state)
assert.True(t, shouldContinue)
assert.False(t, state.Blocked)
@@ -1122,6 +1152,13 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
}
req := httptest.NewRequest("GET", "http://example.com", nil)
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-highscore" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
// Test blocking rule with high score
@@ -1522,6 +1559,13 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("User-Agent", "nikto")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-nikto" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1536,14 +1580,6 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
// func TestBlockedRequestPhase1_IPBlacklist(t *testing.T) {
// func TestBlockedRequestPhase1_IPBlacklist(t *testing.T) {
// func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
// func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
// func TestBlockedRequestPhase3_ResponseHeaderRegex(t *testing.T) {
// func TestBlockedRequestPhase4_ResponseBodyRegex(t *testing.T) {
func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
@@ -1575,6 +1611,13 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headerregex" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1620,6 +1663,13 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headerspecificvalue" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1666,6 +1716,13 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header1", "good-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headercomma" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1712,6 +1769,12 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
req.Header.Set("User-Agent", "good-user")
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-combined"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1758,6 +1821,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("User-Agent", "good-user")
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-nomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1802,6 +1871,13 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
}
req := httptest.NewRequest("GET", "http://example.com", nil)
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headerempty"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1845,6 +1921,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
}
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headermissing"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1891,6 +1974,13 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headercomplex"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1937,6 +2027,13 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "bad-user-agent")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multimatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1982,6 +2079,13 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "good-user-agent")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multinomatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2026,6 +2130,13 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
}
req := httptest.NewRequest("GET", "http://example.com?param1=good-param-value¶m2=good-value", nil)
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-urlparamnomatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2080,6 +2191,12 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
req.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multiplerules" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2095,6 +2212,13 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
req2 := httptest.NewRequest("GET", "http://good-host.com", nil)
req2.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2 as well!
ctx2 := context.Background() // New context for the second request!
logID2 := "test-log-id-multiplerules2"
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
req2 = req2.WithContext(ctx2)
w2 := httptest.NewRecorder()
state2 := &WAFState{}
@@ -2107,7 +2231,6 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
assert.True(t, state2.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Rules", "Response body should contain 'Blocked by Multiple Rules'")
}
func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
@@ -2147,6 +2270,13 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
}(), // Simulate a request with bad body
)
req.Header.Set("Content-Type", "text/plain")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregex" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2161,8 +2291,6 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
assert.Contains(t, w.Body.String(), "Blocked by Body Regex", "Response body should contain 'Blocked by Body Regex'")
}
// new
func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
@@ -2200,6 +2328,13 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
}(), // Simulate a request with JSON body
)
req.Header.Set("Content-Type", "application/json")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexjson" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2247,6 +2382,13 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexform"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2298,6 +2440,13 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
}(),
)
req.Header.Set("Content-Type", "text/plain") // Setting content type
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexspecific"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2350,6 +2499,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
)
req.Header.Set("Content-Type", "text/plain")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexnomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2364,8 +2519,6 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
////////
func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
@@ -2412,6 +2565,13 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexmultipartnomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2425,6 +2585,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
@@ -2719,6 +2880,13 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headercaseinsensitive" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2765,6 +2933,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Custom-Header1", "bad-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req
ctx := context.Background()
logID := "test-log-id-headermultimatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
@@ -2781,6 +2956,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
req2 := httptest.NewRequest("GET", "http://example.com", nil)
req2.Header.Set("X-Custom-Header1", "good-value")
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2
ctx2 := context.Background()
logID2 := "test-log-id-headermultimatch2" // Unique log ID for this test
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
req2 = req2.WithContext(ctx2)
w2 := httptest.NewRecorder()
state2 := &WAFState{}
@@ -2797,6 +2979,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
req3 := httptest.NewRequest("GET", "http://example.com", nil)
req3.Header.Set("X-Custom-Header1", "good-value")
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req3
ctx3 := context.Background()
logID3 := "test-log-id-headermultimatch3" // Unique log ID for this test
ctx3 = context.WithValue(ctx3, ContextKeyLogId("logID"), logID3)
req3 = req3.WithContext(ctx3)
w3 := httptest.NewRecorder()
state3 := &WAFState{}
@@ -2808,7 +2997,6 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
assert.False(t, state3.Blocked, "Request should not be blocked when none headers match")
assert.Equal(t, http.StatusOK, w3.Code, "Expected status code 200")
}
func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
@@ -3087,38 +3275,6 @@ func TestParseAnomalyThreshold_Invalid(t *testing.T) {
assert.Contains(t, err.Error(), "invalid syntax")
}
func TestProcessRuleMatch_NoAction(t *testing.T) {
logger := newMockLogger()
middleware := &Middleware{
logger: logger.Logger, // Use the embedded *zap.Logger
AnomalyThreshold: 10,
ruleHits: sync.Map{},
muMetrics: sync.RWMutex{}, // Use sync.RWMutex directly
}
rule := &Rule{
ID: "no_action_rule",
Targets: []string{"header"},
Description: "Test no-action rule",
Score: 5,
Action: "", // No action
}
state := &WAFState{
TotalScore: 0,
ResponseWritten: false,
}
req := httptest.NewRequest("GET", "http://example.com", nil)
w := httptest.NewRecorder()
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
assert.True(t, shouldContinue)
assert.Equal(t, 5, state.TotalScore)
assert.False(t, state.Blocked)
}
func TestExtractValue_HeaderCaseInsensitive(t *testing.T) {
logger := zap.NewNop()
rve := NewRequestValueExtractor(logger, false)

562
config.go
View File

@@ -23,11 +23,11 @@ func NewConfigLoader(logger *zap.Logger) *ConfigLoader {
// parseMetricsEndpoint parses the metrics_endpoint directive.
func (cl *ConfigLoader) parseMetricsEndpoint(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing value for metrics_endpoint", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
m.MetricsEndpoint = d.Val()
cl.logger.Debug("Metrics endpoint set from Caddyfile",
zap.String("metrics_endpoint", m.MetricsEndpoint),
cl.logger.Debug("Metrics endpoint configured", // Improved log message
zap.String("endpoint", m.MetricsEndpoint), // More descriptive log field
zap.String("file", d.File()),
zap.Int("line", d.Line()),
)
@@ -37,11 +37,11 @@ func (cl *ConfigLoader) parseMetricsEndpoint(d *caddyfile.Dispenser, m *Middlewa
// parseLogPath parses the log_path directive.
func (cl *ConfigLoader) parseLogPath(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing value for log_path", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
m.LogFilePath = d.Val()
cl.logger.Debug("Log path set from Caddyfile",
zap.String("log_path", m.LogFilePath),
cl.logger.Debug("Log file path configured", // Improved log message
zap.String("path", m.LogFilePath), // More descriptive log field
zap.String("file", d.File()),
zap.Int("line", d.Line()),
)
@@ -51,7 +51,7 @@ func (cl *ConfigLoader) parseLogPath(d *caddyfile.Dispenser, m *Middleware) erro
// parseRateLimit parses the rate_limit directive.
func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) error {
if m.RateLimit.Requests > 0 {
return d.Err("rate_limit specified multiple times")
return d.Err("rate_limit directive already specified") // Improved error message
}
rl := RateLimit{
@@ -62,36 +62,28 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
}
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
option := d.Val()
switch option {
case "requests":
if !d.NextArg() {
return d.Err("requests requires an argument")
}
reqs, err := strconv.Atoi(d.Val())
reqs, err := cl.parsePositiveInteger(d, "requests")
if err != nil {
return d.Errf("invalid requests value: %v", err)
return err
}
rl.Requests = reqs
cl.logger.Debug("Rate limit requests set", zap.Int("requests", rl.Requests))
case "window":
if !d.NextArg() {
return d.Err("window requires an argument")
}
window, err := time.ParseDuration(d.Val())
window, err := cl.parseDuration(d, "window")
if err != nil {
return d.Errf("invalid window value: %v", err)
return err
}
rl.Window = window
cl.logger.Debug("Rate limit window set", zap.Duration("window", rl.Window))
case "cleanup_interval":
if !d.NextArg() {
return d.Err("cleanup_interval requires an argument")
}
interval, err := time.ParseDuration(d.Val())
interval, err := cl.parseDuration(d, "cleanup_interval")
if err != nil {
return d.Errf("invalid cleanup_interval value: %v", err)
return err
}
rl.CleanupInterval = interval
cl.logger.Debug("Rate limit cleanup interval set", zap.Duration("cleanup_interval", rl.CleanupInterval))
@@ -99,29 +91,26 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
case "paths":
paths := d.RemainingArgs()
if len(paths) == 0 {
return d.Err("paths requires at least one argument")
return d.Err("paths option requires at least one path") // Improved error message
}
rl.Paths = paths
cl.logger.Debug("Rate limit paths set", zap.Strings("paths", rl.Paths))
cl.logger.Debug("Rate limit paths configured", zap.Strings("paths", rl.Paths)) // Improved log message
case "match_all_paths":
if !d.NextArg() {
return d.Err("match_all_paths requires an argument")
}
matchAllPaths, err := strconv.ParseBool(d.Val())
matchAllPaths, err := cl.parseBool(d, "match_all_paths")
if err != nil {
return d.Errf("invalid match_all_paths value: %v", err)
return err
}
rl.MatchAllPaths = matchAllPaths
cl.logger.Debug("Rate limit match_all_paths set", zap.Bool("match_all_paths", rl.MatchAllPaths))
default:
return d.Errf("invalid rate_limit option: %s", d.Val())
return d.Errf("unrecognized rate_limit option: %s", option) // More specific error message
}
}
if rl.Requests <= 0 || rl.Window <= 0 {
return d.Err("requests and window must be greater than zero")
return d.Err("requests and window in rate_limit must be positive values") // Improved error message
}
m.RateLimit = rl
@@ -129,6 +118,7 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
return nil
}
// UnmarshalCaddyfile is the primary parsing function for the middleware configuration.
func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware) error {
if cl.logger == nil {
cl.logger = zap.NewNop()
@@ -143,7 +133,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
RetryInterval: "5m", // Default retry interval
}
cl.logger.Debug("WAF UnmarshalCaddyfile Called", zap.String("file", d.File()), zap.Int("line", d.Line()))
cl.logger.Debug("Parsing WAF configuration", zap.String("file", d.File()), zap.Int("line", d.Line())) // Improved log message
// Set default values
m.LogSeverity = "info"
@@ -154,154 +144,58 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
m.LogFilePath = "debug.json"
m.RedactSensitiveData = false
directiveHandlers := map[string]func(d *caddyfile.Dispenser, m *Middleware) error{
"metrics_endpoint": cl.parseMetricsEndpoint,
"log_path": cl.parseLogPath,
"rate_limit": cl.parseRateLimit,
"block_countries": cl.parseCountryBlockDirective(true), // Use directive-specific helper
"whitelist_countries": cl.parseCountryBlockDirective(false), // Use directive-specific helper
"log_severity": cl.parseLogSeverity,
"log_json": cl.parseLogJSON,
"rule_file": cl.parseRuleFile,
"ip_blacklist_file": cl.parseBlacklistFileDirective(true), // Use directive-specific helper
"dns_blacklist_file": cl.parseBlacklistFileDirective(false), // Use directive-specific helper
"anomaly_threshold": cl.parseAnomalyThreshold,
"custom_response": cl.parseCustomResponse,
"redact_sensitive_data": cl.parseRedactSensitiveData,
"tor": cl.parseTorBlock,
}
for d.Next() {
for d.NextBlock(0) {
directive := d.Val()
cl.logger.Debug("Processing directive", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
switch directive {
case "metrics_endpoint":
if err := cl.parseMetricsEndpoint(d, m); err != nil {
return err
}
case "log_path":
if err := cl.parseLogPath(d, m); err != nil {
return err
}
case "rate_limit":
if err := cl.parseRateLimit(d, m); err != nil {
return err
}
case "block_countries":
if err := cl.parseCountryBlock(d, m, true); err != nil {
return err
}
case "whitelist_countries":
if err := cl.parseCountryBlock(d, m, false); err != nil {
return err
}
case "log_severity":
if err := cl.parseLogSeverity(d, m); err != nil {
return err
}
case "log_json":
m.LogJSON = true
cl.logger.Debug("Log JSON enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
case "rule_file":
if err := cl.parseRuleFile(d, m); err != nil {
return err
}
case "ip_blacklist_file":
if err := cl.parseBlacklistFile(d, m, true); err != nil {
return err
}
case "dns_blacklist_file":
if err := cl.parseBlacklistFile(d, m, false); err != nil {
return err
}
case "anomaly_threshold":
if err := cl.parseAnomalyThreshold(d, m); err != nil {
return err
}
case "custom_response":
if err := cl.parseCustomResponse(d, m); err != nil {
return err
}
case "redact_sensitive_data":
m.RedactSensitiveData = true
cl.logger.Debug("Redact sensitive data enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
case "tor":
// Handle the tor block as a nested directive
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "enabled":
if !d.NextArg() {
return d.ArgErr()
}
enabled, err := strconv.ParseBool(d.Val())
if err != nil {
return d.Errf("invalid enabled value: %v", err)
}
m.Tor.Enabled = enabled
cl.logger.Debug("Tor blocking enabled", zap.Bool("enabled", m.Tor.Enabled))
case "tor_ip_blacklist_file": // Updated field name
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.TORIPBlacklistFile = d.Val() // Updated field name
cl.logger.Debug("Tor IP blacklist file set", zap.String("tor_ip_blacklist_file", m.Tor.TORIPBlacklistFile))
case "update_interval":
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.UpdateInterval = d.Val()
cl.logger.Debug("Tor update interval set", zap.String("update_interval", m.Tor.UpdateInterval))
case "retry_on_failure":
if !d.NextArg() {
return d.ArgErr()
}
retryOnFailure, err := strconv.ParseBool(d.Val())
if err != nil {
return d.Errf("invalid retry_on_failure value: %v", err)
}
m.Tor.RetryOnFailure = retryOnFailure
cl.logger.Debug("Tor retry on failure set", zap.Bool("retry_on_failure", m.Tor.RetryOnFailure))
case "retry_interval":
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.RetryInterval = d.Val()
cl.logger.Debug("Tor retry interval set", zap.String("retry_interval", m.Tor.RetryInterval))
default:
return d.Errf("unrecognized tor subdirective: %s", d.Val())
}
}
default:
cl.logger.Warn("WAF Unrecognized SubDirective", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
return fmt.Errorf("file: %s, line: %d: unrecognized subdirective: %s", d.File(), d.Line(), d.Val())
handler, exists := directiveHandlers[directive]
if !exists {
cl.logger.Warn("Unrecognized WAF directive", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
return fmt.Errorf("file: %s, line: %d: unrecognized directive: %s", d.File(), d.Line(), directive)
}
} // Closing brace for the outer for loop
if err := handler(d, m); err != nil {
return err // Handler already provides context in error
}
}
}
if len(m.RuleFiles) == 0 {
return fmt.Errorf("no rule files specified")
return fmt.Errorf("no rule files specified for WAF") // More direct error
}
cl.logger.Debug("WAF configuration parsed successfully", zap.String("file", d.File())) // Success log
return nil
}
func (cl *ConfigLoader) parseRuleFile(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing path for rule_file", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
ruleFile := d.Val()
m.RuleFiles = append(m.RuleFiles, ruleFile)
if m.MetricsEndpoint != "" && !strings.HasPrefix(m.MetricsEndpoint, "/") {
return fmt.Errorf("metrics_endpoint must start with '/'")
return fmt.Errorf("metrics_endpoint must start with a leading '/'") // Improved error message
}
cl.logger.Info("WAF Loading Rule File",
zap.String("file", ruleFile),
cl.logger.Info("Loading WAF rule file", // Improved log message
zap.String("path", ruleFile), // More descriptive log field
zap.String("caddyfile", d.File()),
zap.Int("line", d.Line()),
)
@@ -314,58 +208,49 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
}
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing status code for custom_response", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
statusCode, err := strconv.Atoi(d.Val())
statusCode, err := cl.parseStatusCode(d)
if err != nil {
return fmt.Errorf("file: %s, line: %d: invalid status code for custom_response: %v", d.File(), d.Line(), err)
return err
}
if m.CustomResponses[statusCode].Headers == nil {
m.CustomResponses[statusCode] = CustomBlockResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
}
if _, exists := m.CustomResponses[statusCode]; exists {
return d.Errf("custom_response for status code %d already defined", statusCode) // Prevent duplicate status codes
}
resp := CustomBlockResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
}
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing content_type or file path for custom_response", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
contentTypeOrFile := d.Val()
if d.NextArg() {
filePath := d.Val()
content, err := os.ReadFile(filePath)
content, err := cl.readResponseFromFile(d, filePath)
if err != nil {
return fmt.Errorf("file: %s, line: %d: could not read custom response file '%s': %v", d.File(), d.Line(), filePath, err)
}
m.CustomResponses[statusCode] = CustomBlockResponse{
StatusCode: statusCode,
Headers: map[string]string{
"Content-Type": contentTypeOrFile,
},
Body: string(content),
return err
}
resp.Headers["Content-Type"] = contentTypeOrFile
resp.Body = content
cl.logger.Debug("Loaded custom response from file",
zap.Int("status_code", statusCode),
zap.String("file", filePath),
zap.String("file_path", filePath), // More descriptive log field
zap.String("content_type", contentTypeOrFile),
zap.String("caddyfile", d.File()),
zap.Int("line", d.Line()),
)
} else {
remaining := d.RemainingArgs()
if len(remaining) == 0 {
return fmt.Errorf("file: %s, line: %d: missing custom response body", d.File(), d.Line())
}
body := strings.Join(remaining, " ")
m.CustomResponses[statusCode] = CustomBlockResponse{
StatusCode: statusCode,
Headers: map[string]string{
"Content-Type": contentTypeOrFile,
},
Body: body,
body, err := cl.parseInlineResponseBody(d)
if err != nil {
return err
}
resp.Headers["Content-Type"] = contentTypeOrFile
resp.Body = body
cl.logger.Debug("Loaded inline custom response",
zap.Int("status_code", statusCode),
zap.String("content_type", contentTypeOrFile),
@@ -374,95 +259,268 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
zap.Int("line", d.Line()),
)
}
m.CustomResponses[statusCode] = resp
return nil
}
func (cl *ConfigLoader) parseCountryBlock(d *caddyfile.Dispenser, m *Middleware, isBlock bool) error {
target := &m.CountryBlock
if !isBlock {
target = &m.CountryWhitelist
}
target.Enabled = true
// parseCountryBlockDirective returns a closure to handle block_countries and whitelist_countries directives.
func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfile.Dispenser, m *Middleware) error {
return func(d *caddyfile.Dispenser, m *Middleware) error {
target := &m.CountryBlock
directiveName := "block_countries"
if !isBlock {
target = &m.CountryWhitelist
directiveName = "whitelist_countries"
}
target.Enabled = true
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing GeoIP DB path", d.File(), d.Line())
}
target.GeoIPDBPath = d.Val()
target.CountryList = []string{}
if !d.NextArg() {
return d.ArgErr() // More specific error type
}
target.GeoIPDBPath = d.Val()
target.CountryList = []string{}
for d.NextArg() {
country := strings.ToUpper(d.Val())
target.CountryList = append(target.CountryList, country)
}
for d.NextArg() {
country := strings.ToUpper(d.Val())
target.CountryList = append(target.CountryList, country)
}
cl.logger.Debug("Country list configured",
zap.Bool("block_mode", isBlock),
zap.Strings("countries", target.CountryList),
zap.String("geoip_db_path", target.GeoIPDBPath),
zap.String("file", d.File()), zap.Int("line", d.Line()),
)
return nil
cl.logger.Debug("Country list configured", // Improved log message
zap.String("directive", directiveName), // Log directive name
zap.Bool("block_mode", isBlock),
zap.Strings("countries", target.CountryList),
zap.String("geoip_db_path", target.GeoIPDBPath),
zap.String("file", d.File()), zap.Int("line", d.Line()),
)
return nil
}
}
func (cl *ConfigLoader) parseLogSeverity(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing value for log_severity", d.File(), d.Line())
return d.ArgErr() // More specific error type
}
m.LogSeverity = d.Val()
cl.logger.Debug("Log severity set",
severity := d.Val()
validSeverities := []string{"debug", "info", "warn", "error"} // Define valid severities
isValid := false
for _, valid := range validSeverities {
if severity == valid {
isValid = true
break
}
}
if !isValid {
return d.Errf("invalid log_severity value '%s', must be one of: %s", severity, strings.Join(validSeverities, ", ")) // Improved error message
}
m.LogSeverity = severity
cl.logger.Debug("Log severity set", // Improved log message
zap.String("severity", m.LogSeverity),
zap.String("file", d.File()), zap.Int("line", d.Line()),
)
return nil
}
func (cl *ConfigLoader) parseBlacklistFile(d *caddyfile.Dispenser, m *Middleware, isIP bool) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing blacklist file path", d.File(), d.Line())
}
filePath := d.Val()
// Check if the file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
// Create an empty file if it doesn't exist
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("file: %s, line: %d: could not create blacklist file '%s': %v", d.File(), d.Line(), filePath, err)
// parseBlacklistFileDirective returns a closure to handle ip_blacklist_file and dns_blacklist_file directives.
func (cl *ConfigLoader) parseBlacklistFileDirective(isIP bool) func(d *caddyfile.Dispenser, m *Middleware) error {
return func(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return d.ArgErr() // More specific error type
}
filePath := d.Val()
directiveName := "dns_blacklist_file"
if isIP {
directiveName = "ip_blacklist_file"
}
file.Close()
cl.logger.Warn("Blacklist file does not exist, created an empty file",
zap.String("file", filePath),
zap.Bool("is_ip", isIP),
if err := cl.ensureBlacklistFileExists(d, filePath, isIP); err != nil {
return err
}
// Assign the file path to the appropriate field
if isIP {
m.IPBlacklistFile = filePath
} else {
m.DNSBlacklistFile = filePath
}
cl.logger.Info("Blacklist file configured", // Improved log message
zap.String("directive", directiveName), // Log directive name
zap.String("path", filePath), // More descriptive log field
zap.Bool("is_ip_type", isIP),
)
} else if err != nil {
return fmt.Errorf("file: %s, line: %d: could not access blacklist file '%s': %v", d.File(), d.Line(), filePath, err)
return nil
}
// Assign the file path to the appropriate field
if isIP {
m.IPBlacklistFile = filePath
} else {
m.DNSBlacklistFile = filePath
}
cl.logger.Info("Blacklist file loaded",
zap.String("file", filePath),
zap.Bool("is_ip", isIP),
)
return nil
}
func (cl *ConfigLoader) parseAnomalyThreshold(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() {
return fmt.Errorf("file: %s, line: %d: missing threshold value", d.File(), d.Line())
}
threshold, err := strconv.Atoi(d.Val())
threshold, err := cl.parsePositiveInteger(d, "anomaly_threshold")
if err != nil {
return fmt.Errorf("file: %s, line: %d: invalid threshold: %v", d.File(), d.Line(), err)
return err
}
m.AnomalyThreshold = threshold
cl.logger.Debug("Anomaly threshold set", zap.Int("threshold", threshold))
return nil
}
func (cl *ConfigLoader) parseTorBlock(d *caddyfile.Dispenser, m *Middleware) error {
for nesting := d.Nesting(); d.NextBlock(nesting); {
subDirective := d.Val()
switch subDirective {
case "enabled":
enabled, err := cl.parseBool(d, "tor enabled") // More descriptive arg name
if err != nil {
return err
}
m.Tor.Enabled = enabled
cl.logger.Debug("Tor blocking enabled", zap.Bool("enabled", m.Tor.Enabled))
case "tor_ip_blacklist_file": // Updated field name
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.TORIPBlacklistFile = d.Val() // Updated field name
cl.logger.Debug("Tor IP blacklist file set", zap.String("file_path", m.Tor.TORIPBlacklistFile)) // More descriptive log field
case "update_interval":
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.UpdateInterval = d.Val()
cl.logger.Debug("Tor update interval set", zap.String("interval", m.Tor.UpdateInterval)) // More descriptive log field
case "retry_on_failure":
retryOnFailure, err := cl.parseBool(d, "tor retry_on_failure") // More descriptive arg name
if err != nil {
return err
}
m.Tor.RetryOnFailure = retryOnFailure
cl.logger.Debug("Tor retry on failure set", zap.Bool("retry_on_failure", m.Tor.RetryOnFailure))
case "retry_interval":
if !d.NextArg() {
return d.ArgErr()
}
m.Tor.RetryInterval = d.Val()
cl.logger.Debug("Tor retry interval set", zap.String("interval", m.Tor.RetryInterval)) // More descriptive log field
default:
return d.Errf("unrecognized tor subdirective: %s", subDirective) // More specific error message
}
}
return nil
}
func (cl *ConfigLoader) parseLogJSON(d *caddyfile.Dispenser, m *Middleware) error {
m.LogJSON = true
cl.logger.Debug("Log JSON enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
return nil
}
func (cl *ConfigLoader) parseRedactSensitiveData(d *caddyfile.Dispenser, m *Middleware) error {
m.RedactSensitiveData = true
cl.logger.Debug("Redact sensitive data enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
return nil
}
// --- Helper Functions ---
// parsePositiveInteger parses a directive argument as a positive integer.
func (cl *ConfigLoader) parsePositiveInteger(d *caddyfile.Dispenser, directiveName string) (int, error) {
if !d.NextArg() {
return 0, d.ArgErr() // More specific error type
}
valStr := d.Val()
val, err := strconv.Atoi(valStr)
if err != nil {
return 0, d.Errf("invalid %s value '%s': %v", directiveName, valStr, err) // More descriptive error message
}
if val <= 0 {
return 0, d.Errf("%s must be a positive integer, but got '%d'", directiveName, val) // More descriptive error message
}
return val, nil
}
// parseDuration parses a directive argument as a time duration.
func (cl *ConfigLoader) parseDuration(d *caddyfile.Dispenser, directiveName string) (time.Duration, error) {
if !d.NextArg() {
return 0, d.ArgErr() // More specific error type
}
durationStr := d.Val()
duration, err := time.ParseDuration(durationStr)
if err != nil {
return 0, d.Errf("invalid %s value '%s': %v", directiveName, durationStr, err) // More descriptive error message
}
return duration, nil
}
// parseBool parses a directive argument as a boolean.
func (cl *ConfigLoader) parseBool(d *caddyfile.Dispenser, directiveName string) (bool, error) {
if !d.NextArg() {
return false, d.ArgErr() // More specific error type
}
boolStr := d.Val()
val, err := strconv.ParseBool(boolStr)
if err != nil {
return false, d.Errf("invalid %s value '%s': %v, must be 'true' or 'false'", directiveName, boolStr, err) // More descriptive error message
}
return val, nil
}
// parseStatusCode parses a directive argument as an HTTP status code.
func (cl *ConfigLoader) parseStatusCode(d *caddyfile.Dispenser) (int, error) {
statusCodeStr := d.Val()
statusCode, err := strconv.Atoi(statusCodeStr)
if err != nil {
return 0, fmt.Errorf("file: %s, line: %d: invalid status code '%s': %v", d.File(), d.Line(), statusCodeStr, err) // Include file and line in error
}
if statusCode < 100 || statusCode > 599 {
return 0, fmt.Errorf("file: %s, line: %d: status code '%d' out of range, must be between 100 and 599", d.File(), d.Line(), statusCode) // Include file and line in error
}
return statusCode, nil
}
// readResponseFromFile reads custom response body from file.
func (cl *ConfigLoader) readResponseFromFile(d *caddyfile.Dispenser, filePath string) (string, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("file: %s, line: %d: could not read custom response file '%s': %v", d.File(), d.Line(), filePath, err) // Include file and line in error
}
return string(content), nil
}
// parseInlineResponseBody parses inline custom response body.
func (cl *ConfigLoader) parseInlineResponseBody(d *caddyfile.Dispenser) (string, error) {
remaining := d.RemainingArgs()
if len(remaining) == 0 {
return "", fmt.Errorf("file: %s, line: %d: missing custom response body", d.File(), d.Line()) // Include file and line in error
}
return strings.Join(remaining, " "), nil
}
// ensureBlacklistFileExists checks if blacklist file exists and creates it if not.
func (cl *ConfigLoader) ensureBlacklistFileExists(d *caddyfile.Dispenser, filePath string, isIP bool) error {
if _, err := os.Stat(filePath); os.IsNotExist(err) {
file, err := os.Create(filePath)
if err != nil {
fileType := "DNS"
if isIP {
fileType = "IP"
}
return fmt.Errorf("file: %s, line: %d: could not create %s blacklist file '%s': %v", d.File(), d.Line(), fileType, filePath, err) // More descriptive error
}
file.Close()
fileType := "DNS"
if isIP {
fileType = "IP"
}
cl.logger.Warn("%s blacklist file does not exist, created an empty file", zap.String("type", fileType), zap.String("path", filePath)) // Improved log
} else if err != nil {
fileType := "DNS"
if isIP {
fileType = "IP"
}
return fmt.Errorf("file: %s, line: %d: could not access %s blacklist file '%s': %v", d.File(), d.Line(), fileType, filePath, err) // Improved error
}
return nil
}

138
geoip.go
View File

@@ -44,20 +44,13 @@ func (gh *GeoIPHandler) LoadGeoIPDatabase(path string) (*maxminddb.Reader, error
if path == "" {
return nil, fmt.Errorf("no GeoIP database path specified")
}
gh.logger.Debug("Attempting to load GeoIP database",
zap.String("path", path),
)
gh.logger.Debug("Loading GeoIP database", zap.String("path", path)) // Slightly improved log message
reader, err := maxminddb.Open(path)
if err != nil {
gh.logger.Error("Failed to load GeoIP database",
zap.String("path", path),
zap.Error(err),
)
gh.logger.Error("Failed to load GeoIP database", zap.String("path", path), zap.Error(err))
return nil, fmt.Errorf("failed to load GeoIP database: %w", err)
}
gh.logger.Info("GeoIP database loaded successfully",
zap.String("path", path),
)
gh.logger.Info("GeoIP database loaded", zap.String("path", path)) // Slightly improved log message
return reader, nil
}
@@ -69,6 +62,7 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
if geoIP == nil {
return false, fmt.Errorf("geoip database not loaded")
}
ip, err := gh.extractIPFromRemoteAddr(remoteAddr)
if err != nil {
return false, err
@@ -76,21 +70,16 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
gh.logger.Error("invalid IP address", zap.String("ip", ip))
gh.logger.Warn("Invalid IP address", zap.String("ip", ip)) // Changed Error to Warn for invalid IP, as it might be client issue.
return false, fmt.Errorf("invalid IP address: %s", ip)
}
// Easy: Add caching of GeoIP lookups for performance.
// Check cache first
if gh.geoIPCache != nil {
gh.geoIPCacheMutex.RLock()
if record, ok := gh.geoIPCache[ip]; ok {
gh.geoIPCacheMutex.RUnlock()
for _, country := range countryList {
if strings.EqualFold(record.Country.ISOCode, country) {
return true, nil
}
}
return false, nil
return gh.isCountryInRecord(record, countryList), nil // Helper function for country check
}
gh.geoIPCacheMutex.RUnlock()
}
@@ -98,52 +87,16 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
var record GeoIPRecord
err = geoIP.Lookup(parsedIP, &record)
if err != nil {
gh.logger.Error("geoip lookup failed", zap.String("ip", ip), zap.Error(err))
// Critical: Handle cases where the GeoIP database lookup fails more gracefully.
switch gh.geoIPLookupFallbackBehavior {
case "default":
// Log and treat as not in the list
return false, nil
case "none":
return false, fmt.Errorf("geoip lookup failed: %w", err)
case "": // No fallback configured, maintain existing behavior
return false, fmt.Errorf("geoip lookup failed: %w", err)
default:
// Configurable fallback country code
for _, country := range countryList {
if strings.EqualFold(gh.geoIPLookupFallbackBehavior, country) {
return true, nil
}
}
return false, nil
}
gh.logger.Error("GeoIP lookup failed", zap.String("ip", ip), zap.Error(err))
return gh.handleGeoIPLookupError(err, countryList) // Helper function for error handling
}
// Easy: Add caching of GeoIP lookups for performance.
// Cache the record
if gh.geoIPCache != nil {
gh.geoIPCacheMutex.Lock()
gh.geoIPCache[ip] = record
gh.geoIPCacheMutex.Unlock()
// Invalidate cache entry after TTL (basic implementation)
if gh.geoIPCacheTTL > 0 {
time.AfterFunc(gh.geoIPCacheTTL, func() {
gh.geoIPCacheMutex.Lock()
delete(gh.geoIPCache, ip)
gh.geoIPCacheMutex.Unlock()
})
}
gh.cacheGeoIPRecord(ip, record) // Helper function for caching
}
for _, country := range countryList {
if strings.EqualFold(record.Country.ISOCode, country) {
return true, nil
}
}
return false, nil
return gh.isCountryInRecord(record, countryList), nil // Helper function for country check
}
// getCountryCode extracts the country code for logging purposes
@@ -162,11 +115,28 @@ func (gh *GeoIPHandler) GetCountryCode(remoteAddr string, geoIP *maxminddb.Reade
if parsedIP == nil {
return "N/A"
}
// Check cache first for GetCountryCode as well for consistency and potential perf gain
if gh.geoIPCache != nil {
gh.geoIPCacheMutex.RLock()
if record, ok := gh.geoIPCache[ip]; ok {
gh.geoIPCacheMutex.RUnlock()
return record.Country.ISOCode
}
gh.geoIPCacheMutex.RUnlock()
}
var record GeoIPRecord
err = geoIP.Lookup(parsedIP, &record)
if err != nil {
return "N/A"
return "N/A" // Simply return "N/A" on lookup error for GetCountryCode
}
// Cache the record for GetCountryCode as well
if gh.geoIPCache != nil {
gh.cacheGeoIPRecord(ip, record)
}
return record.Country.ISOCode
}
@@ -178,3 +148,51 @@ func (gh *GeoIPHandler) extractIPFromRemoteAddr(remoteAddr string) (string, erro
}
return host, nil
}
// Helper function to check if the country in the record is in the country list
func (gh *GeoIPHandler) isCountryInRecord(record GeoIPRecord, countryList []string) bool {
for _, country := range countryList {
if strings.EqualFold(record.Country.ISOCode, country) {
return true
}
}
return false
}
// Helper function to handle GeoIP lookup errors based on fallback behavior
func (gh *GeoIPHandler) handleGeoIPLookupError(err error, countryList []string) (bool, error) {
switch gh.geoIPLookupFallbackBehavior {
case "default":
// Log at debug level as it's a fallback scenario, not necessarily an error for the overall operation
gh.logger.Debug("GeoIP lookup failed, using default fallback (not in list)", zap.Error(err))
return false, nil // Treat as not in the list
case "none":
return false, fmt.Errorf("geoip lookup failed: %w", err) // Propagate the error
case "": // No fallback configured, maintain existing behavior
return false, fmt.Errorf("geoip lookup failed: %w", err) // Propagate the error
default: // Configurable fallback country code
gh.logger.Debug("GeoIP lookup failed, using configured fallback", zap.String("fallbackCountry", gh.geoIPLookupFallbackBehavior), zap.Error(err))
for _, country := range countryList {
if strings.EqualFold(gh.geoIPLookupFallbackBehavior, country) {
return true, nil // Treat as in the list for the fallback country
}
}
return false, nil // Treat as not in the list if fallback country is not in the list
}
}
// Helper function to cache GeoIP record
func (gh *GeoIPHandler) cacheGeoIPRecord(ip string, record GeoIPRecord) {
gh.geoIPCacheMutex.Lock()
gh.geoIPCache[ip] = record
gh.geoIPCacheMutex.Unlock()
// Invalidate cache entry after TTL
if gh.geoIPCacheTTL > 0 {
time.AfterFunc(gh.geoIPCacheTTL, func() {
gh.geoIPCacheMutex.Lock()
delete(gh.geoIPCache, ip)
gh.geoIPCacheMutex.Unlock()
})
}
}

View File

@@ -1,133 +1,188 @@
package caddywaf
import (
"bytes"
"context"
"net/http"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/google/uuid"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// ==================== Request Handling and Logic ====================
// ServeHTTP implements caddyhttp.Handler.
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
// Generate a unique log ID for the request
logID := uuid.New().String()
// Log the request with common fields
m.logRequest(zapcore.InfoLevel, "WAF evaluation started", r, zap.String("log_id", logID))
m.logRequestStart(r, logID)
// Use the custom type as the key
ctx := context.WithValue(r.Context(), ContextKeyLogId("logID"), logID)
r = r.WithContext(ctx)
// Increment total requests
m.incrementTotalRequestsMetric()
state := m.initializeWAFState()
if m.isPhaseBlocked(w, r, 1, state) { // Phase 1: Pre-request checks
return nil // Request blocked in Phase 1, short-circuit
}
if m.isPhaseBlocked(w, r, 2, state) { // Phase 2: Request analysis
return nil // Request blocked in Phase 2, short-circuit
}
recorder := NewResponseRecorder(w)
err := next.ServeHTTP(recorder, r)
if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) { // Phase 3: Response Header analysis
return nil // Request blocked in Phase 3, short-circuit
}
m.handleResponseBodyPhase(recorder, r, state) // Phase 4: Response Body analysis (if not blocked yet)
if state.Blocked {
m.incrementBlockedRequestsMetric()
m.writeCustomResponse(recorder, state.StatusCode)
return nil // Short circuit if blocked in any phase after headers
}
m.incrementAllowedRequestsMetric()
if m.isMetricsRequest(r) {
return m.handleMetricsRequest(w, r) // Handle metrics requests separately
}
// If not blocked, copy recorded response back to original writer
if !state.Blocked {
// Copy headers from recorder to original writer
header := w.Header()
for key, values := range recorder.Header() {
for _, value := range values {
header.Add(key, value)
}
}
w.WriteHeader(recorder.StatusCode()) // Set status code from recorder
// Write body from recorder to original writer
_, writeErr := w.Write(recorder.body.Bytes())
if writeErr != nil {
m.logger.Error("Failed to write recorded response body to client", zap.Error(writeErr), zap.String("log_id", logID))
// We should still return the original error from next.ServeHTTP if available, or a new error if writing body failed and next didn't return error.
if err == nil {
return writeErr // If original handler didn't error, return body write error.
}
}
}
m.logRequestCompletion(logID, state)
return err // Return any error from the next handler
}
// isPhaseBlocked encapsulates the phase handling and blocking check logic.
func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) bool {
m.handlePhase(w, r, phase, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
w.WriteHeader(state.StatusCode)
return true
}
return false
}
// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic.
func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool {
m.handlePhase(recorder, r, phase, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
recorder.WriteHeader(state.StatusCode)
return true
}
return false
}
// logRequestStart logs the start of WAF evaluation.
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
m.logger.Info("WAF request evaluation started",
zap.String("log_id", logID),
zap.String("method", r.Method),
zap.String("uri", r.RequestURI),
zap.String("remote_address", r.RemoteAddr),
zap.String("user_agent", r.UserAgent()),
)
}
// incrementTotalRequestsMetric increments the total requests metric.
func (m *Middleware) incrementTotalRequestsMetric() {
m.muMetrics.Lock()
m.totalRequests++
m.muMetrics.Unlock()
}
// Initialize WAF state for the request
state := &WAFState{
// initializeWAFState initializes the WAF state.
func (m *Middleware) initializeWAFState() *WAFState {
return &WAFState{
TotalScore: 0,
Blocked: false,
StatusCode: http.StatusOK,
ResponseWritten: false,
}
}
// Log the request details
m.logger.Info("WAF evaluation started",
zap.String("log_id", logID),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("source_ip", r.RemoteAddr),
zap.String("user_agent", r.UserAgent()),
zap.String("query_params", r.URL.RawQuery),
)
// handleResponseBodyPhase processes Phase 4 (response body).
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
body := recorder.BodyString()
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string)))
// Handle Phase 1: Pre-request evaluation
m.handlePhase(w, r, 1, state)
if state.Blocked {
m.muMetrics.Lock()
m.blockedRequests++
m.muMetrics.Unlock()
w.WriteHeader(state.StatusCode)
return nil
}
// Handle Phase 2: Request evaluation
m.handlePhase(w, r, 2, state)
if state.Blocked {
m.muMetrics.Lock()
m.blockedRequests++
m.muMetrics.Unlock()
w.WriteHeader(state.StatusCode)
return nil
}
// Capture the response using a response recorder
recorder := &responseRecorder{ResponseWriter: w, body: new(bytes.Buffer)}
err := next.ServeHTTP(recorder, r)
// Handle Phase 3: Response headers evaluation
m.handlePhase(recorder, r, 3, state)
if state.Blocked {
m.muMetrics.Lock()
m.blockedRequests++
m.muMetrics.Unlock()
recorder.WriteHeader(state.StatusCode)
return nil
}
// Handle Phase 4: Response body evaluation
if recorder.body != nil {
body := recorder.body.String()
m.logger.Debug("Response body captured", zap.String("body", body))
for _, rule := range m.Rules[4] {
if rule.regex.MatchString(body) {
m.processRuleMatch(recorder, r, &rule, body, state)
if state.Blocked {
m.muMetrics.Lock()
m.blockedRequests++
m.muMetrics.Unlock()
recorder.WriteHeader(state.StatusCode)
return nil
}
}
}
// Write the response body if no blocking occurred
if !state.ResponseWritten {
_, writeErr := w.Write(recorder.body.Bytes())
if writeErr != nil {
m.logger.Error("Failed to write response body", zap.Error(writeErr))
for _, rule := range m.Rules[4] {
if rule.regex.MatchString(body) {
m.processRuleMatch(recorder, r, &rule, body, state)
if state.Blocked {
m.incrementBlockedRequestsMetric()
return
}
}
}
}
// Increment allowed requests if not blocked
if !state.Blocked {
m.muMetrics.Lock()
m.allowedRequests++
m.muMetrics.Unlock()
// incrementBlockedRequestsMetric increments the blocked requests metric.
func (m *Middleware) incrementBlockedRequestsMetric() {
m.muMetrics.Lock()
m.blockedRequests++
m.muMetrics.Unlock()
}
// incrementAllowedRequestsMetric increments the allowed requests metric.
func (m *Middleware) incrementAllowedRequestsMetric() {
m.muMetrics.Lock()
m.allowedRequests++
m.muMetrics.Unlock()
}
// isMetricsRequest checks if it's a metrics request.
func (m *Middleware) isMetricsRequest(r *http.Request) bool {
return m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint
}
// writeCustomResponse writes a custom response.
func (m *Middleware) writeCustomResponse(w http.ResponseWriter, statusCode int) {
if customResponse, ok := m.CustomResponses[statusCode]; ok {
for key, value := range customResponse.Headers {
w.Header().Set(key, value)
}
w.WriteHeader(customResponse.StatusCode)
if _, err := w.Write([]byte(customResponse.Body)); err != nil {
m.logger.Error("Failed to write custom response body", zap.Error(err))
}
}
}
// Handle metrics endpoint requests
if m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint {
return m.handleMetricsRequest(w, r)
}
// Log the completion of WAF evaluation
m.logger.Info("WAF evaluation complete",
// logRequestCompletion logs the completion of WAF evaluation.
func (m *Middleware) logRequestCompletion(logID string, state *WAFState) {
m.logger.Info("WAF request evaluation completed",
zap.String("log_id", logID),
zap.Int("total_score", state.TotalScore),
zap.Bool("blocked", state.Blocked),
zap.Int("status_code", state.StatusCode),
)
return err
}
// ==================== Utility Functions ====================

View File

@@ -1,4 +1,3 @@
// logging.go
package caddywaf
import (
@@ -11,155 +10,104 @@ import (
"go.uber.org/zap/zapcore"
)
const unknownValue = "unknown" // Define a constant for "unknown" values
var sensitiveKeys = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive keys for redaction as package variable
func (m *Middleware) logRequest(level zapcore.Level, msg string, r *http.Request, fields ...zap.Field) {
if m.logger == nil {
return
if m.logger == nil || level < m.logLevel {
return // Early return if logger is nil or level is below threshold
}
// Skip logging if the level is below the threshold
if level < m.logLevel {
return
}
// Extract log ID or generate a new one
var logID string
var newFields []zap.Field
foundLogID := false
for _, field := range fields {
if field.Key == "log_id" {
logID = field.String
foundLogID = true
} else {
newFields = append(newFields, field)
}
}
if !foundLogID {
logID = uuid.New().String()
}
// Append log_id explicitly to newFields
newFields = append(newFields, zap.String("log_id", logID))
// Attach common request metadata only if not already set
commonFields := m.getCommonLogFields(r, newFields)
for _, commonField := range commonFields {
fieldExists := false
for _, existingField := range newFields {
if existingField.Key == commonField.Key {
fieldExists = true
break
}
}
if !fieldExists {
newFields = append(newFields, commonField)
}
}
allFields := m.prepareLogFields(r, fields) // Prepare all fields in one function - Corrected call: Removed 'level'
// Send the log entry to the buffered channel
select {
case m.logChan <- LogEntry{Level: level, Message: msg, Fields: newFields}:
case m.logChan <- LogEntry{Level: level, Message: msg, Fields: allFields}:
// Log entry successfully queued
default:
// If the channel is full, fall back to synchronous logging
m.logger.Warn("Log buffer full, falling back to synchronous logging",
zap.String("message", msg),
zap.Any("fields", newFields),
zap.Any("fields", allFields),
)
m.logger.Log(level, msg, newFields...)
m.logger.Log(level, msg, allFields...)
}
}
func (m *Middleware) getCommonLogFields(r *http.Request, fields []zap.Field) []zap.Field {
// Debug: Print the incoming fields
m.logger.Debug("Incoming fields to getCommonLogFields",
zap.Any("fields", fields),
)
// prepareLogFields consolidates the logic for preparing log fields, including common fields and log_id.
func (m *Middleware) prepareLogFields(r *http.Request, fields []zap.Field) []zap.Field { // Corrected signature: Removed 'level zapcore.Level'
var logID string
var customFields []zap.Field
// Extract or assign default values for metadata fields
var sourceIP string
var userAgent string
var requestMethod string
var requestPath string
var queryParams string
var statusCode int
// Extract log_id if present, otherwise generate a new one
logID, customFields = m.extractLogIDField(fields)
if logID == "" {
logID = uuid.New().String()
}
// Extract values from the incoming fields
// Get common log fields and merge with custom fields, prioritizing custom fields in case of duplicates
commonFields := m.getCommonLogFields(r)
allFields := m.mergeFields(customFields, commonFields, zap.String("log_id", logID)) // Ensure log_id is always present
return allFields
}
// extractLogIDField extracts the log_id from the given fields and returns it along with the remaining fields.
func (m *Middleware) extractLogIDField(fields []zap.Field) (logID string, remainingFields []zap.Field) {
for _, field := range fields {
switch field.Key {
case "source_ip":
sourceIP = field.String
case "user_agent":
userAgent = field.String
case "request_method":
requestMethod = field.String
case "request_path":
requestPath = field.String
case "query_params":
queryParams = field.String
case "status_code":
statusCode = int(field.Integer)
if field.Key == "log_id" {
logID = field.String
} else {
remainingFields = append(remainingFields, field)
}
}
return logID, remainingFields
}
// mergeFields merges custom fields and common fields, with custom fields taking precedence and ensuring log_id is present.
func (m *Middleware) mergeFields(customFields []zap.Field, commonFields []zap.Field, logIDField zap.Field) []zap.Field {
mergedFields := make([]zap.Field, 0, len(customFields)+len(commonFields)+1) //预分配容量
mergedFields = append(mergedFields, customFields...)
// Add common fields, skip if key already exists in custom fields
for _, commonField := range commonFields {
exists := false
for _, customField := range customFields {
if commonField.Key == customField.Key {
exists = true
break
}
}
if !exists {
mergedFields = append(mergedFields, commonField)
}
}
// If values are not provided in the fields, extract them from the request
if sourceIP == "" && r != nil {
mergedFields = append(mergedFields, logIDField) // Ensure log_id is always last or at least present
return mergedFields
}
func (m *Middleware) getCommonLogFields(r *http.Request) []zap.Field {
sourceIP := unknownValue
userAgent := unknownValue
requestMethod := unknownValue
requestPath := unknownValue
queryParams := "" // Initialize to empty string, not "unknown" - More accurate for query params
statusCode := 0 // Default status code is 0 if not explicitly set
if r != nil {
sourceIP = r.RemoteAddr
}
if userAgent == "" && r != nil {
userAgent = r.UserAgent()
}
if requestMethod == "" && r != nil {
requestMethod = r.Method
}
if requestPath == "" && r != nil {
requestPath = r.URL.Path
}
if queryParams == "" && r != nil {
queryParams = r.URL.RawQuery
}
// Debug: Print the extracted values
m.logger.Debug("Extracted values in getCommonLogFields",
zap.String("source_ip", sourceIP),
zap.String("user_agent", userAgent),
zap.String("request_method", requestMethod),
zap.String("request_path", requestPath),
zap.String("query_params", queryParams),
zap.Int("status_code", statusCode),
)
// Default values for missing fields
if sourceIP == "" {
sourceIP = "unknown"
}
if userAgent == "" {
userAgent = "unknown"
}
if requestMethod == "" {
requestMethod = "unknown"
}
if requestPath == "" {
requestPath = "unknown"
}
// Debug: Print the final values after applying defaults
m.logger.Debug("Final values after applying defaults",
zap.String("source_ip", sourceIP),
zap.String("user_agent", userAgent),
zap.String("request_method", requestMethod),
zap.String("request_path", requestPath),
zap.String("query_params", queryParams),
zap.Int("status_code", statusCode),
)
// Redact query parameters if required
if m.RedactSensitiveData {
queryParams = m.redactQueryParams(queryParams)
}
// Construct and return common fields
return []zap.Field{
zap.String("source_ip", sourceIP),
zap.String("user_agent", userAgent),
@@ -167,7 +115,7 @@ func (m *Middleware) getCommonLogFields(r *http.Request, fields []zap.Field) []z
zap.String("request_path", requestPath),
zap.String("query_params", queryParams),
zap.Int("status_code", statusCode),
zap.Time("timestamp", time.Now()), // Include a timestamp
zap.Time("timestamp", time.Now()),
}
}
@@ -182,7 +130,7 @@ func (m *Middleware) redactQueryParams(queryParams string) string {
keyValue := strings.SplitN(part, "=", 2)
if len(keyValue) == 2 {
key := strings.ToLower(keyValue[0])
if strings.Contains(key, "password") || strings.Contains(key, "token") || strings.Contains(key, "apikey") || strings.Contains(key, "authorization") || strings.Contains(key, "secret") {
if m.isSensitiveQueryParamKey(key) { // Use helper function for sensitive key check
parts[i] = keyValue[0] + "=REDACTED"
}
}
@@ -191,6 +139,15 @@ func (m *Middleware) redactQueryParams(queryParams string) string {
return strings.Join(parts, "&")
}
func (m *Middleware) isSensitiveQueryParamKey(key string) bool {
for _, sensitiveKey := range sensitiveKeys { // Use package level sensitiveKeys variable
if strings.Contains(key, sensitiveKey) {
return true
}
}
return false
}
func caddyTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
enc.AppendString(t.Format("2006/01/02 15:04:05.000"))
}

View File

@@ -126,11 +126,11 @@ func (rl *RateLimiter) cleanupExpiredEntries() {
// startCleanup starts the goroutine to periodically clean up expired entries.
func (rl *RateLimiter) startCleanup() {
go func() {
// log.Println("[INFO] Starting rate limiter cleanup goroutine")
// log.Println("[INFO] Starting rate limiter cleanup goroutine") <- Removed/commented out
ticker := time.NewTicker(rl.config.CleanupInterval)
defer func() {
ticker.Stop()
// log.Println("[INFO] Rate limiter cleanup goroutine stopped")
// log.Println("[INFO] Rate limiter cleanup goroutine stopped") <- Removed/commented out
}()
for {

View File

@@ -1,11 +1,11 @@
package caddywaf
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
@@ -21,6 +21,37 @@ type RequestValueExtractor struct {
// Define a custom type for context keys
type ContextKeyLogId string
// Extraction Target Constants - Improved Readability and Maintainability
const (
TargetMethod = "METHOD"
TargetRemoteIP = "REMOTE_IP"
TargetProtocol = "PROTOCOL"
TargetHost = "HOST"
TargetArgs = "ARGS"
TargetUserAgent = "USER_AGENT"
TargetPath = "PATH"
TargetURI = "URI"
TargetBody = "BODY"
TargetHeaders = "HEADERS" // Full request headers
TargetRequestHeaders = "REQUEST_HEADERS" // Alias for full request headers
TargetResponseHeaders = "RESPONSE_HEADERS"
TargetResponseBody = "RESPONSE_BODY"
TargetFileName = "FILE_NAME"
TargetFileMIMEType = "FILE_MIME_TYPE"
TargetCookies = "COOKIES" // All cookies
TargetRequestCookies = "REQUEST_COOKIES" // Alias for all cookies (DUPLICATE - REMOVE ALIAS)
TargetURLParamPrefix = "URL_PARAM:"
TargetJSONPathPrefix = "JSON_PATH:"
TargetContentType = "CONTENT_TYPE"
TargetURL = "URL"
TargetCookiesPrefix = "COOKIES:" // Dynamic cookie extraction prefix
TargetHeadersPrefix = "HEADERS:" // Dynamic header extraction prefix
TargetRequestHeadersPrefix = "REQUEST_HEADERS:" // Alias for dynamic header extraction prefix
TargetResponseHeadersPrefix = "RESPONSE_HEADERS:" // Dynamic response header extraction prefix
)
var sensitiveTargets = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive targets for redaction as package variable
// NewRequestValueExtractor creates a new RequestValueExtractor with a given logger
func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool) *RequestValueExtractor {
return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData}
@@ -56,249 +87,84 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
var unredactedValue string
var err error
// Basic Cases (Keep as Before)
switch {
case target == "METHOD":
unredactedValue = r.Method
case target == "REMOTE_IP":
unredactedValue = r.RemoteAddr
case target == "PROTOCOL":
unredactedValue = r.Proto
case target == "HOST":
unredactedValue = r.Host
case target == "ARGS":
if r.URL.RawQuery == "" {
rve.logger.Debug("Query string is empty", zap.String("target", target))
return "", fmt.Errorf("query string is empty for target: %s", target)
}
unredactedValue = r.URL.RawQuery
case target == "USER_AGENT":
unredactedValue = r.UserAgent()
if unredactedValue == "" {
rve.logger.Debug("User-Agent is empty", zap.String("target", target))
}
case target == "PATH":
unredactedValue = r.URL.Path
if unredactedValue == "" {
rve.logger.Debug("Request path is empty", zap.String("target", target))
}
case target == "URI":
unredactedValue = r.URL.RequestURI()
if unredactedValue == "" {
rve.logger.Debug("Request URI is empty", zap.String("target", target))
}
case target == "BODY":
if r.Body == nil {
rve.logger.Warn("Request body is nil", zap.String("target", target))
return "", fmt.Errorf("request body is nil for target: %s", target)
}
if r.ContentLength == 0 {
rve.logger.Debug("Request body is empty", zap.String("target", target))
return "", fmt.Errorf("request body is empty for target: %s", target)
}
var bodyBytes []byte
bodyBytes, err = io.ReadAll(r.Body)
// Optimization: Use a map for target extraction logic
extractionLogic := map[string]func() (string, error){
TargetMethod: func() (string, error) { return r.Method, nil },
TargetRemoteIP: func() (string, error) { return r.RemoteAddr, nil },
TargetProtocol: func() (string, error) { return r.Proto, nil },
TargetHost: func() (string, error) { return r.Host, nil },
TargetArgs: func() (string, error) {
return r.URL.RawQuery, rve.checkEmpty(r.URL.RawQuery, target, "Query string is empty")
},
TargetUserAgent: func() (string, error) {
value := r.UserAgent()
rve.logIfEmpty(value, target, "User-Agent is empty")
return value, nil
},
TargetPath: func() (string, error) {
value := r.URL.Path
rve.logIfEmpty(value, target, "Request path is empty")
return value, nil
},
TargetURI: func() (string, error) {
value := r.URL.RequestURI()
rve.logIfEmpty(value, target, "Request URI is empty")
return value, nil
},
TargetBody: func() (string, error) { return rve.extractBody(r, target) }, // Separate body extraction
TargetHeaders: func() (string, error) { return rve.extractHeaders(r.Header, "Request headers", target) }, // Helper for headers
TargetRequestHeaders: func() (string, error) { return rve.extractHeaders(r.Header, "Request headers", target) }, // Alias
TargetResponseHeaders: func() (string, error) { return rve.extractResponseHeaders(w, target) }, // Helper for response headers
TargetResponseBody: func() (string, error) { return rve.extractResponseBody(w, target) }, // Helper for response body
TargetFileName: func() (string, error) { return rve.extractFileName(r, target) }, // Helper for filename
TargetFileMIMEType: func() (string, error) { return rve.extractFileMIMEType(r, target) }, // Helper for mime type
TargetCookies: func() (string, error) { return rve.extractCookies(r.Cookies(), "No cookies found", target) }, // Helper for cookies
TargetRequestCookies: func() (string, error) { return rve.extractCookies(r.Cookies(), "No cookies found", target) }, // Alias
TargetContentType: func() (string, error) {
return r.Header.Get("Content-Type"), rve.checkEmpty(r.Header.Get("Content-Type"), target, "Content-Type header not found")
},
TargetURL: func() (string, error) {
return r.URL.String(), rve.checkEmpty(r.URL.String(), target, "URL could not be extracted")
},
}
if extractor, exists := extractionLogic[target]; exists {
unredactedValue, err = extractor()
if err != nil {
rve.logger.Error("Failed to read request body", zap.Error(err))
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
return "", err // Return error from extractor
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for next read
unredactedValue = string(bodyBytes)
// Full Header Dump (Request)
case target == "HEADERS", target == "REQUEST_HEADERS":
if len(r.Header) == 0 {
rve.logger.Debug("Request headers are empty", zap.String("target", target))
return "", fmt.Errorf("request headers are empty for target: %s", target)
}
headers := make([]string, 0)
for name, values := range r.Header {
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
}
unredactedValue = strings.Join(headers, "; ")
// Response Headers (Phase 3)
case target == "RESPONSE_HEADERS":
if w == nil {
return "", fmt.Errorf("response headers not accessible outside Phase 3 for target: %s", target)
}
headers := make([]string, 0)
for name, values := range w.Header() {
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
}
unredactedValue = strings.Join(headers, "; ")
// Response Body (Phase 4)
case target == "RESPONSE_BODY":
if w == nil {
return "", fmt.Errorf("response body not accessible outside Phase 4 for target: %s", target)
}
if recorder, ok := w.(*responseRecorder); ok {
if recorder == nil {
return "", fmt.Errorf("response recorder is nil for target: %s", target)
}
if recorder.body.Len() == 0 {
rve.logger.Debug("Response body is empty", zap.String("target", target))
return "", fmt.Errorf("response body is empty for target: %s", target)
}
unredactedValue = recorder.BodyString()
} else {
return "", fmt.Errorf("response recorder not available for target: %s", target)
}
case target == "FILE_NAME":
// Extract file name from multipart form data
if r.MultipartForm != nil && r.MultipartForm.File != nil {
for _, files := range r.MultipartForm.File {
for _, file := range files {
unredactedValue = file.Filename
break
}
}
}
if unredactedValue == "" {
rve.logger.Debug("File name not found", zap.String("target", target))
return "", fmt.Errorf("file name not found for target: %s", target)
}
case target == "FILE_MIME_TYPE":
// Extract MIME type from multipart form data
if r.MultipartForm != nil && r.MultipartForm.File != nil {
for _, files := range r.MultipartForm.File {
for _, file := range files {
unredactedValue = file.Header.Get("Content-Type")
break
}
}
}
if unredactedValue == "" {
rve.logger.Debug("File MIME type not found", zap.String("target", target))
return "", fmt.Errorf("file MIME type not found for target: %s", target)
}
// Dynamic Header Extraction (Request)
case strings.HasPrefix(target, "HEADERS:"), strings.HasPrefix(target, "REQUEST_HEADERS:"):
headerName := strings.TrimPrefix(strings.TrimPrefix(target, "HEADERS:"), "REQUEST_HEADERS:") // Trim both prefixes
headerValue := r.Header.Get(headerName)
if headerValue == "" {
rve.logger.Debug("Header not found", zap.String("header", headerName))
return "", fmt.Errorf("header '%s' not found for target: %s", headerName, target)
}
unredactedValue = headerValue
// Dynamic Response Header Extraction (Phase 3)
case strings.HasPrefix(target, "RESPONSE_HEADERS:"):
if w == nil {
return "", fmt.Errorf("response headers not available during this phase for target: %s", target)
}
headerName := strings.TrimPrefix(target, "RESPONSE_HEADERS:")
headerValue := w.Header().Get(headerName)
if headerValue == "" {
rve.logger.Debug("Response header not found", zap.String("header", headerName))
return "", fmt.Errorf("response header '%s' not found for target: %s", headerName, target)
}
unredactedValue = headerValue
// Cookies Extraction
case target == "COOKIES":
cookies := make([]string, 0)
for _, c := range r.Cookies() {
cookies = append(cookies, fmt.Sprintf("%s=%s", c.Name, c.Value))
}
if len(cookies) == 0 {
rve.logger.Debug("No cookies found", zap.String("target", target))
return "", fmt.Errorf("no cookies found for target: %s", target)
}
unredactedValue = strings.Join(cookies, "; ")
case strings.HasPrefix(target, "COOKIES:"):
cookieName := strings.TrimPrefix(target, "COOKIES:")
cookie, err := r.Cookie(cookieName)
} else if strings.HasPrefix(target, TargetHeadersPrefix) || strings.HasPrefix(target, TargetRequestHeadersPrefix) {
unredactedValue, err = rve.extractDynamicHeader(r.Header, strings.TrimPrefix(strings.TrimPrefix(target, TargetHeadersPrefix), TargetRequestHeadersPrefix), target)
if err != nil {
rve.logger.Debug("Cookie not found", zap.String("cookie", cookieName))
return "", fmt.Errorf("cookie '%s' not found for target: %s", cookieName, target)
return "", err
}
unredactedValue = cookie.Value
// URL Parameter Extraction
case strings.HasPrefix(target, "URL_PARAM:"):
paramName := strings.TrimPrefix(target, "URL_PARAM:")
if paramName == "" {
return "", fmt.Errorf("URL parameter name is empty for target: %s", target)
}
if r.URL.Query().Get(paramName) == "" {
rve.logger.Debug("URL parameter not found", zap.String("parameter", paramName))
return "", fmt.Errorf("url parameter '%s' not found for target: %s", paramName, target)
}
unredactedValue = r.URL.Query().Get(paramName)
// JSON Path Extraction from Body
case strings.HasPrefix(target, "JSON_PATH:"):
jsonPath := strings.TrimPrefix(target, "JSON_PATH:")
if r.Body == nil {
rve.logger.Warn("Request body is nil", zap.String("target", target))
return "", fmt.Errorf("request body is nil for target: %s", target)
}
if r.ContentLength == 0 {
rve.logger.Debug("Request body is empty", zap.String("target", target))
return "", fmt.Errorf("request body is empty for target: %s", target)
}
bodyBytes, err := io.ReadAll(r.Body)
} else if strings.HasPrefix(target, TargetResponseHeadersPrefix) {
unredactedValue, err = rve.extractDynamicResponseHeader(w, strings.TrimPrefix(target, TargetResponseHeadersPrefix), target)
if err != nil {
rve.logger.Error("Failed to read request body", zap.Error(err))
return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err)
return "", err
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for next read
// Use helper method to dynamically extract value based on JSON path (e.g., 'data.items.0.name').
unredactedValue, err = rve.extractJSONPath(string(bodyBytes), jsonPath)
} else if strings.HasPrefix(target, TargetCookiesPrefix) {
unredactedValue, err = rve.extractDynamicCookie(r, strings.TrimPrefix(target, TargetCookiesPrefix), target)
if err != nil {
rve.logger.Debug("Failed to extract value from JSON path", zap.String("target", target), zap.String("path", jsonPath), zap.Error(err))
return "", fmt.Errorf("failed to extract from JSON path '%s': %w", jsonPath, err)
return "", err
}
// New cases start here:
case target == "CONTENT_TYPE":
unredactedValue = r.Header.Get("Content-Type")
if unredactedValue == "" {
rve.logger.Debug("Content-Type header not found", zap.String("target", target))
return "", fmt.Errorf("content-type header not found for target: %s", target)
} else if strings.HasPrefix(target, TargetURLParamPrefix) {
unredactedValue, err = rve.extractURLParam(r.URL, strings.TrimPrefix(target, TargetURLParamPrefix), target)
if err != nil {
return "", err
}
case target == "URL":
unredactedValue = r.URL.String()
if unredactedValue == "" {
rve.logger.Debug("URL could not be extracted", zap.String("target", target))
return "", fmt.Errorf("url could not be extracted for target: %s", target)
} else if strings.HasPrefix(target, TargetJSONPathPrefix) {
unredactedValue, err = rve.extractValueForJSONPath(r, strings.TrimPrefix(target, TargetJSONPathPrefix), target)
if err != nil {
return "", err
}
case target == "REQUEST_COOKIES":
cookies := make([]string, 0)
for _, c := range r.Cookies() {
cookies = append(cookies, fmt.Sprintf("%s=%s", c.Name, c.Value))
}
unredactedValue = strings.Join(cookies, "; ")
if len(cookies) == 0 {
rve.logger.Debug("No cookies found", zap.String("target", target))
return "", fmt.Errorf("no cookies found for target: %s", target)
}
default:
} else {
rve.logger.Warn("Unknown extraction target", zap.String("target", target))
return "", fmt.Errorf("unknown extraction target: %s", target)
}
// Redact sensitive fields before returning the value
value := unredactedValue
if rve.redactSensitiveData {
sensitiveTargets := []string{"password", "token", "apikey", "authorization", "secret"}
for _, sensitive := range sensitiveTargets {
if strings.Contains(strings.ToLower(target), sensitive) {
value = "REDACTED"
break
}
}
}
// Redact sensitive fields before returning the value (as before)
value := rve.redactValueIfSensitive(target, unredactedValue)
// Log the extracted value (redacted if necessary)
rve.logger.Debug("Extracted value",
@@ -310,6 +176,202 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
return unredactedValue, nil
}
// Helper function to check for empty value and log debug message if empty
func (rve *RequestValueExtractor) checkEmpty(value string, target, message string) error {
if value == "" {
rve.logger.Debug(message, zap.String("target", target))
return fmt.Errorf("%s for target: %s", message, target)
}
return nil
}
// Helper function to log debug message if value is empty
func (rve *RequestValueExtractor) logIfEmpty(value string, target string, message string) {
if value == "" {
rve.logger.Debug(message, zap.String("target", target))
}
}
// Helper function to extract body
func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (string, error) {
if r.Body == nil {
rve.logger.Warn("Request body is nil", zap.String("target", target))
return "", fmt.Errorf("request body is nil for target: %s", target)
}
if r.ContentLength == 0 {
rve.logger.Debug("Request body is empty", zap.String("target", target))
return "", fmt.Errorf("request body is empty for target: %s", target)
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
rve.logger.Error("Failed to read request body", zap.Error(err))
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
}
r.Body = http.NoBody // Reset body for next read - using http.NoBody
return string(bodyBytes), nil
}
// Helper function to extract headers
func (rve *RequestValueExtractor) extractHeaders(header http.Header, logMessage, target string) (string, error) {
if len(header) == 0 {
rve.logger.Debug(logMessage+" are empty", zap.String("target", target))
return "", fmt.Errorf("%s are empty for target: %s", logMessage, target)
}
headers := make([]string, 0)
for name, values := range header {
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
}
return strings.Join(headers, "; "), nil
}
// Helper function to extract response headers (for phase 3)
func (rve *RequestValueExtractor) extractResponseHeaders(w http.ResponseWriter, target string) (string, error) {
if w == nil {
return "", fmt.Errorf("response headers not accessible during this phase for target: %s", target)
}
return rve.extractHeaders(w.Header(), "Response headers", target)
}
// Helper function to extract response body (for phase 4)
func (rve *RequestValueExtractor) extractResponseBody(w http.ResponseWriter, target string) (string, error) {
if w == nil {
return "", fmt.Errorf("response body not accessible outside Phase 4 for target: %s", target)
}
recorder, ok := w.(*responseRecorder)
if !ok || recorder == nil {
return "", fmt.Errorf("response recorder not available for target: %s", target)
}
if recorder.body.Len() == 0 {
rve.logger.Debug("Response body is empty", zap.String("target", target))
return "", fmt.Errorf("response body is empty for target: %s", target)
}
return recorder.BodyString(), nil
}
// Helper function to extract filename from multipart form
func (rve *RequestValueExtractor) extractFileName(r *http.Request, target string) (string, error) {
if r.MultipartForm == nil || r.MultipartForm.File == nil {
rve.logger.Debug("Multipart form file not found", zap.String("target", target))
return "", fmt.Errorf("multipart form file not found for target: %s", target)
}
for _, files := range r.MultipartForm.File {
if len(files) > 0 { // Check if there are files
return files[0].Filename, nil // Return the first file's name
}
}
return "", fmt.Errorf("no files found in multipart form for target: %s", target) // No files found but form is present
}
// Helper function to extract MIME type from multipart form
func (rve *RequestValueExtractor) extractFileMIMEType(r *http.Request, target string) (string, error) {
if r.MultipartForm == nil || r.MultipartForm.File == nil {
rve.logger.Debug("Multipart form file not found", zap.String("target", target))
return "", fmt.Errorf("multipart form file not found for target: %s", target)
}
for _, files := range r.MultipartForm.File {
if len(files) > 0 { // Check if files are present
return files[0].Header.Get("Content-Type"), nil // Return MIME type of the first file
}
}
return "", fmt.Errorf("no files found in multipart form for target: %s", target) // No files found even though form is present
}
// Helper function to extract dynamic header value
func (rve *RequestValueExtractor) extractDynamicHeader(header http.Header, headerName, target string) (string, error) {
headerValue := header.Get(headerName)
if headerValue == "" {
rve.logger.Debug("Header not found", zap.String("header", headerName))
return "", fmt.Errorf("header '%s' not found for target: %s", headerName, target)
}
return headerValue, nil
}
// Helper function to extract dynamic response header value (for phase 3)
func (rve *RequestValueExtractor) extractDynamicResponseHeader(w http.ResponseWriter, headerName, target string) (string, error) {
if w == nil {
return "", fmt.Errorf("response headers not available during this phase for target: %s", target)
}
headerValue := w.Header().Get(headerName)
if headerValue == "" {
rve.logger.Debug("Response header not found", zap.String("header", headerName))
return "", fmt.Errorf("response header '%s' not found for target: %s", headerName, target)
}
return headerValue, nil
}
// Helper function to extract cookie value
func (rve *RequestValueExtractor) extractDynamicCookie(r *http.Request, cookieName string, target string) (string, error) {
cookie, err := r.Cookie(cookieName)
if err != nil {
rve.logger.Debug("Cookie not found", zap.String("cookie", cookieName))
return "", fmt.Errorf("cookie '%s' not found for target: %s", cookieName, target)
}
return cookie.Value, nil
}
// Helper function to extract URL parameter value
func (rve *RequestValueExtractor) extractURLParam(url *url.URL, paramName string, target string) (string, error) {
paramValue := url.Query().Get(paramName)
if paramValue == "" {
rve.logger.Debug("URL parameter not found", zap.String("parameter", paramName))
return "", fmt.Errorf("url parameter '%s' not found for target: %s", paramName, target)
}
return paramValue, nil
}
// Helper function to extract value for JSON Path
func (rve *RequestValueExtractor) extractValueForJSONPath(r *http.Request, jsonPath string, target string) (string, error) {
if r.Body == nil {
rve.logger.Warn("Request body is nil", zap.String("target", target))
return "", fmt.Errorf("request body is nil for target: %s", target)
}
if r.ContentLength == 0 {
rve.logger.Debug("Request body is empty", zap.String("target", target))
return "", fmt.Errorf("request body is empty for target: %s", target)
}
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
rve.logger.Error("Failed to read request body", zap.Error(err))
return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err)
}
r.Body = http.NoBody // Reset body for next read
// Use helper method to dynamically extract value based on JSON path (e.g., 'data.items.0.name').
unredactedValue, err := rve.extractJSONPath(string(bodyBytes), jsonPath)
if err != nil {
rve.logger.Debug("Failed to extract value from JSON path", zap.String("target", target), zap.String("path", jsonPath), zap.Error(err))
return "", fmt.Errorf("failed to extract from JSON path '%s': %w", jsonPath, err)
}
return unredactedValue, nil
}
// Helper function to redact value if target is sensitive
func (rve *RequestValueExtractor) redactValueIfSensitive(target string, value string) string {
if rve.redactSensitiveData {
for _, sensitive := range sensitiveTargets {
if strings.Contains(strings.ToLower(target), sensitive) {
return "REDACTED"
}
}
}
return value
}
// Helper function to extract cookies
func (rve *RequestValueExtractor) extractCookies(cookies []*http.Cookie, logMessage string, target string) (string, error) {
if len(cookies) == 0 {
rve.logger.Debug(logMessage, zap.String("target", target))
return "", fmt.Errorf("%s for target: %s", logMessage, target)
}
cookieStrings := make([]string, 0)
for _, cookie := range cookies {
cookieStrings = append(cookieStrings, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
}
return strings.Join(cookieStrings, "; "), nil
}
// Helper function for JSON path extraction.
func (rve *RequestValueExtractor) extractJSONPath(jsonStr string, jsonPath string) (string, error) {
// Validate input JSON string

225
rules.go
View File

@@ -1,3 +1,4 @@
// rules.go
package caddywaf
import (
@@ -9,18 +10,14 @@ import (
"sort"
"strings"
"github.com/google/uuid"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, value string, state *WAFState) bool {
logID, _ := r.Context().Value("logID").(string)
if logID == "" {
logID = uuid.New().String()
}
logID := r.Context().Value(ContextKeyLogId("logID")).(string)
m.logRequest(zapcore.DebugLevel, "Rule matched during evaluation", r,
m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message
zap.String("rule_id", string(rule.ID)),
zap.String("target", strings.Join(rule.Targets, ",")),
zap.String("value", value),
@@ -28,76 +25,54 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
zap.Int("score", rule.Score),
)
// Increment rule hit count
if count, ok := m.ruleHits.Load(rule.ID); ok {
newCount := count.(HitCount) + 1
m.ruleHits.Store(rule.ID, newCount)
m.logger.Debug("Incremented rule hit count",
zap.String("rule_id", string(rule.ID)),
zap.Int("new_count", int(newCount)),
)
} else {
m.ruleHits.Store(rule.ID, HitCount(1))
m.logger.Debug("Initialized rule hit count",
zap.String("rule_id", string(rule.ID)),
zap.Int("new_count", 1),
)
}
// Rule Hit Counter - Refactored for clarity
m.incrementRuleHitCount(RuleID(rule.ID))
// Increment rule hits by phase
m.muMetrics.Lock()
if m.ruleHitsByPhase == nil {
m.ruleHitsByPhase = make(map[int]int64)
}
m.ruleHitsByPhase[rule.Phase]++
m.muMetrics.Unlock()
// Metrics for Rule Hits by Phase - Refactored for clarity
m.incrementRuleHitsByPhaseMetric(rule.Phase)
oldScore := state.TotalScore
state.TotalScore += rule.Score
m.logRequest(zapcore.DebugLevel, "Increased anomaly score", r,
m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // Corrected argument order - 'r' is now the third argument
zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)),
zap.Int("score_increase", rule.Score),
zap.Int("old_total_score", oldScore),
zap.Int("new_total_score", state.TotalScore),
zap.Int("old_score", oldScore),
zap.Int("new_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
shouldBlock := false
shouldBlock := !state.ResponseWritten && (state.TotalScore >= m.AnomalyThreshold || rule.Action == "block")
blockReason := ""
if !state.ResponseWritten {
if state.TotalScore >= m.AnomalyThreshold {
shouldBlock = true
blockReason = "Anomaly threshold exceeded"
} else if rule.Action == "block" {
shouldBlock = true
if shouldBlock {
blockReason = "Anomaly threshold exceeded"
if rule.Action == "block" {
blockReason = "Rule action is 'block'"
}
}
m.logger.Debug("Processing rule action",
zap.String("rule_id", string(rule.ID)),
m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // 'r' is now the 3rd argument
zap.String("action", rule.Action),
zap.Bool("should_block", shouldBlock),
zap.String("block_reason", blockReason),
)
if shouldBlock && !state.ResponseWritten {
if shouldBlock {
m.blockRequest(w, r, state, http.StatusForbidden, blockReason, string(rule.ID), value,
zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
return false // Stop further processing
return false
}
if rule.Action == "log" {
m.logRequest(zapcore.InfoLevel, "Rule action is 'log', request allowed but logged", r,
m.logRequest(zapcore.InfoLevel, "Rule action: Log", r,
zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)),
)
} else if !shouldBlock && !state.ResponseWritten {
m.logRequest(zapcore.DebugLevel, "Rule matched, no blocking action taken", r,
m.logRequest(zapcore.DebugLevel, "Rule action: No Block", r,
zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)),
zap.String("action", rule.Action),
@@ -106,7 +81,30 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
)
}
return true // Continue processing
return true
}
// incrementRuleHitCount increments the hit counter for a given rule ID.
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
hitCount := HitCount(1) // Default increment
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded {
hitCount = currentCount.(HitCount) + 1
}
m.ruleHits.Store(ruleID, hitCount)
m.logger.Debug("Rule hit count updated",
zap.String("rule_id", string(ruleID)),
zap.Int("hit_count", int(hitCount)), // More descriptive log field
)
}
// incrementRuleHitsByPhaseMetric increments the rule hits by phase metric.
func (m *Middleware) incrementRuleHitsByPhaseMetric(phase int) {
m.muMetrics.Lock()
if m.ruleHitsByPhase == nil {
m.ruleHitsByPhase = make(map[int]int64)
}
m.ruleHitsByPhase[phase]++
m.muMetrics.Unlock()
}
func validateRule(rule *Rule) error {
@@ -131,92 +129,109 @@ func validateRule(rule *Rule) error {
return nil
}
// loadRules updates the RuleCache when rules are loaded and sorts rules by priority.
// loadRules updates the RuleCache and Rules map when rules are loaded and sorts rules by priority.
func (m *Middleware) loadRules(paths []string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.logger.Debug("Loading rules from files", zap.Strings("rule_files", paths))
m.Rules = make(map[int][]Rule)
loadedRules := make(map[int][]Rule) // Temporary map to hold loaded rules
totalRules := 0
var invalidFiles []string
var allInvalidRules []string
invalidFiles := []string{}
allInvalidRules := []string{}
ruleIDs := make(map[string]bool)
for _, path := range paths {
content, err := os.ReadFile(path)
fileRules, fileInvalidRules, err := m.loadRulesFromFile(path, ruleIDs) // Load rules from a single file
if err != nil {
m.logger.Error("Failed to read rule file", zap.String("file", path), zap.Error(err))
m.logger.Error("Failed to load rule file", zap.String("file", path), zap.Error(err))
invalidFiles = append(invalidFiles, path)
continue
continue // Skip to the next file if loading fails
}
var rules []Rule
if err := json.Unmarshal(content, &rules); err != nil {
m.logger.Error("Failed to unmarshal rules from file", zap.String("file", path), zap.Error(err))
invalidFiles = append(invalidFiles, path)
continue
if len(fileInvalidRules) > 0 {
m.logger.Warn("Invalid rules in file", zap.String("file", path), zap.Strings("errors", fileInvalidRules))
allInvalidRules = append(allInvalidRules, fileInvalidRules...)
}
m.logger.Info("Rules loaded from file", zap.String("file", path), zap.Int("valid_rules", len(fileRules)), zap.Int("invalid_rules", len(fileInvalidRules)))
// Sort rules by priority (higher priority first)
sort.Slice(rules, func(i, j int) bool {
return rules[i].Priority > rules[j].Priority
})
var invalidRulesInFile []string
for i, rule := range rules {
if err := validateRule(&rule); err != nil {
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule at index %d: %v", i, err))
continue
}
if _, exists := ruleIDs[string(rule.ID)]; exists {
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i))
continue
}
ruleIDs[string(rule.ID)] = true
// Check RuleCache first
if regex, exists := m.ruleCache.Get(rule.ID); exists {
rule.regex = regex
} else {
regex, err := regexp.Compile(rule.Pattern)
if err != nil {
m.logger.Error("Failed to compile regex for rule", zap.String("rule_id", string(rule.ID)), zap.String("pattern", rule.Pattern), zap.Error(err))
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err))
continue
}
rule.regex = regex
m.ruleCache.Set(rule.ID, regex) // Cache the compiled regex
}
if _, ok := m.Rules[rule.Phase]; !ok {
m.Rules[rule.Phase] = []Rule{}
}
m.Rules[rule.Phase] = append(m.Rules[rule.Phase], rule)
totalRules++
// Merge valid rules from the file into the temporary loadedRules map
for phase, rules := range fileRules {
loadedRules[phase] = append(loadedRules[phase], rules...)
}
if len(invalidRulesInFile) > 0 {
m.logger.Warn("Some rules failed validation", zap.String("file", path), zap.Strings("invalid_rules", invalidRulesInFile))
allInvalidRules = append(allInvalidRules, invalidRulesInFile...)
}
m.logger.Info("Rules loaded", zap.String("file", path), zap.Int("total_rules", len(rules)), zap.Int("invalid_rules", len(invalidRulesInFile)))
totalRules += len(fileRules[1]) + len(fileRules[2]) + len(fileRules[3]) + len(fileRules[4]) // Update total rule count
}
m.Rules = loadedRules // Atomically update m.Rules after loading all files
if len(invalidFiles) > 0 {
m.logger.Warn("Some rule files could not be loaded", zap.Strings("invalid_files", invalidFiles))
m.logger.Error("Failed to load rule files", zap.Strings("files", invalidFiles)) // Error level for file loading failures
}
if len(allInvalidRules) > 0 {
m.logger.Warn("Some rules across files failed validation", zap.Strings("invalid_rules", allInvalidRules))
m.logger.Warn("Validation errors in rules", zap.Strings("errors", allInvalidRules)) // More specific log message - "errors" field
}
if totalRules == 0 && len(invalidFiles) > 0 {
if totalRules == 0 && len(paths) > 0 { // Only return error if paths were provided
return fmt.Errorf("no valid rules were loaded from any file")
} else if totalRules == 0 && len(paths) == 0 {
m.logger.Warn("No rule files specified, WAF will run without rules.") // Warn if no rule files and no rules loaded
}
m.logger.Debug("Rules loaded successfully", zap.Int("total_rules", totalRules))
m.logger.Info("WAF rules loaded successfully", zap.Int("total_rules", totalRules))
return nil
}
// loadRulesFromFile loads and validates rules from a single file.
func (m *Middleware) loadRulesFromFile(path string, ruleIDs map[string]bool) (validRules map[int][]Rule, invalidRules []string, err error) {
validRules = make(map[int][]Rule)
var fileInvalidRules []string
content, err := os.ReadFile(path)
if err != nil {
return nil, nil, fmt.Errorf("failed to read rule file: %w", err)
}
var rules []Rule
if err := json.Unmarshal(content, &rules); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal rules: %w", err)
}
// Sort rules by priority (higher priority first)
sort.Slice(rules, func(i, j int) bool {
return rules[i].Priority > rules[j].Priority
})
for i, rule := range rules {
if err := validateRule(&rule); err != nil {
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Rule at index %d: %v", i, err))
continue
}
if _, exists := ruleIDs[string(rule.ID)]; exists {
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i))
continue
}
ruleIDs[string(rule.ID)] = true // Track rule IDs to prevent duplicates
// RuleCache handling (compile and cache regex)
if cachedRegex, exists := m.ruleCache.Get(rule.ID); exists {
rule.regex = cachedRegex
} else {
compiledRegex, err := regexp.Compile(rule.Pattern)
if err != nil {
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err))
continue
}
rule.regex = compiledRegex
m.ruleCache.Set(rule.ID, compiledRegex) // Cache regex
}
if _, ok := validRules[rule.Phase]; !ok {
validRules[rule.Phase] = []Rule{}
}
validRules[rule.Phase] = append(validRules[rule.Phase], rule)
}
return validRules, fileInvalidRules, nil
}

38
tor.go
View File

@@ -1,7 +1,7 @@
// tor.go
package caddywaf
import (
"fmt" // Import fmt for improved error formatting
"io"
"net/http"
"os"
@@ -32,7 +32,7 @@ func (t *TorConfig) Provision(ctx caddy.Context) error {
t.logger = ctx.Logger()
if t.Enabled {
if err := t.updateTorExitNodes(); err != nil {
return err
return fmt.Errorf("provisioning tor: %w", err) // Improved error wrapping
}
go t.scheduleUpdates()
}
@@ -41,21 +41,27 @@ func (t *TorConfig) Provision(ctx caddy.Context) error {
// updateTorExitNodes fetches the latest Tor exit nodes and updates the IP blacklist.
func (t *TorConfig) updateTorExitNodes() error {
t.logger.Debug("Updating Tor exit nodes...") // Debug log at start of update
resp, err := http.Get(torExitNodeURL)
if err != nil {
return err
return fmt.Errorf("http get failed for %s: %w", torExitNodeURL, err) // Improved error message with URL
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("http get returned status %s for %s", resp.Status, torExitNodeURL) // Check for non-200 status
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
return fmt.Errorf("failed to read response body from %s: %w", torExitNodeURL, err) // Improved error message with URL
}
torIPs := strings.Split(string(data), "\n")
existingIPs, err := t.readExistingBlacklist()
if err != nil {
return err
return fmt.Errorf("failed to read existing blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
}
// Merge and deduplicate IPs
@@ -65,15 +71,15 @@ func (t *TorConfig) updateTorExitNodes() error {
// Write updated blacklist to file
if err := t.writeBlacklist(uniqueIPs); err != nil {
return err
return fmt.Errorf("failed to write updated blacklist to file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
}
t.lastUpdated = time.Now()
t.logger.Info("Updated Tor exit nodes in IP blacklist", zap.Int("count", len(uniqueIPs)))
t.logger.Info("Tor exit nodes updated", zap.Int("count", len(uniqueIPs))) // Improved log message
t.logger.Debug("Tor exit node update completed successfully") // Debug log at end of update
return nil
}
// scheduleUpdates periodically updates the Tor exit node list.
// scheduleUpdates periodically updates the Tor exit node list.
func (t *TorConfig) scheduleUpdates() {
interval, err := time.ParseDuration(t.UpdateInterval)
@@ -96,13 +102,13 @@ func (t *TorConfig) scheduleUpdates() {
// Use for range to iterate over the ticker channel
for range ticker.C {
if err := t.updateTorExitNodes(); err != nil {
if updateErr := t.updateTorExitNodes(); updateErr != nil { // Renamed err to updateErr for clarity
if t.RetryOnFailure {
t.logger.Error("Failed to update Tor exit nodes, retrying shortly", zap.Error(err))
t.logger.Error("Failed to update Tor exit nodes, retrying shortly", zap.Error(updateErr)) // Use updateErr
time.Sleep(retryInterval)
continue
} else {
t.logger.Error("Failed to update Tor exit nodes, will retry at next scheduled interval", zap.Error(err))
t.logger.Error("Failed to update Tor exit nodes, will retry at next scheduled interval", zap.Error(updateErr)) // Use updateErr
}
}
}
@@ -113,9 +119,10 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) {
data, err := os.ReadFile(t.TORIPBlacklistFile)
if err != nil {
if os.IsNotExist(err) {
t.logger.Debug("Blacklist file does not exist, assuming empty list", zap.String("path", t.TORIPBlacklistFile)) // Debug log for non-existent file
return nil, nil
}
return nil, err
return nil, fmt.Errorf("failed to read IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
}
return strings.Split(string(data), "\n"), nil
}
@@ -123,7 +130,12 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) {
// writeBlacklist writes the updated IP blacklist to the file.
func (t *TorConfig) writeBlacklist(ips []string) error {
data := strings.Join(ips, "\n")
return os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0644)
err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0644)
if err != nil {
return fmt.Errorf("failed to write IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
}
t.logger.Debug("Blacklist file updated", zap.String("path", t.TORIPBlacklistFile), zap.Int("entry_count", len(ips))) // Debug log for file update
return nil
}
// unique removes duplicate entries from a slice of strings.