mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
Minor code improvements.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -13,3 +13,4 @@ testdata/rules.json
|
||||
log.json
|
||||
validation.log
|
||||
caddy-waf.DS_Store
|
||||
vendor
|
||||
73
blacklist.go
73
blacklist.go
@@ -1,6 +1,7 @@
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
"bufio" // Optimized reading
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -25,18 +26,22 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
|
||||
if bl.logger == nil {
|
||||
bl.logger = zap.NewNop()
|
||||
}
|
||||
bl.logger.Debug("Loading DNS blacklist from file", zap.String("file", path))
|
||||
bl.logger.Debug("Loading DNS blacklist", zap.String("path", path)) // Improved log message
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
bl.logger.Warn("Failed to read DNS blacklist file", zap.String("file", path), zap.Error(err))
|
||||
return fmt.Errorf("failed to read DNS blacklist file: %w", err)
|
||||
bl.logger.Warn("Failed to open DNS blacklist file", zap.String("path", path), zap.Error(err)) // Path instead of file for consistency
|
||||
return fmt.Errorf("failed to open DNS blacklist file: %w", err) // More accurate error message
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
scanner := bufio.NewScanner(file)
|
||||
validEntries := 0
|
||||
totalLines := 0 // Initialize totalLines
|
||||
|
||||
for _, line := range lines {
|
||||
for scanner.Scan() {
|
||||
totalLines++
|
||||
line := scanner.Text()
|
||||
line = strings.ToLower(strings.TrimSpace(line))
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue // Skip empty lines and comments
|
||||
@@ -45,10 +50,15 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
|
||||
validEntries++
|
||||
}
|
||||
|
||||
bl.logger.Info("DNS blacklist loaded successfully",
|
||||
zap.String("file", path),
|
||||
if err := scanner.Err(); err != nil {
|
||||
bl.logger.Error("Error reading DNS blacklist file", zap.String("path", path), zap.Error(err)) // More specific error log
|
||||
return fmt.Errorf("error reading DNS blacklist file: %w", err)
|
||||
}
|
||||
|
||||
bl.logger.Info("DNS blacklist loaded", // Improved log message
|
||||
zap.String("path", path),
|
||||
zap.Int("valid_entries", validEntries),
|
||||
zap.Int("total_lines", len(lines)),
|
||||
zap.Int("total_lines", totalLines), // Use totalLines which is correctly counted
|
||||
)
|
||||
return nil
|
||||
}
|
||||
@@ -77,14 +87,14 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if _, exists := m.dnsBlacklist[normalizedHost]; exists {
|
||||
m.logger.Info("Host is blacklisted",
|
||||
m.logger.Debug("DNS blacklist hit", // More concise log message, debug level
|
||||
zap.String("host", host),
|
||||
zap.String("blacklisted_domain", normalizedHost),
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
m.logger.Debug("Host is not blacklisted", zap.String("host", host))
|
||||
m.logger.Debug("DNS blacklist miss", zap.String("host", host)) // More concise log message, debug level
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -92,9 +102,9 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
|
||||
func extractIP(remoteAddr string, logger *zap.Logger) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to extract IP from remote address, using full address",
|
||||
logger.Debug("Using full remote address as IP", // More descriptive debug log
|
||||
zap.String("remoteAddr", remoteAddr),
|
||||
zap.Error(err),
|
||||
zap.Error(err), // Keep error for debugging
|
||||
)
|
||||
return remoteAddr // Assume the input is already an IP address
|
||||
}
|
||||
@@ -106,18 +116,22 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
if bl.logger == nil {
|
||||
bl.logger = zap.NewNop()
|
||||
}
|
||||
bl.logger.Debug("Loading IP blacklist from file", zap.String("file", path))
|
||||
bl.logger.Debug("Loading IP blacklist", zap.String("path", path)) // Improved log message
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
bl.logger.Warn("Failed to read IP blacklist file", zap.String("file", path), zap.Error(err))
|
||||
return fmt.Errorf("failed to read IP blacklist file: %w", err)
|
||||
bl.logger.Warn("Failed to open IP blacklist file", zap.String("path", path), zap.Error(err)) // Path instead of file for consistency
|
||||
return fmt.Errorf("failed to open IP blacklist file: %w", err) // More accurate error message
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
scanner := bufio.NewScanner(file)
|
||||
validEntries := 0
|
||||
totalLines := 0 // Initialize totalLines
|
||||
|
||||
for i, line := range lines {
|
||||
for scanner.Scan() {
|
||||
totalLines++
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue // Skip empty lines and comments
|
||||
@@ -127,7 +141,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
// Valid CIDR range
|
||||
ipBlacklist[line] = struct{}{}
|
||||
validEntries++
|
||||
bl.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line))
|
||||
bl.logger.Debug("Added CIDR to IP blacklist", zap.String("cidr", line)) // More specific debug log
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -135,21 +149,26 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
|
||||
// Valid IP address
|
||||
ipBlacklist[line] = struct{}{}
|
||||
validEntries++
|
||||
bl.logger.Debug("Added IP to blacklist", zap.String("ip", line))
|
||||
bl.logger.Debug("Added IP to IP blacklist", zap.String("ip", line)) // More specific debug log
|
||||
continue
|
||||
}
|
||||
|
||||
bl.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping",
|
||||
zap.String("file", path),
|
||||
zap.Int("line", i+1),
|
||||
bl.logger.Warn("Invalid IP/CIDR entry in blacklist file", // More concise warning message
|
||||
zap.String("path", path),
|
||||
zap.Int("line", totalLines), // Use totalLines which is correctly counted
|
||||
zap.String("entry", line),
|
||||
)
|
||||
}
|
||||
|
||||
bl.logger.Info("IP blacklist loaded successfully",
|
||||
zap.String("file", path),
|
||||
if scanner.Err() != nil {
|
||||
bl.logger.Error("Error reading IP blacklist file", zap.String("path", path), zap.Error(scanner.Err())) // More specific error log
|
||||
return fmt.Errorf("error reading IP blacklist file: %w", scanner.Err())
|
||||
}
|
||||
|
||||
bl.logger.Info("IP blacklist loaded", // Improved log message
|
||||
zap.String("path", path),
|
||||
zap.Int("valid_entries", validEntries),
|
||||
zap.Int("total_lines", len(lines)),
|
||||
zap.Int("total_lines", totalLines), // Use totalLines which is correctly counted
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
272
caddywaf_test.go
272
caddywaf_test.go
@@ -84,7 +84,7 @@ func TestLoadDNSBlacklistFromFile_InvalidFile(t *testing.T) {
|
||||
dnsBlacklist := make(map[string]struct{})
|
||||
err := bl.LoadDNSBlacklistFromFile("nonexistent.txt", dnsBlacklist)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read DNS blacklist file")
|
||||
assert.Contains(t, err.Error(), "failed to open DNS blacklist file") // Updated error message
|
||||
}
|
||||
|
||||
// TestLoadIPBlacklistFromFile tests loading IP addresses and CIDR ranges from a file.
|
||||
@@ -136,7 +136,7 @@ func TestLoadIPBlacklistFromFile_InvalidFile(t *testing.T) {
|
||||
ipBlacklist := make(map[string]struct{})
|
||||
err := bl.LoadIPBlacklistFromFile("nonexistent.txt", ipBlacklist)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read IP blacklist file")
|
||||
assert.Contains(t, err.Error(), "failed to open IP blacklist file") // Updated error message
|
||||
}
|
||||
|
||||
// TestIsDNSBlacklisted tests checking if a host is blacklisted.
|
||||
@@ -400,9 +400,10 @@ func TestParseCountryBlock(t *testing.T) {
|
||||
t.Fatal("Failed to advance to the first directive")
|
||||
}
|
||||
|
||||
err := cl.parseCountryBlock(d, m, true)
|
||||
handler := cl.parseCountryBlockDirective(true) // Get the directive handler
|
||||
err := handler(d, m) // Execute the handler
|
||||
if err != nil {
|
||||
t.Fatalf("parseCountryBlock failed: %v", err)
|
||||
t.Fatalf("parseCountryBlockDirective failed: %v", err)
|
||||
}
|
||||
|
||||
if !m.CountryBlock.Enabled {
|
||||
@@ -459,14 +460,32 @@ func TestParseBlacklistFile(t *testing.T) {
|
||||
t.Fatal("Failed to advance to the first directive")
|
||||
}
|
||||
|
||||
err := cl.parseBlacklistFile(d, m, true)
|
||||
handler := cl.parseBlacklistFileDirective(true) // Get the directive handler for IP blacklist
|
||||
err := handler(d, m) // Execute the handler
|
||||
if err != nil {
|
||||
t.Fatalf("parseBlacklistFile failed: %v", err)
|
||||
t.Fatalf("parseBlacklistFileDirective failed: %v", err)
|
||||
}
|
||||
|
||||
if m.IPBlacklistFile != tmpFile {
|
||||
t.Errorf("Expected IP blacklist file to be '%s', got '%s'", tmpFile, m.IPBlacklistFile)
|
||||
}
|
||||
|
||||
// Test dns_blacklist_file
|
||||
tmpDNSFile := filepath.Join(tmpDir, "dns_blacklist.txt")
|
||||
d = caddyfile.NewTestDispenser(`
|
||||
dns_blacklist_file ` + tmpDNSFile + `
|
||||
`)
|
||||
if !d.Next() {
|
||||
t.Fatal("Failed to advance to the dns_blacklist_file directive")
|
||||
}
|
||||
handler = cl.parseBlacklistFileDirective(false) // Get handler for DNS blacklist
|
||||
err = handler(d, m)
|
||||
if err != nil {
|
||||
t.Fatalf("parseBlacklistFileDirective for dns failed: %v", err)
|
||||
}
|
||||
if m.DNSBlacklistFile != tmpDNSFile {
|
||||
t.Errorf("Expected DNS blacklist file to be '%s', got '%s'", tmpDNSFile, m.DNSBlacklistFile)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseAnomalyThreshold tests the parseAnomalyThreshold function.
|
||||
@@ -1008,10 +1027,10 @@ func TestValidateRule(t *testing.T) {
|
||||
func TestProcessRuleMatch(t *testing.T) {
|
||||
logger := newMockLogger()
|
||||
middleware := &Middleware{
|
||||
logger: logger.Logger, // Use the embedded *zap.Logger
|
||||
logger: logger.Logger,
|
||||
AnomalyThreshold: 10,
|
||||
ruleHits: sync.Map{}, // Use sync.Map directly
|
||||
muMetrics: sync.RWMutex{}, // Use sync.RWMutex directly
|
||||
ruleHits: sync.Map{},
|
||||
muMetrics: sync.RWMutex{},
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
@@ -1028,12 +1047,22 @@ func TestProcessRuleMatch(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id" // Or generate a UUID if needed: uuid.New().String()
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
|
||||
// Create a new request with the context
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Create a ResponseRecorder
|
||||
w := NewResponseRecorder(httptest.NewRecorder())
|
||||
|
||||
// Test blocking rule
|
||||
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
|
||||
assert.False(t, shouldContinue)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equal(t, http.StatusForbidden, w.StatusCode())
|
||||
assert.True(t, state.Blocked)
|
||||
assert.Equal(t, 5, state.TotalScore)
|
||||
|
||||
@@ -1043,7 +1072,8 @@ func TestProcessRuleMatch(t *testing.T) {
|
||||
TotalScore: 0,
|
||||
ResponseWritten: false,
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
// Re-create a ResponseRecorder for the second test
|
||||
w = NewResponseRecorder(httptest.NewRecorder())
|
||||
shouldContinue = middleware.processRuleMatch(w, req, rule, "value", state)
|
||||
assert.True(t, shouldContinue)
|
||||
assert.False(t, state.Blocked)
|
||||
@@ -1122,6 +1152,13 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-highscore" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Test blocking rule with high score
|
||||
@@ -1522,6 +1559,13 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("User-Agent", "nikto")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-nikto" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1536,14 +1580,6 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
}
|
||||
|
||||
// func TestBlockedRequestPhase1_IPBlacklist(t *testing.T) {
|
||||
// func TestBlockedRequestPhase1_IPBlacklist(t *testing.T) {
|
||||
// func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
|
||||
// func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
// func TestBlockedRequestPhase3_ResponseHeaderRegex(t *testing.T) {
|
||||
// func TestBlockedRequestPhase4_ResponseBodyRegex(t *testing.T) {
|
||||
|
||||
func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
middleware := &Middleware{
|
||||
@@ -1575,6 +1611,13 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headerregex" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1620,6 +1663,13 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headerspecificvalue" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1666,6 +1716,13 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header1", "good-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headercomma" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1712,6 +1769,12 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
|
||||
req.Header.Set("User-Agent", "good-user")
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-combined"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1758,6 +1821,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("User-Agent", "good-user")
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-nomatch"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1802,6 +1871,13 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headerempty"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1845,6 +1921,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headermissing"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1891,6 +1974,13 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
|
||||
|
||||
// Create a context and add logID to it
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headercomplex"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1937,6 +2027,13 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "bad-user-agent")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-multimatch" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -1982,6 +2079,13 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "good-user-agent")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-multinomatch" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx) // Create new request with context
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2026,6 +2130,13 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com?param1=good-param-value¶m2=good-value", nil)
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-urlparamnomatch" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2080,6 +2191,12 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
|
||||
req.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-multiplerules" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2095,6 +2212,13 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
|
||||
req2 := httptest.NewRequest("GET", "http://good-host.com", nil)
|
||||
req2.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2 as well!
|
||||
ctx2 := context.Background() // New context for the second request!
|
||||
logID2 := "test-log-id-multiplerules2"
|
||||
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
state2 := &WAFState{}
|
||||
|
||||
@@ -2107,7 +2231,6 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
assert.True(t, state2.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
|
||||
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Rules", "Response body should contain 'Blocked by Multiple Rules'")
|
||||
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
@@ -2147,6 +2270,13 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
}(), // Simulate a request with bad body
|
||||
)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregex" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2161,8 +2291,6 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
assert.Contains(t, w.Body.String(), "Blocked by Body Regex", "Response body should contain 'Blocked by Body Regex'")
|
||||
}
|
||||
|
||||
// new
|
||||
|
||||
func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
middleware := &Middleware{
|
||||
@@ -2200,6 +2328,13 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
||||
}(), // Simulate a request with JSON body
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregexjson" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2247,6 +2382,13 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
|
||||
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregexform"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2298,6 +2440,13 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
|
||||
}(),
|
||||
)
|
||||
req.Header.Set("Content-Type", "text/plain") // Setting content type
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregexspecific"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2350,6 +2499,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
||||
)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregexnomatch"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2364,8 +2519,6 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
||||
assert.Empty(t, w.Body.String(), "Response body should be empty")
|
||||
}
|
||||
|
||||
////////
|
||||
|
||||
func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
middleware := &Middleware{
|
||||
@@ -2412,6 +2565,13 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-bodyregexmultipartnomatch"
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2425,6 +2585,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
assert.Empty(t, w.Body.String(), "Response body should be empty")
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
middleware := &Middleware{
|
||||
@@ -2719,6 +2880,13 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headercaseinsensitive" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2765,6 +2933,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Custom-Header1", "bad-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req
|
||||
ctx := context.Background()
|
||||
logID := "test-log-id-headermultimatch" // Unique log ID for this test
|
||||
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
@@ -2781,6 +2956,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
req2 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req2.Header.Set("X-Custom-Header1", "good-value")
|
||||
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2
|
||||
ctx2 := context.Background()
|
||||
logID2 := "test-log-id-headermultimatch2" // Unique log ID for this test
|
||||
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
state2 := &WAFState{}
|
||||
|
||||
@@ -2797,6 +2979,13 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
req3 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req3.Header.Set("X-Custom-Header1", "good-value")
|
||||
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req3
|
||||
ctx3 := context.Background()
|
||||
logID3 := "test-log-id-headermultimatch3" // Unique log ID for this test
|
||||
ctx3 = context.WithValue(ctx3, ContextKeyLogId("logID"), logID3)
|
||||
req3 = req3.WithContext(ctx3)
|
||||
|
||||
w3 := httptest.NewRecorder()
|
||||
state3 := &WAFState{}
|
||||
|
||||
@@ -2808,7 +2997,6 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
|
||||
assert.False(t, state3.Blocked, "Request should not be blocked when none headers match")
|
||||
assert.Equal(t, http.StatusOK, w3.Code, "Expected status code 200")
|
||||
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
||||
@@ -3087,38 +3275,6 @@ func TestParseAnomalyThreshold_Invalid(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "invalid syntax")
|
||||
}
|
||||
|
||||
func TestProcessRuleMatch_NoAction(t *testing.T) {
|
||||
logger := newMockLogger()
|
||||
middleware := &Middleware{
|
||||
logger: logger.Logger, // Use the embedded *zap.Logger
|
||||
AnomalyThreshold: 10,
|
||||
ruleHits: sync.Map{},
|
||||
muMetrics: sync.RWMutex{}, // Use sync.RWMutex directly
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ID: "no_action_rule",
|
||||
Targets: []string{"header"},
|
||||
Description: "Test no-action rule",
|
||||
Score: 5,
|
||||
Action: "", // No action
|
||||
}
|
||||
|
||||
state := &WAFState{
|
||||
TotalScore: 0,
|
||||
ResponseWritten: false,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
|
||||
assert.True(t, shouldContinue)
|
||||
assert.Equal(t, 5, state.TotalScore)
|
||||
assert.False(t, state.Blocked)
|
||||
|
||||
}
|
||||
|
||||
func TestExtractValue_HeaderCaseInsensitive(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
|
||||
562
config.go
562
config.go
@@ -23,11 +23,11 @@ func NewConfigLoader(logger *zap.Logger) *ConfigLoader {
|
||||
// parseMetricsEndpoint parses the metrics_endpoint directive.
|
||||
func (cl *ConfigLoader) parseMetricsEndpoint(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing value for metrics_endpoint", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
m.MetricsEndpoint = d.Val()
|
||||
cl.logger.Debug("Metrics endpoint set from Caddyfile",
|
||||
zap.String("metrics_endpoint", m.MetricsEndpoint),
|
||||
cl.logger.Debug("Metrics endpoint configured", // Improved log message
|
||||
zap.String("endpoint", m.MetricsEndpoint), // More descriptive log field
|
||||
zap.String("file", d.File()),
|
||||
zap.Int("line", d.Line()),
|
||||
)
|
||||
@@ -37,11 +37,11 @@ func (cl *ConfigLoader) parseMetricsEndpoint(d *caddyfile.Dispenser, m *Middlewa
|
||||
// parseLogPath parses the log_path directive.
|
||||
func (cl *ConfigLoader) parseLogPath(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing value for log_path", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
m.LogFilePath = d.Val()
|
||||
cl.logger.Debug("Log path set from Caddyfile",
|
||||
zap.String("log_path", m.LogFilePath),
|
||||
cl.logger.Debug("Log file path configured", // Improved log message
|
||||
zap.String("path", m.LogFilePath), // More descriptive log field
|
||||
zap.String("file", d.File()),
|
||||
zap.Int("line", d.Line()),
|
||||
)
|
||||
@@ -51,7 +51,7 @@ func (cl *ConfigLoader) parseLogPath(d *caddyfile.Dispenser, m *Middleware) erro
|
||||
// parseRateLimit parses the rate_limit directive.
|
||||
func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if m.RateLimit.Requests > 0 {
|
||||
return d.Err("rate_limit specified multiple times")
|
||||
return d.Err("rate_limit directive already specified") // Improved error message
|
||||
}
|
||||
|
||||
rl := RateLimit{
|
||||
@@ -62,36 +62,28 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
|
||||
}
|
||||
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
switch d.Val() {
|
||||
option := d.Val()
|
||||
switch option {
|
||||
case "requests":
|
||||
if !d.NextArg() {
|
||||
return d.Err("requests requires an argument")
|
||||
}
|
||||
reqs, err := strconv.Atoi(d.Val())
|
||||
reqs, err := cl.parsePositiveInteger(d, "requests")
|
||||
if err != nil {
|
||||
return d.Errf("invalid requests value: %v", err)
|
||||
return err
|
||||
}
|
||||
rl.Requests = reqs
|
||||
cl.logger.Debug("Rate limit requests set", zap.Int("requests", rl.Requests))
|
||||
|
||||
case "window":
|
||||
if !d.NextArg() {
|
||||
return d.Err("window requires an argument")
|
||||
}
|
||||
window, err := time.ParseDuration(d.Val())
|
||||
window, err := cl.parseDuration(d, "window")
|
||||
if err != nil {
|
||||
return d.Errf("invalid window value: %v", err)
|
||||
return err
|
||||
}
|
||||
rl.Window = window
|
||||
cl.logger.Debug("Rate limit window set", zap.Duration("window", rl.Window))
|
||||
|
||||
case "cleanup_interval":
|
||||
if !d.NextArg() {
|
||||
return d.Err("cleanup_interval requires an argument")
|
||||
}
|
||||
interval, err := time.ParseDuration(d.Val())
|
||||
interval, err := cl.parseDuration(d, "cleanup_interval")
|
||||
if err != nil {
|
||||
return d.Errf("invalid cleanup_interval value: %v", err)
|
||||
return err
|
||||
}
|
||||
rl.CleanupInterval = interval
|
||||
cl.logger.Debug("Rate limit cleanup interval set", zap.Duration("cleanup_interval", rl.CleanupInterval))
|
||||
@@ -99,29 +91,26 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
|
||||
case "paths":
|
||||
paths := d.RemainingArgs()
|
||||
if len(paths) == 0 {
|
||||
return d.Err("paths requires at least one argument")
|
||||
return d.Err("paths option requires at least one path") // Improved error message
|
||||
}
|
||||
rl.Paths = paths
|
||||
cl.logger.Debug("Rate limit paths set", zap.Strings("paths", rl.Paths))
|
||||
cl.logger.Debug("Rate limit paths configured", zap.Strings("paths", rl.Paths)) // Improved log message
|
||||
|
||||
case "match_all_paths":
|
||||
if !d.NextArg() {
|
||||
return d.Err("match_all_paths requires an argument")
|
||||
}
|
||||
matchAllPaths, err := strconv.ParseBool(d.Val())
|
||||
matchAllPaths, err := cl.parseBool(d, "match_all_paths")
|
||||
if err != nil {
|
||||
return d.Errf("invalid match_all_paths value: %v", err)
|
||||
return err
|
||||
}
|
||||
rl.MatchAllPaths = matchAllPaths
|
||||
cl.logger.Debug("Rate limit match_all_paths set", zap.Bool("match_all_paths", rl.MatchAllPaths))
|
||||
|
||||
default:
|
||||
return d.Errf("invalid rate_limit option: %s", d.Val())
|
||||
return d.Errf("unrecognized rate_limit option: %s", option) // More specific error message
|
||||
}
|
||||
}
|
||||
|
||||
if rl.Requests <= 0 || rl.Window <= 0 {
|
||||
return d.Err("requests and window must be greater than zero")
|
||||
return d.Err("requests and window in rate_limit must be positive values") // Improved error message
|
||||
}
|
||||
|
||||
m.RateLimit = rl
|
||||
@@ -129,6 +118,7 @@ func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalCaddyfile is the primary parsing function for the middleware configuration.
|
||||
func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if cl.logger == nil {
|
||||
cl.logger = zap.NewNop()
|
||||
@@ -143,7 +133,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
|
||||
RetryInterval: "5m", // Default retry interval
|
||||
}
|
||||
|
||||
cl.logger.Debug("WAF UnmarshalCaddyfile Called", zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
cl.logger.Debug("Parsing WAF configuration", zap.String("file", d.File()), zap.Int("line", d.Line())) // Improved log message
|
||||
|
||||
// Set default values
|
||||
m.LogSeverity = "info"
|
||||
@@ -154,154 +144,58 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
|
||||
m.LogFilePath = "debug.json"
|
||||
m.RedactSensitiveData = false
|
||||
|
||||
directiveHandlers := map[string]func(d *caddyfile.Dispenser, m *Middleware) error{
|
||||
"metrics_endpoint": cl.parseMetricsEndpoint,
|
||||
"log_path": cl.parseLogPath,
|
||||
"rate_limit": cl.parseRateLimit,
|
||||
"block_countries": cl.parseCountryBlockDirective(true), // Use directive-specific helper
|
||||
"whitelist_countries": cl.parseCountryBlockDirective(false), // Use directive-specific helper
|
||||
"log_severity": cl.parseLogSeverity,
|
||||
"log_json": cl.parseLogJSON,
|
||||
"rule_file": cl.parseRuleFile,
|
||||
"ip_blacklist_file": cl.parseBlacklistFileDirective(true), // Use directive-specific helper
|
||||
"dns_blacklist_file": cl.parseBlacklistFileDirective(false), // Use directive-specific helper
|
||||
"anomaly_threshold": cl.parseAnomalyThreshold,
|
||||
"custom_response": cl.parseCustomResponse,
|
||||
"redact_sensitive_data": cl.parseRedactSensitiveData,
|
||||
"tor": cl.parseTorBlock,
|
||||
}
|
||||
|
||||
for d.Next() {
|
||||
for d.NextBlock(0) {
|
||||
directive := d.Val()
|
||||
cl.logger.Debug("Processing directive", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
|
||||
switch directive {
|
||||
case "metrics_endpoint":
|
||||
if err := cl.parseMetricsEndpoint(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "log_path":
|
||||
if err := cl.parseLogPath(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "rate_limit":
|
||||
if err := cl.parseRateLimit(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "block_countries":
|
||||
if err := cl.parseCountryBlock(d, m, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "whitelist_countries":
|
||||
if err := cl.parseCountryBlock(d, m, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "log_severity":
|
||||
if err := cl.parseLogSeverity(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "log_json":
|
||||
m.LogJSON = true
|
||||
cl.logger.Debug("Log JSON enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
|
||||
case "rule_file":
|
||||
if err := cl.parseRuleFile(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "ip_blacklist_file":
|
||||
if err := cl.parseBlacklistFile(d, m, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "dns_blacklist_file":
|
||||
if err := cl.parseBlacklistFile(d, m, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "anomaly_threshold":
|
||||
if err := cl.parseAnomalyThreshold(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "custom_response":
|
||||
if err := cl.parseCustomResponse(d, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "redact_sensitive_data":
|
||||
m.RedactSensitiveData = true
|
||||
cl.logger.Debug("Redact sensitive data enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
|
||||
case "tor":
|
||||
// Handle the tor block as a nested directive
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
switch d.Val() {
|
||||
case "enabled":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
enabled, err := strconv.ParseBool(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid enabled value: %v", err)
|
||||
}
|
||||
m.Tor.Enabled = enabled
|
||||
cl.logger.Debug("Tor blocking enabled", zap.Bool("enabled", m.Tor.Enabled))
|
||||
|
||||
case "tor_ip_blacklist_file": // Updated field name
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.TORIPBlacklistFile = d.Val() // Updated field name
|
||||
cl.logger.Debug("Tor IP blacklist file set", zap.String("tor_ip_blacklist_file", m.Tor.TORIPBlacklistFile))
|
||||
|
||||
case "update_interval":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.UpdateInterval = d.Val()
|
||||
cl.logger.Debug("Tor update interval set", zap.String("update_interval", m.Tor.UpdateInterval))
|
||||
|
||||
case "retry_on_failure":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
retryOnFailure, err := strconv.ParseBool(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid retry_on_failure value: %v", err)
|
||||
}
|
||||
m.Tor.RetryOnFailure = retryOnFailure
|
||||
cl.logger.Debug("Tor retry on failure set", zap.Bool("retry_on_failure", m.Tor.RetryOnFailure))
|
||||
|
||||
case "retry_interval":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.RetryInterval = d.Val()
|
||||
cl.logger.Debug("Tor retry interval set", zap.String("retry_interval", m.Tor.RetryInterval))
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized tor subdirective: %s", d.Val())
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
cl.logger.Warn("WAF Unrecognized SubDirective", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
return fmt.Errorf("file: %s, line: %d: unrecognized subdirective: %s", d.File(), d.Line(), d.Val())
|
||||
handler, exists := directiveHandlers[directive]
|
||||
if !exists {
|
||||
cl.logger.Warn("Unrecognized WAF directive", zap.String("directive", directive), zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
return fmt.Errorf("file: %s, line: %d: unrecognized directive: %s", d.File(), d.Line(), directive)
|
||||
}
|
||||
} // Closing brace for the outer for loop
|
||||
if err := handler(d, m); err != nil {
|
||||
return err // Handler already provides context in error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.RuleFiles) == 0 {
|
||||
return fmt.Errorf("no rule files specified")
|
||||
return fmt.Errorf("no rule files specified for WAF") // More direct error
|
||||
}
|
||||
|
||||
cl.logger.Debug("WAF configuration parsed successfully", zap.String("file", d.File())) // Success log
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseRuleFile(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing path for rule_file", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
ruleFile := d.Val()
|
||||
m.RuleFiles = append(m.RuleFiles, ruleFile)
|
||||
|
||||
if m.MetricsEndpoint != "" && !strings.HasPrefix(m.MetricsEndpoint, "/") {
|
||||
return fmt.Errorf("metrics_endpoint must start with '/'")
|
||||
return fmt.Errorf("metrics_endpoint must start with a leading '/'") // Improved error message
|
||||
}
|
||||
|
||||
cl.logger.Info("WAF Loading Rule File",
|
||||
zap.String("file", ruleFile),
|
||||
cl.logger.Info("Loading WAF rule file", // Improved log message
|
||||
zap.String("path", ruleFile), // More descriptive log field
|
||||
zap.String("caddyfile", d.File()),
|
||||
zap.Int("line", d.Line()),
|
||||
)
|
||||
@@ -314,58 +208,49 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
|
||||
}
|
||||
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing status code for custom_response", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
statusCode, err := strconv.Atoi(d.Val())
|
||||
statusCode, err := cl.parseStatusCode(d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file: %s, line: %d: invalid status code for custom_response: %v", d.File(), d.Line(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
if m.CustomResponses[statusCode].Headers == nil {
|
||||
m.CustomResponses[statusCode] = CustomBlockResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
if _, exists := m.CustomResponses[statusCode]; exists {
|
||||
return d.Errf("custom_response for status code %d already defined", statusCode) // Prevent duplicate status codes
|
||||
}
|
||||
|
||||
resp := CustomBlockResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing content_type or file path for custom_response", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
contentTypeOrFile := d.Val()
|
||||
|
||||
if d.NextArg() {
|
||||
filePath := d.Val()
|
||||
content, err := os.ReadFile(filePath)
|
||||
content, err := cl.readResponseFromFile(d, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file: %s, line: %d: could not read custom response file '%s': %v", d.File(), d.Line(), filePath, err)
|
||||
}
|
||||
m.CustomResponses[statusCode] = CustomBlockResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": contentTypeOrFile,
|
||||
},
|
||||
Body: string(content),
|
||||
return err
|
||||
}
|
||||
resp.Headers["Content-Type"] = contentTypeOrFile
|
||||
resp.Body = content
|
||||
cl.logger.Debug("Loaded custom response from file",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("file", filePath),
|
||||
zap.String("file_path", filePath), // More descriptive log field
|
||||
zap.String("content_type", contentTypeOrFile),
|
||||
zap.String("caddyfile", d.File()),
|
||||
zap.Int("line", d.Line()),
|
||||
)
|
||||
} else {
|
||||
remaining := d.RemainingArgs()
|
||||
if len(remaining) == 0 {
|
||||
return fmt.Errorf("file: %s, line: %d: missing custom response body", d.File(), d.Line())
|
||||
}
|
||||
body := strings.Join(remaining, " ")
|
||||
m.CustomResponses[statusCode] = CustomBlockResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": contentTypeOrFile,
|
||||
},
|
||||
Body: body,
|
||||
body, err := cl.parseInlineResponseBody(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Headers["Content-Type"] = contentTypeOrFile
|
||||
resp.Body = body
|
||||
cl.logger.Debug("Loaded inline custom response",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("content_type", contentTypeOrFile),
|
||||
@@ -374,95 +259,268 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
|
||||
zap.Int("line", d.Line()),
|
||||
)
|
||||
}
|
||||
m.CustomResponses[statusCode] = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseCountryBlock(d *caddyfile.Dispenser, m *Middleware, isBlock bool) error {
|
||||
target := &m.CountryBlock
|
||||
if !isBlock {
|
||||
target = &m.CountryWhitelist
|
||||
}
|
||||
target.Enabled = true
|
||||
// parseCountryBlockDirective returns a closure to handle block_countries and whitelist_countries directives.
|
||||
func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
return func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
target := &m.CountryBlock
|
||||
directiveName := "block_countries"
|
||||
if !isBlock {
|
||||
target = &m.CountryWhitelist
|
||||
directiveName = "whitelist_countries"
|
||||
}
|
||||
target.Enabled = true
|
||||
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing GeoIP DB path", d.File(), d.Line())
|
||||
}
|
||||
target.GeoIPDBPath = d.Val()
|
||||
target.CountryList = []string{}
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
target.GeoIPDBPath = d.Val()
|
||||
target.CountryList = []string{}
|
||||
|
||||
for d.NextArg() {
|
||||
country := strings.ToUpper(d.Val())
|
||||
target.CountryList = append(target.CountryList, country)
|
||||
}
|
||||
for d.NextArg() {
|
||||
country := strings.ToUpper(d.Val())
|
||||
target.CountryList = append(target.CountryList, country)
|
||||
}
|
||||
|
||||
cl.logger.Debug("Country list configured",
|
||||
zap.Bool("block_mode", isBlock),
|
||||
zap.Strings("countries", target.CountryList),
|
||||
zap.String("geoip_db_path", target.GeoIPDBPath),
|
||||
zap.String("file", d.File()), zap.Int("line", d.Line()),
|
||||
)
|
||||
return nil
|
||||
cl.logger.Debug("Country list configured", // Improved log message
|
||||
zap.String("directive", directiveName), // Log directive name
|
||||
zap.Bool("block_mode", isBlock),
|
||||
zap.Strings("countries", target.CountryList),
|
||||
zap.String("geoip_db_path", target.GeoIPDBPath),
|
||||
zap.String("file", d.File()), zap.Int("line", d.Line()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseLogSeverity(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing value for log_severity", d.File(), d.Line())
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
m.LogSeverity = d.Val()
|
||||
cl.logger.Debug("Log severity set",
|
||||
severity := d.Val()
|
||||
validSeverities := []string{"debug", "info", "warn", "error"} // Define valid severities
|
||||
isValid := false
|
||||
for _, valid := range validSeverities {
|
||||
if severity == valid {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isValid {
|
||||
return d.Errf("invalid log_severity value '%s', must be one of: %s", severity, strings.Join(validSeverities, ", ")) // Improved error message
|
||||
}
|
||||
|
||||
m.LogSeverity = severity
|
||||
cl.logger.Debug("Log severity set", // Improved log message
|
||||
zap.String("severity", m.LogSeverity),
|
||||
zap.String("file", d.File()), zap.Int("line", d.Line()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseBlacklistFile(d *caddyfile.Dispenser, m *Middleware, isIP bool) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing blacklist file path", d.File(), d.Line())
|
||||
}
|
||||
|
||||
filePath := d.Val()
|
||||
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
// Create an empty file if it doesn't exist
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file: %s, line: %d: could not create blacklist file '%s': %v", d.File(), d.Line(), filePath, err)
|
||||
// parseBlacklistFileDirective returns a closure to handle ip_blacklist_file and dns_blacklist_file directives.
|
||||
func (cl *ConfigLoader) parseBlacklistFileDirective(isIP bool) func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
return func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr() // More specific error type
|
||||
}
|
||||
filePath := d.Val()
|
||||
directiveName := "dns_blacklist_file"
|
||||
if isIP {
|
||||
directiveName = "ip_blacklist_file"
|
||||
}
|
||||
file.Close()
|
||||
|
||||
cl.logger.Warn("Blacklist file does not exist, created an empty file",
|
||||
zap.String("file", filePath),
|
||||
zap.Bool("is_ip", isIP),
|
||||
if err := cl.ensureBlacklistFileExists(d, filePath, isIP); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign the file path to the appropriate field
|
||||
if isIP {
|
||||
m.IPBlacklistFile = filePath
|
||||
} else {
|
||||
m.DNSBlacklistFile = filePath
|
||||
}
|
||||
|
||||
cl.logger.Info("Blacklist file configured", // Improved log message
|
||||
zap.String("directive", directiveName), // Log directive name
|
||||
zap.String("path", filePath), // More descriptive log field
|
||||
zap.Bool("is_ip_type", isIP),
|
||||
)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("file: %s, line: %d: could not access blacklist file '%s': %v", d.File(), d.Line(), filePath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assign the file path to the appropriate field
|
||||
if isIP {
|
||||
m.IPBlacklistFile = filePath
|
||||
} else {
|
||||
m.DNSBlacklistFile = filePath
|
||||
}
|
||||
|
||||
cl.logger.Info("Blacklist file loaded",
|
||||
zap.String("file", filePath),
|
||||
zap.Bool("is_ip", isIP),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseAnomalyThreshold(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
if !d.NextArg() {
|
||||
return fmt.Errorf("file: %s, line: %d: missing threshold value", d.File(), d.Line())
|
||||
}
|
||||
threshold, err := strconv.Atoi(d.Val())
|
||||
threshold, err := cl.parsePositiveInteger(d, "anomaly_threshold")
|
||||
if err != nil {
|
||||
return fmt.Errorf("file: %s, line: %d: invalid threshold: %v", d.File(), d.Line(), err)
|
||||
return err
|
||||
}
|
||||
m.AnomalyThreshold = threshold
|
||||
cl.logger.Debug("Anomaly threshold set", zap.Int("threshold", threshold))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseTorBlock(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
subDirective := d.Val()
|
||||
switch subDirective {
|
||||
case "enabled":
|
||||
enabled, err := cl.parseBool(d, "tor enabled") // More descriptive arg name
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Tor.Enabled = enabled
|
||||
cl.logger.Debug("Tor blocking enabled", zap.Bool("enabled", m.Tor.Enabled))
|
||||
|
||||
case "tor_ip_blacklist_file": // Updated field name
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.TORIPBlacklistFile = d.Val() // Updated field name
|
||||
cl.logger.Debug("Tor IP blacklist file set", zap.String("file_path", m.Tor.TORIPBlacklistFile)) // More descriptive log field
|
||||
|
||||
case "update_interval":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.UpdateInterval = d.Val()
|
||||
cl.logger.Debug("Tor update interval set", zap.String("interval", m.Tor.UpdateInterval)) // More descriptive log field
|
||||
|
||||
case "retry_on_failure":
|
||||
retryOnFailure, err := cl.parseBool(d, "tor retry_on_failure") // More descriptive arg name
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Tor.RetryOnFailure = retryOnFailure
|
||||
cl.logger.Debug("Tor retry on failure set", zap.Bool("retry_on_failure", m.Tor.RetryOnFailure))
|
||||
|
||||
case "retry_interval":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.Tor.RetryInterval = d.Val()
|
||||
cl.logger.Debug("Tor retry interval set", zap.String("interval", m.Tor.RetryInterval)) // More descriptive log field
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized tor subdirective: %s", subDirective) // More specific error message
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseLogJSON(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
m.LogJSON = true
|
||||
cl.logger.Debug("Log JSON enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) parseRedactSensitiveData(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
m.RedactSensitiveData = true
|
||||
cl.logger.Debug("Redact sensitive data enabled", zap.String("file", d.File()), zap.Int("line", d.Line()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Helper Functions ---
|
||||
|
||||
// parsePositiveInteger parses a directive argument as a positive integer.
|
||||
func (cl *ConfigLoader) parsePositiveInteger(d *caddyfile.Dispenser, directiveName string) (int, error) {
|
||||
if !d.NextArg() {
|
||||
return 0, d.ArgErr() // More specific error type
|
||||
}
|
||||
valStr := d.Val()
|
||||
val, err := strconv.Atoi(valStr)
|
||||
if err != nil {
|
||||
return 0, d.Errf("invalid %s value '%s': %v", directiveName, valStr, err) // More descriptive error message
|
||||
}
|
||||
if val <= 0 {
|
||||
return 0, d.Errf("%s must be a positive integer, but got '%d'", directiveName, val) // More descriptive error message
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// parseDuration parses a directive argument as a time duration.
|
||||
func (cl *ConfigLoader) parseDuration(d *caddyfile.Dispenser, directiveName string) (time.Duration, error) {
|
||||
if !d.NextArg() {
|
||||
return 0, d.ArgErr() // More specific error type
|
||||
}
|
||||
durationStr := d.Val()
|
||||
duration, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
return 0, d.Errf("invalid %s value '%s': %v", directiveName, durationStr, err) // More descriptive error message
|
||||
}
|
||||
return duration, nil
|
||||
}
|
||||
|
||||
// parseBool parses a directive argument as a boolean.
|
||||
func (cl *ConfigLoader) parseBool(d *caddyfile.Dispenser, directiveName string) (bool, error) {
|
||||
if !d.NextArg() {
|
||||
return false, d.ArgErr() // More specific error type
|
||||
}
|
||||
boolStr := d.Val()
|
||||
val, err := strconv.ParseBool(boolStr)
|
||||
if err != nil {
|
||||
return false, d.Errf("invalid %s value '%s': %v, must be 'true' or 'false'", directiveName, boolStr, err) // More descriptive error message
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// parseStatusCode parses a directive argument as an HTTP status code.
|
||||
func (cl *ConfigLoader) parseStatusCode(d *caddyfile.Dispenser) (int, error) {
|
||||
statusCodeStr := d.Val()
|
||||
statusCode, err := strconv.Atoi(statusCodeStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("file: %s, line: %d: invalid status code '%s': %v", d.File(), d.Line(), statusCodeStr, err) // Include file and line in error
|
||||
}
|
||||
if statusCode < 100 || statusCode > 599 {
|
||||
return 0, fmt.Errorf("file: %s, line: %d: status code '%d' out of range, must be between 100 and 599", d.File(), d.Line(), statusCode) // Include file and line in error
|
||||
}
|
||||
return statusCode, nil
|
||||
}
|
||||
|
||||
// readResponseFromFile reads custom response body from file.
|
||||
func (cl *ConfigLoader) readResponseFromFile(d *caddyfile.Dispenser, filePath string) (string, error) {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("file: %s, line: %d: could not read custom response file '%s': %v", d.File(), d.Line(), filePath, err) // Include file and line in error
|
||||
}
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// parseInlineResponseBody parses inline custom response body.
|
||||
func (cl *ConfigLoader) parseInlineResponseBody(d *caddyfile.Dispenser) (string, error) {
|
||||
remaining := d.RemainingArgs()
|
||||
if len(remaining) == 0 {
|
||||
return "", fmt.Errorf("file: %s, line: %d: missing custom response body", d.File(), d.Line()) // Include file and line in error
|
||||
}
|
||||
return strings.Join(remaining, " "), nil
|
||||
}
|
||||
|
||||
// ensureBlacklistFileExists checks if blacklist file exists and creates it if not.
|
||||
func (cl *ConfigLoader) ensureBlacklistFileExists(d *caddyfile.Dispenser, filePath string, isIP bool) error {
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
fileType := "DNS"
|
||||
if isIP {
|
||||
fileType = "IP"
|
||||
}
|
||||
return fmt.Errorf("file: %s, line: %d: could not create %s blacklist file '%s': %v", d.File(), d.Line(), fileType, filePath, err) // More descriptive error
|
||||
}
|
||||
file.Close()
|
||||
fileType := "DNS"
|
||||
if isIP {
|
||||
fileType = "IP"
|
||||
}
|
||||
cl.logger.Warn("%s blacklist file does not exist, created an empty file", zap.String("type", fileType), zap.String("path", filePath)) // Improved log
|
||||
} else if err != nil {
|
||||
fileType := "DNS"
|
||||
if isIP {
|
||||
fileType = "IP"
|
||||
}
|
||||
return fmt.Errorf("file: %s, line: %d: could not access %s blacklist file '%s': %v", d.File(), d.Line(), fileType, filePath, err) // Improved error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
138
geoip.go
138
geoip.go
@@ -44,20 +44,13 @@ func (gh *GeoIPHandler) LoadGeoIPDatabase(path string) (*maxminddb.Reader, error
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("no GeoIP database path specified")
|
||||
}
|
||||
gh.logger.Debug("Attempting to load GeoIP database",
|
||||
zap.String("path", path),
|
||||
)
|
||||
gh.logger.Debug("Loading GeoIP database", zap.String("path", path)) // Slightly improved log message
|
||||
reader, err := maxminddb.Open(path)
|
||||
if err != nil {
|
||||
gh.logger.Error("Failed to load GeoIP database",
|
||||
zap.String("path", path),
|
||||
zap.Error(err),
|
||||
)
|
||||
gh.logger.Error("Failed to load GeoIP database", zap.String("path", path), zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to load GeoIP database: %w", err)
|
||||
}
|
||||
gh.logger.Info("GeoIP database loaded successfully",
|
||||
zap.String("path", path),
|
||||
)
|
||||
gh.logger.Info("GeoIP database loaded", zap.String("path", path)) // Slightly improved log message
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
@@ -69,6 +62,7 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
|
||||
if geoIP == nil {
|
||||
return false, fmt.Errorf("geoip database not loaded")
|
||||
}
|
||||
|
||||
ip, err := gh.extractIPFromRemoteAddr(remoteAddr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -76,21 +70,16 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
gh.logger.Error("invalid IP address", zap.String("ip", ip))
|
||||
gh.logger.Warn("Invalid IP address", zap.String("ip", ip)) // Changed Error to Warn for invalid IP, as it might be client issue.
|
||||
return false, fmt.Errorf("invalid IP address: %s", ip)
|
||||
}
|
||||
|
||||
// Easy: Add caching of GeoIP lookups for performance.
|
||||
// Check cache first
|
||||
if gh.geoIPCache != nil {
|
||||
gh.geoIPCacheMutex.RLock()
|
||||
if record, ok := gh.geoIPCache[ip]; ok {
|
||||
gh.geoIPCacheMutex.RUnlock()
|
||||
for _, country := range countryList {
|
||||
if strings.EqualFold(record.Country.ISOCode, country) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
return gh.isCountryInRecord(record, countryList), nil // Helper function for country check
|
||||
}
|
||||
gh.geoIPCacheMutex.RUnlock()
|
||||
}
|
||||
@@ -98,52 +87,16 @@ func (gh *GeoIPHandler) IsCountryInList(remoteAddr string, countryList []string,
|
||||
var record GeoIPRecord
|
||||
err = geoIP.Lookup(parsedIP, &record)
|
||||
if err != nil {
|
||||
gh.logger.Error("geoip lookup failed", zap.String("ip", ip), zap.Error(err))
|
||||
|
||||
// Critical: Handle cases where the GeoIP database lookup fails more gracefully.
|
||||
switch gh.geoIPLookupFallbackBehavior {
|
||||
case "default":
|
||||
// Log and treat as not in the list
|
||||
return false, nil
|
||||
case "none":
|
||||
return false, fmt.Errorf("geoip lookup failed: %w", err)
|
||||
case "": // No fallback configured, maintain existing behavior
|
||||
return false, fmt.Errorf("geoip lookup failed: %w", err)
|
||||
default:
|
||||
// Configurable fallback country code
|
||||
for _, country := range countryList {
|
||||
if strings.EqualFold(gh.geoIPLookupFallbackBehavior, country) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
gh.logger.Error("GeoIP lookup failed", zap.String("ip", ip), zap.Error(err))
|
||||
return gh.handleGeoIPLookupError(err, countryList) // Helper function for error handling
|
||||
}
|
||||
|
||||
// Easy: Add caching of GeoIP lookups for performance.
|
||||
// Cache the record
|
||||
if gh.geoIPCache != nil {
|
||||
gh.geoIPCacheMutex.Lock()
|
||||
gh.geoIPCache[ip] = record
|
||||
gh.geoIPCacheMutex.Unlock()
|
||||
|
||||
// Invalidate cache entry after TTL (basic implementation)
|
||||
if gh.geoIPCacheTTL > 0 {
|
||||
time.AfterFunc(gh.geoIPCacheTTL, func() {
|
||||
gh.geoIPCacheMutex.Lock()
|
||||
delete(gh.geoIPCache, ip)
|
||||
gh.geoIPCacheMutex.Unlock()
|
||||
})
|
||||
}
|
||||
gh.cacheGeoIPRecord(ip, record) // Helper function for caching
|
||||
}
|
||||
|
||||
for _, country := range countryList {
|
||||
if strings.EqualFold(record.Country.ISOCode, country) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return gh.isCountryInRecord(record, countryList), nil // Helper function for country check
|
||||
}
|
||||
|
||||
// getCountryCode extracts the country code for logging purposes
|
||||
@@ -162,11 +115,28 @@ func (gh *GeoIPHandler) GetCountryCode(remoteAddr string, geoIP *maxminddb.Reade
|
||||
if parsedIP == nil {
|
||||
return "N/A"
|
||||
}
|
||||
|
||||
// Check cache first for GetCountryCode as well for consistency and potential perf gain
|
||||
if gh.geoIPCache != nil {
|
||||
gh.geoIPCacheMutex.RLock()
|
||||
if record, ok := gh.geoIPCache[ip]; ok {
|
||||
gh.geoIPCacheMutex.RUnlock()
|
||||
return record.Country.ISOCode
|
||||
}
|
||||
gh.geoIPCacheMutex.RUnlock()
|
||||
}
|
||||
|
||||
var record GeoIPRecord
|
||||
err = geoIP.Lookup(parsedIP, &record)
|
||||
if err != nil {
|
||||
return "N/A"
|
||||
return "N/A" // Simply return "N/A" on lookup error for GetCountryCode
|
||||
}
|
||||
|
||||
// Cache the record for GetCountryCode as well
|
||||
if gh.geoIPCache != nil {
|
||||
gh.cacheGeoIPRecord(ip, record)
|
||||
}
|
||||
|
||||
return record.Country.ISOCode
|
||||
}
|
||||
|
||||
@@ -178,3 +148,51 @@ func (gh *GeoIPHandler) extractIPFromRemoteAddr(remoteAddr string) (string, erro
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Helper function to check if the country in the record is in the country list
|
||||
func (gh *GeoIPHandler) isCountryInRecord(record GeoIPRecord, countryList []string) bool {
|
||||
for _, country := range countryList {
|
||||
if strings.EqualFold(record.Country.ISOCode, country) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Helper function to handle GeoIP lookup errors based on fallback behavior
|
||||
func (gh *GeoIPHandler) handleGeoIPLookupError(err error, countryList []string) (bool, error) {
|
||||
switch gh.geoIPLookupFallbackBehavior {
|
||||
case "default":
|
||||
// Log at debug level as it's a fallback scenario, not necessarily an error for the overall operation
|
||||
gh.logger.Debug("GeoIP lookup failed, using default fallback (not in list)", zap.Error(err))
|
||||
return false, nil // Treat as not in the list
|
||||
case "none":
|
||||
return false, fmt.Errorf("geoip lookup failed: %w", err) // Propagate the error
|
||||
case "": // No fallback configured, maintain existing behavior
|
||||
return false, fmt.Errorf("geoip lookup failed: %w", err) // Propagate the error
|
||||
default: // Configurable fallback country code
|
||||
gh.logger.Debug("GeoIP lookup failed, using configured fallback", zap.String("fallbackCountry", gh.geoIPLookupFallbackBehavior), zap.Error(err))
|
||||
for _, country := range countryList {
|
||||
if strings.EqualFold(gh.geoIPLookupFallbackBehavior, country) {
|
||||
return true, nil // Treat as in the list for the fallback country
|
||||
}
|
||||
}
|
||||
return false, nil // Treat as not in the list if fallback country is not in the list
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to cache GeoIP record
|
||||
func (gh *GeoIPHandler) cacheGeoIPRecord(ip string, record GeoIPRecord) {
|
||||
gh.geoIPCacheMutex.Lock()
|
||||
gh.geoIPCache[ip] = record
|
||||
gh.geoIPCacheMutex.Unlock()
|
||||
|
||||
// Invalidate cache entry after TTL
|
||||
if gh.geoIPCacheTTL > 0 {
|
||||
time.AfterFunc(gh.geoIPCacheTTL, func() {
|
||||
gh.geoIPCacheMutex.Lock()
|
||||
delete(gh.geoIPCache, ip)
|
||||
gh.geoIPCacheMutex.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
241
handler.go
241
handler.go
@@ -1,133 +1,188 @@
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// ==================== Request Handling and Logic ====================
|
||||
|
||||
// ServeHTTP implements caddyhttp.Handler.
|
||||
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
// Generate a unique log ID for the request
|
||||
logID := uuid.New().String()
|
||||
|
||||
// Log the request with common fields
|
||||
m.logRequest(zapcore.InfoLevel, "WAF evaluation started", r, zap.String("log_id", logID))
|
||||
m.logRequestStart(r, logID)
|
||||
|
||||
// Use the custom type as the key
|
||||
ctx := context.WithValue(r.Context(), ContextKeyLogId("logID"), logID)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Increment total requests
|
||||
m.incrementTotalRequestsMetric()
|
||||
|
||||
state := m.initializeWAFState()
|
||||
|
||||
if m.isPhaseBlocked(w, r, 1, state) { // Phase 1: Pre-request checks
|
||||
return nil // Request blocked in Phase 1, short-circuit
|
||||
}
|
||||
|
||||
if m.isPhaseBlocked(w, r, 2, state) { // Phase 2: Request analysis
|
||||
return nil // Request blocked in Phase 2, short-circuit
|
||||
}
|
||||
|
||||
recorder := NewResponseRecorder(w)
|
||||
err := next.ServeHTTP(recorder, r)
|
||||
|
||||
if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) { // Phase 3: Response Header analysis
|
||||
return nil // Request blocked in Phase 3, short-circuit
|
||||
}
|
||||
|
||||
m.handleResponseBodyPhase(recorder, r, state) // Phase 4: Response Body analysis (if not blocked yet)
|
||||
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
m.writeCustomResponse(recorder, state.StatusCode)
|
||||
return nil // Short circuit if blocked in any phase after headers
|
||||
}
|
||||
|
||||
m.incrementAllowedRequestsMetric()
|
||||
|
||||
if m.isMetricsRequest(r) {
|
||||
return m.handleMetricsRequest(w, r) // Handle metrics requests separately
|
||||
}
|
||||
|
||||
// If not blocked, copy recorded response back to original writer
|
||||
if !state.Blocked {
|
||||
// Copy headers from recorder to original writer
|
||||
header := w.Header()
|
||||
for key, values := range recorder.Header() {
|
||||
for _, value := range values {
|
||||
header.Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(recorder.StatusCode()) // Set status code from recorder
|
||||
|
||||
// Write body from recorder to original writer
|
||||
_, writeErr := w.Write(recorder.body.Bytes())
|
||||
if writeErr != nil {
|
||||
m.logger.Error("Failed to write recorded response body to client", zap.Error(writeErr), zap.String("log_id", logID))
|
||||
// We should still return the original error from next.ServeHTTP if available, or a new error if writing body failed and next didn't return error.
|
||||
if err == nil {
|
||||
return writeErr // If original handler didn't error, return body write error.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.logRequestCompletion(logID, state)
|
||||
|
||||
return err // Return any error from the next handler
|
||||
}
|
||||
|
||||
// isPhaseBlocked encapsulates the phase handling and blocking check logic.
|
||||
func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) bool {
|
||||
m.handlePhase(w, r, phase, state)
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
w.WriteHeader(state.StatusCode)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic.
|
||||
func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool {
|
||||
m.handlePhase(recorder, r, phase, state)
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
recorder.WriteHeader(state.StatusCode)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// logRequestStart logs the start of WAF evaluation.
|
||||
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
|
||||
m.logger.Info("WAF request evaluation started",
|
||||
zap.String("log_id", logID),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("uri", r.RequestURI),
|
||||
zap.String("remote_address", r.RemoteAddr),
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
)
|
||||
}
|
||||
|
||||
// incrementTotalRequestsMetric increments the total requests metric.
|
||||
func (m *Middleware) incrementTotalRequestsMetric() {
|
||||
m.muMetrics.Lock()
|
||||
m.totalRequests++
|
||||
m.muMetrics.Unlock()
|
||||
}
|
||||
|
||||
// Initialize WAF state for the request
|
||||
state := &WAFState{
|
||||
// initializeWAFState initializes the WAF state.
|
||||
func (m *Middleware) initializeWAFState() *WAFState {
|
||||
return &WAFState{
|
||||
TotalScore: 0,
|
||||
Blocked: false,
|
||||
StatusCode: http.StatusOK,
|
||||
ResponseWritten: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the request details
|
||||
m.logger.Info("WAF evaluation started",
|
||||
zap.String("log_id", logID),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("source_ip", r.RemoteAddr),
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
zap.String("query_params", r.URL.RawQuery),
|
||||
)
|
||||
// handleResponseBodyPhase processes Phase 4 (response body).
|
||||
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
|
||||
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
|
||||
body := recorder.BodyString()
|
||||
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string)))
|
||||
|
||||
// Handle Phase 1: Pre-request evaluation
|
||||
m.handlePhase(w, r, 1, state)
|
||||
if state.Blocked {
|
||||
m.muMetrics.Lock()
|
||||
m.blockedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
w.WriteHeader(state.StatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle Phase 2: Request evaluation
|
||||
m.handlePhase(w, r, 2, state)
|
||||
if state.Blocked {
|
||||
m.muMetrics.Lock()
|
||||
m.blockedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
w.WriteHeader(state.StatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Capture the response using a response recorder
|
||||
recorder := &responseRecorder{ResponseWriter: w, body: new(bytes.Buffer)}
|
||||
err := next.ServeHTTP(recorder, r)
|
||||
|
||||
// Handle Phase 3: Response headers evaluation
|
||||
m.handlePhase(recorder, r, 3, state)
|
||||
if state.Blocked {
|
||||
m.muMetrics.Lock()
|
||||
m.blockedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
recorder.WriteHeader(state.StatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle Phase 4: Response body evaluation
|
||||
if recorder.body != nil {
|
||||
body := recorder.body.String()
|
||||
m.logger.Debug("Response body captured", zap.String("body", body))
|
||||
|
||||
for _, rule := range m.Rules[4] {
|
||||
if rule.regex.MatchString(body) {
|
||||
m.processRuleMatch(recorder, r, &rule, body, state)
|
||||
if state.Blocked {
|
||||
m.muMetrics.Lock()
|
||||
m.blockedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
recorder.WriteHeader(state.StatusCode)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the response body if no blocking occurred
|
||||
if !state.ResponseWritten {
|
||||
_, writeErr := w.Write(recorder.body.Bytes())
|
||||
if writeErr != nil {
|
||||
m.logger.Error("Failed to write response body", zap.Error(writeErr))
|
||||
for _, rule := range m.Rules[4] {
|
||||
if rule.regex.MatchString(body) {
|
||||
m.processRuleMatch(recorder, r, &rule, body, state)
|
||||
if state.Blocked {
|
||||
m.incrementBlockedRequestsMetric()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment allowed requests if not blocked
|
||||
if !state.Blocked {
|
||||
m.muMetrics.Lock()
|
||||
m.allowedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
// incrementBlockedRequestsMetric increments the blocked requests metric.
|
||||
func (m *Middleware) incrementBlockedRequestsMetric() {
|
||||
m.muMetrics.Lock()
|
||||
m.blockedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
}
|
||||
|
||||
// incrementAllowedRequestsMetric increments the allowed requests metric.
|
||||
func (m *Middleware) incrementAllowedRequestsMetric() {
|
||||
m.muMetrics.Lock()
|
||||
m.allowedRequests++
|
||||
m.muMetrics.Unlock()
|
||||
}
|
||||
|
||||
// isMetricsRequest checks if it's a metrics request.
|
||||
func (m *Middleware) isMetricsRequest(r *http.Request) bool {
|
||||
return m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint
|
||||
}
|
||||
|
||||
// writeCustomResponse writes a custom response.
|
||||
func (m *Middleware) writeCustomResponse(w http.ResponseWriter, statusCode int) {
|
||||
if customResponse, ok := m.CustomResponses[statusCode]; ok {
|
||||
for key, value := range customResponse.Headers {
|
||||
w.Header().Set(key, value)
|
||||
}
|
||||
w.WriteHeader(customResponse.StatusCode)
|
||||
if _, err := w.Write([]byte(customResponse.Body)); err != nil {
|
||||
m.logger.Error("Failed to write custom response body", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle metrics endpoint requests
|
||||
if m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint {
|
||||
return m.handleMetricsRequest(w, r)
|
||||
}
|
||||
|
||||
// Log the completion of WAF evaluation
|
||||
m.logger.Info("WAF evaluation complete",
|
||||
// logRequestCompletion logs the completion of WAF evaluation.
|
||||
func (m *Middleware) logRequestCompletion(logID string, state *WAFState) {
|
||||
m.logger.Info("WAF request evaluation completed",
|
||||
zap.String("log_id", logID),
|
||||
zap.Int("total_score", state.TotalScore),
|
||||
zap.Bool("blocked", state.Blocked),
|
||||
zap.Int("status_code", state.StatusCode),
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ==================== Utility Functions ====================
|
||||
|
||||
197
logging.go
197
logging.go
@@ -1,4 +1,3 @@
|
||||
// logging.go
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
@@ -11,155 +10,104 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
const unknownValue = "unknown" // Define a constant for "unknown" values
|
||||
|
||||
var sensitiveKeys = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive keys for redaction as package variable
|
||||
|
||||
func (m *Middleware) logRequest(level zapcore.Level, msg string, r *http.Request, fields ...zap.Field) {
|
||||
if m.logger == nil {
|
||||
return
|
||||
if m.logger == nil || level < m.logLevel {
|
||||
return // Early return if logger is nil or level is below threshold
|
||||
}
|
||||
|
||||
// Skip logging if the level is below the threshold
|
||||
if level < m.logLevel {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract log ID or generate a new one
|
||||
var logID string
|
||||
var newFields []zap.Field
|
||||
foundLogID := false
|
||||
|
||||
for _, field := range fields {
|
||||
if field.Key == "log_id" {
|
||||
logID = field.String
|
||||
foundLogID = true
|
||||
} else {
|
||||
newFields = append(newFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
if !foundLogID {
|
||||
logID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Append log_id explicitly to newFields
|
||||
newFields = append(newFields, zap.String("log_id", logID))
|
||||
|
||||
// Attach common request metadata only if not already set
|
||||
commonFields := m.getCommonLogFields(r, newFields)
|
||||
for _, commonField := range commonFields {
|
||||
fieldExists := false
|
||||
for _, existingField := range newFields {
|
||||
if existingField.Key == commonField.Key {
|
||||
fieldExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !fieldExists {
|
||||
newFields = append(newFields, commonField)
|
||||
}
|
||||
}
|
||||
allFields := m.prepareLogFields(r, fields) // Prepare all fields in one function - Corrected call: Removed 'level'
|
||||
|
||||
// Send the log entry to the buffered channel
|
||||
select {
|
||||
case m.logChan <- LogEntry{Level: level, Message: msg, Fields: newFields}:
|
||||
case m.logChan <- LogEntry{Level: level, Message: msg, Fields: allFields}:
|
||||
// Log entry successfully queued
|
||||
default:
|
||||
// If the channel is full, fall back to synchronous logging
|
||||
m.logger.Warn("Log buffer full, falling back to synchronous logging",
|
||||
zap.String("message", msg),
|
||||
zap.Any("fields", newFields),
|
||||
zap.Any("fields", allFields),
|
||||
)
|
||||
m.logger.Log(level, msg, newFields...)
|
||||
m.logger.Log(level, msg, allFields...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) getCommonLogFields(r *http.Request, fields []zap.Field) []zap.Field {
|
||||
// Debug: Print the incoming fields
|
||||
m.logger.Debug("Incoming fields to getCommonLogFields",
|
||||
zap.Any("fields", fields),
|
||||
)
|
||||
// prepareLogFields consolidates the logic for preparing log fields, including common fields and log_id.
|
||||
func (m *Middleware) prepareLogFields(r *http.Request, fields []zap.Field) []zap.Field { // Corrected signature: Removed 'level zapcore.Level'
|
||||
var logID string
|
||||
var customFields []zap.Field
|
||||
|
||||
// Extract or assign default values for metadata fields
|
||||
var sourceIP string
|
||||
var userAgent string
|
||||
var requestMethod string
|
||||
var requestPath string
|
||||
var queryParams string
|
||||
var statusCode int
|
||||
// Extract log_id if present, otherwise generate a new one
|
||||
logID, customFields = m.extractLogIDField(fields)
|
||||
if logID == "" {
|
||||
logID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Extract values from the incoming fields
|
||||
// Get common log fields and merge with custom fields, prioritizing custom fields in case of duplicates
|
||||
commonFields := m.getCommonLogFields(r)
|
||||
allFields := m.mergeFields(customFields, commonFields, zap.String("log_id", logID)) // Ensure log_id is always present
|
||||
|
||||
return allFields
|
||||
}
|
||||
|
||||
// extractLogIDField extracts the log_id from the given fields and returns it along with the remaining fields.
|
||||
func (m *Middleware) extractLogIDField(fields []zap.Field) (logID string, remainingFields []zap.Field) {
|
||||
for _, field := range fields {
|
||||
switch field.Key {
|
||||
case "source_ip":
|
||||
sourceIP = field.String
|
||||
case "user_agent":
|
||||
userAgent = field.String
|
||||
case "request_method":
|
||||
requestMethod = field.String
|
||||
case "request_path":
|
||||
requestPath = field.String
|
||||
case "query_params":
|
||||
queryParams = field.String
|
||||
case "status_code":
|
||||
statusCode = int(field.Integer)
|
||||
if field.Key == "log_id" {
|
||||
logID = field.String
|
||||
} else {
|
||||
remainingFields = append(remainingFields, field)
|
||||
}
|
||||
}
|
||||
return logID, remainingFields
|
||||
}
|
||||
|
||||
// mergeFields merges custom fields and common fields, with custom fields taking precedence and ensuring log_id is present.
|
||||
func (m *Middleware) mergeFields(customFields []zap.Field, commonFields []zap.Field, logIDField zap.Field) []zap.Field {
|
||||
mergedFields := make([]zap.Field, 0, len(customFields)+len(commonFields)+1) //预分配容量
|
||||
mergedFields = append(mergedFields, customFields...)
|
||||
|
||||
// Add common fields, skip if key already exists in custom fields
|
||||
for _, commonField := range commonFields {
|
||||
exists := false
|
||||
for _, customField := range customFields {
|
||||
if commonField.Key == customField.Key {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
mergedFields = append(mergedFields, commonField)
|
||||
}
|
||||
}
|
||||
|
||||
// If values are not provided in the fields, extract them from the request
|
||||
if sourceIP == "" && r != nil {
|
||||
mergedFields = append(mergedFields, logIDField) // Ensure log_id is always last or at least present
|
||||
return mergedFields
|
||||
}
|
||||
|
||||
func (m *Middleware) getCommonLogFields(r *http.Request) []zap.Field {
|
||||
sourceIP := unknownValue
|
||||
userAgent := unknownValue
|
||||
requestMethod := unknownValue
|
||||
requestPath := unknownValue
|
||||
queryParams := "" // Initialize to empty string, not "unknown" - More accurate for query params
|
||||
statusCode := 0 // Default status code is 0 if not explicitly set
|
||||
|
||||
if r != nil {
|
||||
sourceIP = r.RemoteAddr
|
||||
}
|
||||
if userAgent == "" && r != nil {
|
||||
userAgent = r.UserAgent()
|
||||
}
|
||||
if requestMethod == "" && r != nil {
|
||||
requestMethod = r.Method
|
||||
}
|
||||
if requestPath == "" && r != nil {
|
||||
requestPath = r.URL.Path
|
||||
}
|
||||
if queryParams == "" && r != nil {
|
||||
queryParams = r.URL.RawQuery
|
||||
}
|
||||
|
||||
// Debug: Print the extracted values
|
||||
m.logger.Debug("Extracted values in getCommonLogFields",
|
||||
zap.String("source_ip", sourceIP),
|
||||
zap.String("user_agent", userAgent),
|
||||
zap.String("request_method", requestMethod),
|
||||
zap.String("request_path", requestPath),
|
||||
zap.String("query_params", queryParams),
|
||||
zap.Int("status_code", statusCode),
|
||||
)
|
||||
|
||||
// Default values for missing fields
|
||||
if sourceIP == "" {
|
||||
sourceIP = "unknown"
|
||||
}
|
||||
if userAgent == "" {
|
||||
userAgent = "unknown"
|
||||
}
|
||||
if requestMethod == "" {
|
||||
requestMethod = "unknown"
|
||||
}
|
||||
if requestPath == "" {
|
||||
requestPath = "unknown"
|
||||
}
|
||||
|
||||
// Debug: Print the final values after applying defaults
|
||||
m.logger.Debug("Final values after applying defaults",
|
||||
zap.String("source_ip", sourceIP),
|
||||
zap.String("user_agent", userAgent),
|
||||
zap.String("request_method", requestMethod),
|
||||
zap.String("request_path", requestPath),
|
||||
zap.String("query_params", queryParams),
|
||||
zap.Int("status_code", statusCode),
|
||||
)
|
||||
|
||||
// Redact query parameters if required
|
||||
if m.RedactSensitiveData {
|
||||
queryParams = m.redactQueryParams(queryParams)
|
||||
}
|
||||
|
||||
// Construct and return common fields
|
||||
return []zap.Field{
|
||||
zap.String("source_ip", sourceIP),
|
||||
zap.String("user_agent", userAgent),
|
||||
@@ -167,7 +115,7 @@ func (m *Middleware) getCommonLogFields(r *http.Request, fields []zap.Field) []z
|
||||
zap.String("request_path", requestPath),
|
||||
zap.String("query_params", queryParams),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Time("timestamp", time.Now()), // Include a timestamp
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +130,7 @@ func (m *Middleware) redactQueryParams(queryParams string) string {
|
||||
keyValue := strings.SplitN(part, "=", 2)
|
||||
if len(keyValue) == 2 {
|
||||
key := strings.ToLower(keyValue[0])
|
||||
if strings.Contains(key, "password") || strings.Contains(key, "token") || strings.Contains(key, "apikey") || strings.Contains(key, "authorization") || strings.Contains(key, "secret") {
|
||||
if m.isSensitiveQueryParamKey(key) { // Use helper function for sensitive key check
|
||||
parts[i] = keyValue[0] + "=REDACTED"
|
||||
}
|
||||
}
|
||||
@@ -191,6 +139,15 @@ func (m *Middleware) redactQueryParams(queryParams string) string {
|
||||
return strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
func (m *Middleware) isSensitiveQueryParamKey(key string) bool {
|
||||
for _, sensitiveKey := range sensitiveKeys { // Use package level sensitiveKeys variable
|
||||
if strings.Contains(key, sensitiveKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func caddyTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
|
||||
enc.AppendString(t.Format("2006/01/02 15:04:05.000"))
|
||||
}
|
||||
|
||||
@@ -126,11 +126,11 @@ func (rl *RateLimiter) cleanupExpiredEntries() {
|
||||
// startCleanup starts the goroutine to periodically clean up expired entries.
|
||||
func (rl *RateLimiter) startCleanup() {
|
||||
go func() {
|
||||
// log.Println("[INFO] Starting rate limiter cleanup goroutine")
|
||||
// log.Println("[INFO] Starting rate limiter cleanup goroutine") <- Removed/commented out
|
||||
ticker := time.NewTicker(rl.config.CleanupInterval)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
// log.Println("[INFO] Rate limiter cleanup goroutine stopped")
|
||||
// log.Println("[INFO] Rate limiter cleanup goroutine stopped") <- Removed/commented out
|
||||
}()
|
||||
|
||||
for {
|
||||
|
||||
522
request.go
522
request.go
@@ -1,11 +1,11 @@
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -21,6 +21,37 @@ type RequestValueExtractor struct {
|
||||
// Define a custom type for context keys
|
||||
type ContextKeyLogId string
|
||||
|
||||
// Extraction Target Constants - Improved Readability and Maintainability
|
||||
const (
|
||||
TargetMethod = "METHOD"
|
||||
TargetRemoteIP = "REMOTE_IP"
|
||||
TargetProtocol = "PROTOCOL"
|
||||
TargetHost = "HOST"
|
||||
TargetArgs = "ARGS"
|
||||
TargetUserAgent = "USER_AGENT"
|
||||
TargetPath = "PATH"
|
||||
TargetURI = "URI"
|
||||
TargetBody = "BODY"
|
||||
TargetHeaders = "HEADERS" // Full request headers
|
||||
TargetRequestHeaders = "REQUEST_HEADERS" // Alias for full request headers
|
||||
TargetResponseHeaders = "RESPONSE_HEADERS"
|
||||
TargetResponseBody = "RESPONSE_BODY"
|
||||
TargetFileName = "FILE_NAME"
|
||||
TargetFileMIMEType = "FILE_MIME_TYPE"
|
||||
TargetCookies = "COOKIES" // All cookies
|
||||
TargetRequestCookies = "REQUEST_COOKIES" // Alias for all cookies (DUPLICATE - REMOVE ALIAS)
|
||||
TargetURLParamPrefix = "URL_PARAM:"
|
||||
TargetJSONPathPrefix = "JSON_PATH:"
|
||||
TargetContentType = "CONTENT_TYPE"
|
||||
TargetURL = "URL"
|
||||
TargetCookiesPrefix = "COOKIES:" // Dynamic cookie extraction prefix
|
||||
TargetHeadersPrefix = "HEADERS:" // Dynamic header extraction prefix
|
||||
TargetRequestHeadersPrefix = "REQUEST_HEADERS:" // Alias for dynamic header extraction prefix
|
||||
TargetResponseHeadersPrefix = "RESPONSE_HEADERS:" // Dynamic response header extraction prefix
|
||||
)
|
||||
|
||||
var sensitiveTargets = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive targets for redaction as package variable
|
||||
|
||||
// NewRequestValueExtractor creates a new RequestValueExtractor with a given logger
|
||||
func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool) *RequestValueExtractor {
|
||||
return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData}
|
||||
@@ -56,249 +87,84 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
|
||||
var unredactedValue string
|
||||
var err error
|
||||
|
||||
// Basic Cases (Keep as Before)
|
||||
switch {
|
||||
case target == "METHOD":
|
||||
unredactedValue = r.Method
|
||||
case target == "REMOTE_IP":
|
||||
unredactedValue = r.RemoteAddr
|
||||
case target == "PROTOCOL":
|
||||
unredactedValue = r.Proto
|
||||
case target == "HOST":
|
||||
unredactedValue = r.Host
|
||||
case target == "ARGS":
|
||||
if r.URL.RawQuery == "" {
|
||||
rve.logger.Debug("Query string is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("query string is empty for target: %s", target)
|
||||
}
|
||||
unredactedValue = r.URL.RawQuery
|
||||
case target == "USER_AGENT":
|
||||
unredactedValue = r.UserAgent()
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("User-Agent is empty", zap.String("target", target))
|
||||
}
|
||||
case target == "PATH":
|
||||
unredactedValue = r.URL.Path
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("Request path is empty", zap.String("target", target))
|
||||
}
|
||||
case target == "URI":
|
||||
unredactedValue = r.URL.RequestURI()
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("Request URI is empty", zap.String("target", target))
|
||||
}
|
||||
case target == "BODY":
|
||||
if r.Body == nil {
|
||||
rve.logger.Warn("Request body is nil", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is nil for target: %s", target)
|
||||
}
|
||||
if r.ContentLength == 0 {
|
||||
rve.logger.Debug("Request body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
var bodyBytes []byte
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
// Optimization: Use a map for target extraction logic
|
||||
extractionLogic := map[string]func() (string, error){
|
||||
TargetMethod: func() (string, error) { return r.Method, nil },
|
||||
TargetRemoteIP: func() (string, error) { return r.RemoteAddr, nil },
|
||||
TargetProtocol: func() (string, error) { return r.Proto, nil },
|
||||
TargetHost: func() (string, error) { return r.Host, nil },
|
||||
TargetArgs: func() (string, error) {
|
||||
return r.URL.RawQuery, rve.checkEmpty(r.URL.RawQuery, target, "Query string is empty")
|
||||
},
|
||||
TargetUserAgent: func() (string, error) {
|
||||
value := r.UserAgent()
|
||||
rve.logIfEmpty(value, target, "User-Agent is empty")
|
||||
return value, nil
|
||||
},
|
||||
TargetPath: func() (string, error) {
|
||||
value := r.URL.Path
|
||||
rve.logIfEmpty(value, target, "Request path is empty")
|
||||
return value, nil
|
||||
},
|
||||
TargetURI: func() (string, error) {
|
||||
value := r.URL.RequestURI()
|
||||
rve.logIfEmpty(value, target, "Request URI is empty")
|
||||
return value, nil
|
||||
},
|
||||
TargetBody: func() (string, error) { return rve.extractBody(r, target) }, // Separate body extraction
|
||||
TargetHeaders: func() (string, error) { return rve.extractHeaders(r.Header, "Request headers", target) }, // Helper for headers
|
||||
TargetRequestHeaders: func() (string, error) { return rve.extractHeaders(r.Header, "Request headers", target) }, // Alias
|
||||
TargetResponseHeaders: func() (string, error) { return rve.extractResponseHeaders(w, target) }, // Helper for response headers
|
||||
TargetResponseBody: func() (string, error) { return rve.extractResponseBody(w, target) }, // Helper for response body
|
||||
TargetFileName: func() (string, error) { return rve.extractFileName(r, target) }, // Helper for filename
|
||||
TargetFileMIMEType: func() (string, error) { return rve.extractFileMIMEType(r, target) }, // Helper for mime type
|
||||
TargetCookies: func() (string, error) { return rve.extractCookies(r.Cookies(), "No cookies found", target) }, // Helper for cookies
|
||||
TargetRequestCookies: func() (string, error) { return rve.extractCookies(r.Cookies(), "No cookies found", target) }, // Alias
|
||||
TargetContentType: func() (string, error) {
|
||||
return r.Header.Get("Content-Type"), rve.checkEmpty(r.Header.Get("Content-Type"), target, "Content-Type header not found")
|
||||
},
|
||||
TargetURL: func() (string, error) {
|
||||
return r.URL.String(), rve.checkEmpty(r.URL.String(), target, "URL could not be extracted")
|
||||
},
|
||||
}
|
||||
|
||||
if extractor, exists := extractionLogic[target]; exists {
|
||||
unredactedValue, err = extractor()
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
||||
return "", err // Return error from extractor
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for next read
|
||||
unredactedValue = string(bodyBytes)
|
||||
|
||||
// Full Header Dump (Request)
|
||||
case target == "HEADERS", target == "REQUEST_HEADERS":
|
||||
if len(r.Header) == 0 {
|
||||
rve.logger.Debug("Request headers are empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request headers are empty for target: %s", target)
|
||||
}
|
||||
headers := make([]string, 0)
|
||||
for name, values := range r.Header {
|
||||
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
|
||||
}
|
||||
unredactedValue = strings.Join(headers, "; ")
|
||||
|
||||
// Response Headers (Phase 3)
|
||||
case target == "RESPONSE_HEADERS":
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response headers not accessible outside Phase 3 for target: %s", target)
|
||||
}
|
||||
headers := make([]string, 0)
|
||||
for name, values := range w.Header() {
|
||||
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
|
||||
}
|
||||
unredactedValue = strings.Join(headers, "; ")
|
||||
|
||||
// Response Body (Phase 4)
|
||||
case target == "RESPONSE_BODY":
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response body not accessible outside Phase 4 for target: %s", target)
|
||||
}
|
||||
if recorder, ok := w.(*responseRecorder); ok {
|
||||
if recorder == nil {
|
||||
return "", fmt.Errorf("response recorder is nil for target: %s", target)
|
||||
}
|
||||
if recorder.body.Len() == 0 {
|
||||
rve.logger.Debug("Response body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("response body is empty for target: %s", target)
|
||||
}
|
||||
unredactedValue = recorder.BodyString()
|
||||
} else {
|
||||
return "", fmt.Errorf("response recorder not available for target: %s", target)
|
||||
}
|
||||
|
||||
case target == "FILE_NAME":
|
||||
// Extract file name from multipart form data
|
||||
if r.MultipartForm != nil && r.MultipartForm.File != nil {
|
||||
for _, files := range r.MultipartForm.File {
|
||||
for _, file := range files {
|
||||
unredactedValue = file.Filename
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("File name not found", zap.String("target", target))
|
||||
return "", fmt.Errorf("file name not found for target: %s", target)
|
||||
}
|
||||
|
||||
case target == "FILE_MIME_TYPE":
|
||||
// Extract MIME type from multipart form data
|
||||
if r.MultipartForm != nil && r.MultipartForm.File != nil {
|
||||
for _, files := range r.MultipartForm.File {
|
||||
for _, file := range files {
|
||||
unredactedValue = file.Header.Get("Content-Type")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("File MIME type not found", zap.String("target", target))
|
||||
return "", fmt.Errorf("file MIME type not found for target: %s", target)
|
||||
}
|
||||
|
||||
// Dynamic Header Extraction (Request)
|
||||
case strings.HasPrefix(target, "HEADERS:"), strings.HasPrefix(target, "REQUEST_HEADERS:"):
|
||||
headerName := strings.TrimPrefix(strings.TrimPrefix(target, "HEADERS:"), "REQUEST_HEADERS:") // Trim both prefixes
|
||||
headerValue := r.Header.Get(headerName)
|
||||
if headerValue == "" {
|
||||
rve.logger.Debug("Header not found", zap.String("header", headerName))
|
||||
return "", fmt.Errorf("header '%s' not found for target: %s", headerName, target)
|
||||
}
|
||||
unredactedValue = headerValue
|
||||
|
||||
// Dynamic Response Header Extraction (Phase 3)
|
||||
case strings.HasPrefix(target, "RESPONSE_HEADERS:"):
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response headers not available during this phase for target: %s", target)
|
||||
}
|
||||
headerName := strings.TrimPrefix(target, "RESPONSE_HEADERS:")
|
||||
headerValue := w.Header().Get(headerName)
|
||||
if headerValue == "" {
|
||||
rve.logger.Debug("Response header not found", zap.String("header", headerName))
|
||||
return "", fmt.Errorf("response header '%s' not found for target: %s", headerName, target)
|
||||
}
|
||||
unredactedValue = headerValue
|
||||
|
||||
// Cookies Extraction
|
||||
case target == "COOKIES":
|
||||
cookies := make([]string, 0)
|
||||
for _, c := range r.Cookies() {
|
||||
cookies = append(cookies, fmt.Sprintf("%s=%s", c.Name, c.Value))
|
||||
}
|
||||
if len(cookies) == 0 {
|
||||
rve.logger.Debug("No cookies found", zap.String("target", target))
|
||||
return "", fmt.Errorf("no cookies found for target: %s", target)
|
||||
}
|
||||
unredactedValue = strings.Join(cookies, "; ")
|
||||
|
||||
case strings.HasPrefix(target, "COOKIES:"):
|
||||
cookieName := strings.TrimPrefix(target, "COOKIES:")
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
} else if strings.HasPrefix(target, TargetHeadersPrefix) || strings.HasPrefix(target, TargetRequestHeadersPrefix) {
|
||||
unredactedValue, err = rve.extractDynamicHeader(r.Header, strings.TrimPrefix(strings.TrimPrefix(target, TargetHeadersPrefix), TargetRequestHeadersPrefix), target)
|
||||
if err != nil {
|
||||
rve.logger.Debug("Cookie not found", zap.String("cookie", cookieName))
|
||||
return "", fmt.Errorf("cookie '%s' not found for target: %s", cookieName, target)
|
||||
return "", err
|
||||
}
|
||||
unredactedValue = cookie.Value
|
||||
|
||||
// URL Parameter Extraction
|
||||
case strings.HasPrefix(target, "URL_PARAM:"):
|
||||
paramName := strings.TrimPrefix(target, "URL_PARAM:")
|
||||
if paramName == "" {
|
||||
return "", fmt.Errorf("URL parameter name is empty for target: %s", target)
|
||||
}
|
||||
if r.URL.Query().Get(paramName) == "" {
|
||||
rve.logger.Debug("URL parameter not found", zap.String("parameter", paramName))
|
||||
return "", fmt.Errorf("url parameter '%s' not found for target: %s", paramName, target)
|
||||
}
|
||||
unredactedValue = r.URL.Query().Get(paramName)
|
||||
|
||||
// JSON Path Extraction from Body
|
||||
case strings.HasPrefix(target, "JSON_PATH:"):
|
||||
jsonPath := strings.TrimPrefix(target, "JSON_PATH:")
|
||||
if r.Body == nil {
|
||||
rve.logger.Warn("Request body is nil", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is nil for target: %s", target)
|
||||
}
|
||||
if r.ContentLength == 0 {
|
||||
rve.logger.Debug("Request body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
} else if strings.HasPrefix(target, TargetResponseHeadersPrefix) {
|
||||
unredactedValue, err = rve.extractDynamicResponseHeader(w, strings.TrimPrefix(target, TargetResponseHeadersPrefix), target)
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err)
|
||||
return "", err
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for next read
|
||||
|
||||
// Use helper method to dynamically extract value based on JSON path (e.g., 'data.items.0.name').
|
||||
unredactedValue, err = rve.extractJSONPath(string(bodyBytes), jsonPath)
|
||||
} else if strings.HasPrefix(target, TargetCookiesPrefix) {
|
||||
unredactedValue, err = rve.extractDynamicCookie(r, strings.TrimPrefix(target, TargetCookiesPrefix), target)
|
||||
if err != nil {
|
||||
rve.logger.Debug("Failed to extract value from JSON path", zap.String("target", target), zap.String("path", jsonPath), zap.Error(err))
|
||||
return "", fmt.Errorf("failed to extract from JSON path '%s': %w", jsonPath, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// New cases start here:
|
||||
case target == "CONTENT_TYPE":
|
||||
unredactedValue = r.Header.Get("Content-Type")
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("Content-Type header not found", zap.String("target", target))
|
||||
return "", fmt.Errorf("content-type header not found for target: %s", target)
|
||||
} else if strings.HasPrefix(target, TargetURLParamPrefix) {
|
||||
unredactedValue, err = rve.extractURLParam(r.URL, strings.TrimPrefix(target, TargetURLParamPrefix), target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
case target == "URL":
|
||||
unredactedValue = r.URL.String()
|
||||
if unredactedValue == "" {
|
||||
rve.logger.Debug("URL could not be extracted", zap.String("target", target))
|
||||
return "", fmt.Errorf("url could not be extracted for target: %s", target)
|
||||
} else if strings.HasPrefix(target, TargetJSONPathPrefix) {
|
||||
unredactedValue, err = rve.extractValueForJSONPath(r, strings.TrimPrefix(target, TargetJSONPathPrefix), target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
case target == "REQUEST_COOKIES":
|
||||
cookies := make([]string, 0)
|
||||
for _, c := range r.Cookies() {
|
||||
cookies = append(cookies, fmt.Sprintf("%s=%s", c.Name, c.Value))
|
||||
}
|
||||
unredactedValue = strings.Join(cookies, "; ")
|
||||
if len(cookies) == 0 {
|
||||
rve.logger.Debug("No cookies found", zap.String("target", target))
|
||||
return "", fmt.Errorf("no cookies found for target: %s", target)
|
||||
}
|
||||
|
||||
default:
|
||||
} else {
|
||||
rve.logger.Warn("Unknown extraction target", zap.String("target", target))
|
||||
return "", fmt.Errorf("unknown extraction target: %s", target)
|
||||
}
|
||||
|
||||
// Redact sensitive fields before returning the value
|
||||
value := unredactedValue
|
||||
if rve.redactSensitiveData {
|
||||
sensitiveTargets := []string{"password", "token", "apikey", "authorization", "secret"}
|
||||
for _, sensitive := range sensitiveTargets {
|
||||
if strings.Contains(strings.ToLower(target), sensitive) {
|
||||
value = "REDACTED"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Redact sensitive fields before returning the value (as before)
|
||||
value := rve.redactValueIfSensitive(target, unredactedValue)
|
||||
|
||||
// Log the extracted value (redacted if necessary)
|
||||
rve.logger.Debug("Extracted value",
|
||||
@@ -310,6 +176,202 @@ func (rve *RequestValueExtractor) extractSingleValue(target string, r *http.Requ
|
||||
return unredactedValue, nil
|
||||
}
|
||||
|
||||
// Helper function to check for empty value and log debug message if empty
|
||||
func (rve *RequestValueExtractor) checkEmpty(value string, target, message string) error {
|
||||
if value == "" {
|
||||
rve.logger.Debug(message, zap.String("target", target))
|
||||
return fmt.Errorf("%s for target: %s", message, target)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to log debug message if value is empty
|
||||
func (rve *RequestValueExtractor) logIfEmpty(value string, target string, message string) {
|
||||
if value == "" {
|
||||
rve.logger.Debug(message, zap.String("target", target))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to extract body
|
||||
func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (string, error) {
|
||||
if r.Body == nil {
|
||||
rve.logger.Warn("Request body is nil", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is nil for target: %s", target)
|
||||
}
|
||||
if r.ContentLength == 0 {
|
||||
rve.logger.Debug("Request body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
||||
}
|
||||
r.Body = http.NoBody // Reset body for next read - using http.NoBody
|
||||
return string(bodyBytes), nil
|
||||
}
|
||||
|
||||
// Helper function to extract headers
|
||||
func (rve *RequestValueExtractor) extractHeaders(header http.Header, logMessage, target string) (string, error) {
|
||||
if len(header) == 0 {
|
||||
rve.logger.Debug(logMessage+" are empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("%s are empty for target: %s", logMessage, target)
|
||||
}
|
||||
headers := make([]string, 0)
|
||||
for name, values := range header {
|
||||
headers = append(headers, fmt.Sprintf("%s: %s", name, strings.Join(values, ",")))
|
||||
}
|
||||
return strings.Join(headers, "; "), nil
|
||||
}
|
||||
|
||||
// Helper function to extract response headers (for phase 3)
|
||||
func (rve *RequestValueExtractor) extractResponseHeaders(w http.ResponseWriter, target string) (string, error) {
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response headers not accessible during this phase for target: %s", target)
|
||||
}
|
||||
return rve.extractHeaders(w.Header(), "Response headers", target)
|
||||
}
|
||||
|
||||
// Helper function to extract response body (for phase 4)
|
||||
func (rve *RequestValueExtractor) extractResponseBody(w http.ResponseWriter, target string) (string, error) {
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response body not accessible outside Phase 4 for target: %s", target)
|
||||
}
|
||||
recorder, ok := w.(*responseRecorder)
|
||||
if !ok || recorder == nil {
|
||||
return "", fmt.Errorf("response recorder not available for target: %s", target)
|
||||
}
|
||||
if recorder.body.Len() == 0 {
|
||||
rve.logger.Debug("Response body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("response body is empty for target: %s", target)
|
||||
}
|
||||
return recorder.BodyString(), nil
|
||||
}
|
||||
|
||||
// Helper function to extract filename from multipart form
|
||||
func (rve *RequestValueExtractor) extractFileName(r *http.Request, target string) (string, error) {
|
||||
if r.MultipartForm == nil || r.MultipartForm.File == nil {
|
||||
rve.logger.Debug("Multipart form file not found", zap.String("target", target))
|
||||
return "", fmt.Errorf("multipart form file not found for target: %s", target)
|
||||
}
|
||||
for _, files := range r.MultipartForm.File {
|
||||
if len(files) > 0 { // Check if there are files
|
||||
return files[0].Filename, nil // Return the first file's name
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no files found in multipart form for target: %s", target) // No files found but form is present
|
||||
}
|
||||
|
||||
// Helper function to extract MIME type from multipart form
|
||||
func (rve *RequestValueExtractor) extractFileMIMEType(r *http.Request, target string) (string, error) {
|
||||
if r.MultipartForm == nil || r.MultipartForm.File == nil {
|
||||
rve.logger.Debug("Multipart form file not found", zap.String("target", target))
|
||||
return "", fmt.Errorf("multipart form file not found for target: %s", target)
|
||||
}
|
||||
|
||||
for _, files := range r.MultipartForm.File {
|
||||
if len(files) > 0 { // Check if files are present
|
||||
return files[0].Header.Get("Content-Type"), nil // Return MIME type of the first file
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no files found in multipart form for target: %s", target) // No files found even though form is present
|
||||
}
|
||||
|
||||
// Helper function to extract dynamic header value
|
||||
func (rve *RequestValueExtractor) extractDynamicHeader(header http.Header, headerName, target string) (string, error) {
|
||||
headerValue := header.Get(headerName)
|
||||
if headerValue == "" {
|
||||
rve.logger.Debug("Header not found", zap.String("header", headerName))
|
||||
return "", fmt.Errorf("header '%s' not found for target: %s", headerName, target)
|
||||
}
|
||||
return headerValue, nil
|
||||
}
|
||||
|
||||
// Helper function to extract dynamic response header value (for phase 3)
|
||||
func (rve *RequestValueExtractor) extractDynamicResponseHeader(w http.ResponseWriter, headerName, target string) (string, error) {
|
||||
if w == nil {
|
||||
return "", fmt.Errorf("response headers not available during this phase for target: %s", target)
|
||||
}
|
||||
headerValue := w.Header().Get(headerName)
|
||||
if headerValue == "" {
|
||||
rve.logger.Debug("Response header not found", zap.String("header", headerName))
|
||||
return "", fmt.Errorf("response header '%s' not found for target: %s", headerName, target)
|
||||
}
|
||||
return headerValue, nil
|
||||
}
|
||||
|
||||
// Helper function to extract cookie value
|
||||
func (rve *RequestValueExtractor) extractDynamicCookie(r *http.Request, cookieName string, target string) (string, error) {
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
rve.logger.Debug("Cookie not found", zap.String("cookie", cookieName))
|
||||
return "", fmt.Errorf("cookie '%s' not found for target: %s", cookieName, target)
|
||||
}
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
// Helper function to extract URL parameter value
|
||||
func (rve *RequestValueExtractor) extractURLParam(url *url.URL, paramName string, target string) (string, error) {
|
||||
paramValue := url.Query().Get(paramName)
|
||||
if paramValue == "" {
|
||||
rve.logger.Debug("URL parameter not found", zap.String("parameter", paramName))
|
||||
return "", fmt.Errorf("url parameter '%s' not found for target: %s", paramName, target)
|
||||
}
|
||||
return paramValue, nil
|
||||
}
|
||||
|
||||
// Helper function to extract value for JSON Path
|
||||
func (rve *RequestValueExtractor) extractValueForJSONPath(r *http.Request, jsonPath string, target string) (string, error) {
|
||||
if r.Body == nil {
|
||||
rve.logger.Warn("Request body is nil", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is nil for target: %s", target)
|
||||
}
|
||||
if r.ContentLength == 0 {
|
||||
rve.logger.Debug("Request body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err)
|
||||
}
|
||||
r.Body = http.NoBody // Reset body for next read
|
||||
|
||||
// Use helper method to dynamically extract value based on JSON path (e.g., 'data.items.0.name').
|
||||
unredactedValue, err := rve.extractJSONPath(string(bodyBytes), jsonPath)
|
||||
if err != nil {
|
||||
rve.logger.Debug("Failed to extract value from JSON path", zap.String("target", target), zap.String("path", jsonPath), zap.Error(err))
|
||||
return "", fmt.Errorf("failed to extract from JSON path '%s': %w", jsonPath, err)
|
||||
}
|
||||
return unredactedValue, nil
|
||||
}
|
||||
|
||||
// Helper function to redact value if target is sensitive
|
||||
func (rve *RequestValueExtractor) redactValueIfSensitive(target string, value string) string {
|
||||
if rve.redactSensitiveData {
|
||||
for _, sensitive := range sensitiveTargets {
|
||||
if strings.Contains(strings.ToLower(target), sensitive) {
|
||||
return "REDACTED"
|
||||
}
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Helper function to extract cookies
|
||||
func (rve *RequestValueExtractor) extractCookies(cookies []*http.Cookie, logMessage string, target string) (string, error) {
|
||||
if len(cookies) == 0 {
|
||||
rve.logger.Debug(logMessage, zap.String("target", target))
|
||||
return "", fmt.Errorf("%s for target: %s", logMessage, target)
|
||||
}
|
||||
cookieStrings := make([]string, 0)
|
||||
for _, cookie := range cookies {
|
||||
cookieStrings = append(cookieStrings, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
|
||||
}
|
||||
return strings.Join(cookieStrings, "; "), nil
|
||||
}
|
||||
|
||||
// Helper function for JSON path extraction.
|
||||
func (rve *RequestValueExtractor) extractJSONPath(jsonStr string, jsonPath string) (string, error) {
|
||||
// Validate input JSON string
|
||||
|
||||
225
rules.go
225
rules.go
@@ -1,3 +1,4 @@
|
||||
// rules.go
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
@@ -9,18 +10,14 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, value string, state *WAFState) bool {
|
||||
logID, _ := r.Context().Value("logID").(string)
|
||||
if logID == "" {
|
||||
logID = uuid.New().String()
|
||||
}
|
||||
logID := r.Context().Value(ContextKeyLogId("logID")).(string)
|
||||
|
||||
m.logRequest(zapcore.DebugLevel, "Rule matched during evaluation", r,
|
||||
m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
zap.String("target", strings.Join(rule.Targets, ",")),
|
||||
zap.String("value", value),
|
||||
@@ -28,76 +25,54 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
|
||||
zap.Int("score", rule.Score),
|
||||
)
|
||||
|
||||
// Increment rule hit count
|
||||
if count, ok := m.ruleHits.Load(rule.ID); ok {
|
||||
newCount := count.(HitCount) + 1
|
||||
m.ruleHits.Store(rule.ID, newCount)
|
||||
m.logger.Debug("Incremented rule hit count",
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
zap.Int("new_count", int(newCount)),
|
||||
)
|
||||
} else {
|
||||
m.ruleHits.Store(rule.ID, HitCount(1))
|
||||
m.logger.Debug("Initialized rule hit count",
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
zap.Int("new_count", 1),
|
||||
)
|
||||
}
|
||||
// Rule Hit Counter - Refactored for clarity
|
||||
m.incrementRuleHitCount(RuleID(rule.ID))
|
||||
|
||||
// Increment rule hits by phase
|
||||
m.muMetrics.Lock()
|
||||
if m.ruleHitsByPhase == nil {
|
||||
m.ruleHitsByPhase = make(map[int]int64)
|
||||
}
|
||||
m.ruleHitsByPhase[rule.Phase]++
|
||||
m.muMetrics.Unlock()
|
||||
// Metrics for Rule Hits by Phase - Refactored for clarity
|
||||
m.incrementRuleHitsByPhaseMetric(rule.Phase)
|
||||
|
||||
oldScore := state.TotalScore
|
||||
state.TotalScore += rule.Score
|
||||
m.logRequest(zapcore.DebugLevel, "Increased anomaly score", r,
|
||||
m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // Corrected argument order - 'r' is now the third argument
|
||||
zap.String("log_id", logID),
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
zap.Int("score_increase", rule.Score),
|
||||
zap.Int("old_total_score", oldScore),
|
||||
zap.Int("new_total_score", state.TotalScore),
|
||||
zap.Int("old_score", oldScore),
|
||||
zap.Int("new_score", state.TotalScore),
|
||||
zap.Int("anomaly_threshold", m.AnomalyThreshold),
|
||||
)
|
||||
|
||||
shouldBlock := false
|
||||
shouldBlock := !state.ResponseWritten && (state.TotalScore >= m.AnomalyThreshold || rule.Action == "block")
|
||||
blockReason := ""
|
||||
|
||||
if !state.ResponseWritten {
|
||||
if state.TotalScore >= m.AnomalyThreshold {
|
||||
shouldBlock = true
|
||||
blockReason = "Anomaly threshold exceeded"
|
||||
} else if rule.Action == "block" {
|
||||
shouldBlock = true
|
||||
if shouldBlock {
|
||||
blockReason = "Anomaly threshold exceeded"
|
||||
if rule.Action == "block" {
|
||||
blockReason = "Rule action is 'block'"
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Debug("Processing rule action",
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // 'r' is now the 3rd argument
|
||||
zap.String("action", rule.Action),
|
||||
zap.Bool("should_block", shouldBlock),
|
||||
zap.String("block_reason", blockReason),
|
||||
)
|
||||
|
||||
if shouldBlock && !state.ResponseWritten {
|
||||
if shouldBlock {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, blockReason, string(rule.ID), value,
|
||||
zap.Int("total_score", state.TotalScore),
|
||||
zap.Int("anomaly_threshold", m.AnomalyThreshold),
|
||||
)
|
||||
return false // Stop further processing
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.Action == "log" {
|
||||
m.logRequest(zapcore.InfoLevel, "Rule action is 'log', request allowed but logged", r,
|
||||
m.logRequest(zapcore.InfoLevel, "Rule action: Log", r,
|
||||
zap.String("log_id", logID),
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
)
|
||||
} else if !shouldBlock && !state.ResponseWritten {
|
||||
m.logRequest(zapcore.DebugLevel, "Rule matched, no blocking action taken", r,
|
||||
m.logRequest(zapcore.DebugLevel, "Rule action: No Block", r,
|
||||
zap.String("log_id", logID),
|
||||
zap.String("rule_id", string(rule.ID)),
|
||||
zap.String("action", rule.Action),
|
||||
@@ -106,7 +81,30 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
|
||||
)
|
||||
}
|
||||
|
||||
return true // Continue processing
|
||||
return true
|
||||
}
|
||||
|
||||
// incrementRuleHitCount increments the hit counter for a given rule ID.
|
||||
func (m *Middleware) incrementRuleHitCount(ruleID RuleID) {
|
||||
hitCount := HitCount(1) // Default increment
|
||||
if currentCount, loaded := m.ruleHits.Load(ruleID); loaded {
|
||||
hitCount = currentCount.(HitCount) + 1
|
||||
}
|
||||
m.ruleHits.Store(ruleID, hitCount)
|
||||
m.logger.Debug("Rule hit count updated",
|
||||
zap.String("rule_id", string(ruleID)),
|
||||
zap.Int("hit_count", int(hitCount)), // More descriptive log field
|
||||
)
|
||||
}
|
||||
|
||||
// incrementRuleHitsByPhaseMetric increments the rule hits by phase metric.
|
||||
func (m *Middleware) incrementRuleHitsByPhaseMetric(phase int) {
|
||||
m.muMetrics.Lock()
|
||||
if m.ruleHitsByPhase == nil {
|
||||
m.ruleHitsByPhase = make(map[int]int64)
|
||||
}
|
||||
m.ruleHitsByPhase[phase]++
|
||||
m.muMetrics.Unlock()
|
||||
}
|
||||
|
||||
func validateRule(rule *Rule) error {
|
||||
@@ -131,92 +129,109 @@ func validateRule(rule *Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRules updates the RuleCache when rules are loaded and sorts rules by priority.
|
||||
// loadRules updates the RuleCache and Rules map when rules are loaded and sorts rules by priority.
|
||||
func (m *Middleware) loadRules(paths []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.logger.Debug("Loading rules from files", zap.Strings("rule_files", paths))
|
||||
|
||||
m.Rules = make(map[int][]Rule)
|
||||
loadedRules := make(map[int][]Rule) // Temporary map to hold loaded rules
|
||||
totalRules := 0
|
||||
var invalidFiles []string
|
||||
var allInvalidRules []string
|
||||
invalidFiles := []string{}
|
||||
allInvalidRules := []string{}
|
||||
ruleIDs := make(map[string]bool)
|
||||
|
||||
for _, path := range paths {
|
||||
content, err := os.ReadFile(path)
|
||||
fileRules, fileInvalidRules, err := m.loadRulesFromFile(path, ruleIDs) // Load rules from a single file
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to read rule file", zap.String("file", path), zap.Error(err))
|
||||
m.logger.Error("Failed to load rule file", zap.String("file", path), zap.Error(err))
|
||||
invalidFiles = append(invalidFiles, path)
|
||||
continue
|
||||
continue // Skip to the next file if loading fails
|
||||
}
|
||||
|
||||
var rules []Rule
|
||||
if err := json.Unmarshal(content, &rules); err != nil {
|
||||
m.logger.Error("Failed to unmarshal rules from file", zap.String("file", path), zap.Error(err))
|
||||
invalidFiles = append(invalidFiles, path)
|
||||
continue
|
||||
if len(fileInvalidRules) > 0 {
|
||||
m.logger.Warn("Invalid rules in file", zap.String("file", path), zap.Strings("errors", fileInvalidRules))
|
||||
allInvalidRules = append(allInvalidRules, fileInvalidRules...)
|
||||
}
|
||||
m.logger.Info("Rules loaded from file", zap.String("file", path), zap.Int("valid_rules", len(fileRules)), zap.Int("invalid_rules", len(fileInvalidRules)))
|
||||
|
||||
// Sort rules by priority (higher priority first)
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return rules[i].Priority > rules[j].Priority
|
||||
})
|
||||
|
||||
var invalidRulesInFile []string
|
||||
for i, rule := range rules {
|
||||
if err := validateRule(&rule); err != nil {
|
||||
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule at index %d: %v", i, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ruleIDs[string(rule.ID)]; exists {
|
||||
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i))
|
||||
continue
|
||||
}
|
||||
ruleIDs[string(rule.ID)] = true
|
||||
|
||||
// Check RuleCache first
|
||||
if regex, exists := m.ruleCache.Get(rule.ID); exists {
|
||||
rule.regex = regex
|
||||
} else {
|
||||
regex, err := regexp.Compile(rule.Pattern)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to compile regex for rule", zap.String("rule_id", string(rule.ID)), zap.String("pattern", rule.Pattern), zap.Error(err))
|
||||
invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err))
|
||||
continue
|
||||
}
|
||||
rule.regex = regex
|
||||
m.ruleCache.Set(rule.ID, regex) // Cache the compiled regex
|
||||
}
|
||||
|
||||
if _, ok := m.Rules[rule.Phase]; !ok {
|
||||
m.Rules[rule.Phase] = []Rule{}
|
||||
}
|
||||
|
||||
m.Rules[rule.Phase] = append(m.Rules[rule.Phase], rule)
|
||||
totalRules++
|
||||
// Merge valid rules from the file into the temporary loadedRules map
|
||||
for phase, rules := range fileRules {
|
||||
loadedRules[phase] = append(loadedRules[phase], rules...)
|
||||
}
|
||||
if len(invalidRulesInFile) > 0 {
|
||||
m.logger.Warn("Some rules failed validation", zap.String("file", path), zap.Strings("invalid_rules", invalidRulesInFile))
|
||||
allInvalidRules = append(allInvalidRules, invalidRulesInFile...)
|
||||
}
|
||||
|
||||
m.logger.Info("Rules loaded", zap.String("file", path), zap.Int("total_rules", len(rules)), zap.Int("invalid_rules", len(invalidRulesInFile)))
|
||||
totalRules += len(fileRules[1]) + len(fileRules[2]) + len(fileRules[3]) + len(fileRules[4]) // Update total rule count
|
||||
}
|
||||
|
||||
m.Rules = loadedRules // Atomically update m.Rules after loading all files
|
||||
|
||||
if len(invalidFiles) > 0 {
|
||||
m.logger.Warn("Some rule files could not be loaded", zap.Strings("invalid_files", invalidFiles))
|
||||
m.logger.Error("Failed to load rule files", zap.Strings("files", invalidFiles)) // Error level for file loading failures
|
||||
}
|
||||
if len(allInvalidRules) > 0 {
|
||||
m.logger.Warn("Some rules across files failed validation", zap.Strings("invalid_rules", allInvalidRules))
|
||||
m.logger.Warn("Validation errors in rules", zap.Strings("errors", allInvalidRules)) // More specific log message - "errors" field
|
||||
}
|
||||
|
||||
if totalRules == 0 && len(invalidFiles) > 0 {
|
||||
if totalRules == 0 && len(paths) > 0 { // Only return error if paths were provided
|
||||
return fmt.Errorf("no valid rules were loaded from any file")
|
||||
} else if totalRules == 0 && len(paths) == 0 {
|
||||
m.logger.Warn("No rule files specified, WAF will run without rules.") // Warn if no rule files and no rules loaded
|
||||
}
|
||||
m.logger.Debug("Rules loaded successfully", zap.Int("total_rules", totalRules))
|
||||
|
||||
m.logger.Info("WAF rules loaded successfully", zap.Int("total_rules", totalRules))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRulesFromFile loads and validates rules from a single file.
|
||||
func (m *Middleware) loadRulesFromFile(path string, ruleIDs map[string]bool) (validRules map[int][]Rule, invalidRules []string, err error) {
|
||||
validRules = make(map[int][]Rule)
|
||||
var fileInvalidRules []string
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read rule file: %w", err)
|
||||
}
|
||||
|
||||
var rules []Rule
|
||||
if err := json.Unmarshal(content, &rules); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal rules: %w", err)
|
||||
}
|
||||
|
||||
// Sort rules by priority (higher priority first)
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return rules[i].Priority > rules[j].Priority
|
||||
})
|
||||
|
||||
for i, rule := range rules {
|
||||
if err := validateRule(&rule); err != nil {
|
||||
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Rule at index %d: %v", i, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ruleIDs[string(rule.ID)]; exists {
|
||||
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i))
|
||||
continue
|
||||
}
|
||||
ruleIDs[string(rule.ID)] = true // Track rule IDs to prevent duplicates
|
||||
|
||||
// RuleCache handling (compile and cache regex)
|
||||
if cachedRegex, exists := m.ruleCache.Get(rule.ID); exists {
|
||||
rule.regex = cachedRegex
|
||||
} else {
|
||||
compiledRegex, err := regexp.Compile(rule.Pattern)
|
||||
if err != nil {
|
||||
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err))
|
||||
continue
|
||||
}
|
||||
rule.regex = compiledRegex
|
||||
m.ruleCache.Set(rule.ID, compiledRegex) // Cache regex
|
||||
}
|
||||
|
||||
if _, ok := validRules[rule.Phase]; !ok {
|
||||
validRules[rule.Phase] = []Rule{}
|
||||
}
|
||||
validRules[rule.Phase] = append(validRules[rule.Phase], rule)
|
||||
}
|
||||
|
||||
return validRules, fileInvalidRules, nil
|
||||
}
|
||||
|
||||
38
tor.go
38
tor.go
@@ -1,7 +1,7 @@
|
||||
// tor.go
|
||||
package caddywaf
|
||||
|
||||
import (
|
||||
"fmt" // Import fmt for improved error formatting
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -32,7 +32,7 @@ func (t *TorConfig) Provision(ctx caddy.Context) error {
|
||||
t.logger = ctx.Logger()
|
||||
if t.Enabled {
|
||||
if err := t.updateTorExitNodes(); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("provisioning tor: %w", err) // Improved error wrapping
|
||||
}
|
||||
go t.scheduleUpdates()
|
||||
}
|
||||
@@ -41,21 +41,27 @@ func (t *TorConfig) Provision(ctx caddy.Context) error {
|
||||
|
||||
// updateTorExitNodes fetches the latest Tor exit nodes and updates the IP blacklist.
|
||||
func (t *TorConfig) updateTorExitNodes() error {
|
||||
t.logger.Debug("Updating Tor exit nodes...") // Debug log at start of update
|
||||
|
||||
resp, err := http.Get(torExitNodeURL)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("http get failed for %s: %w", torExitNodeURL, err) // Improved error message with URL
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("http get returned status %s for %s", resp.Status, torExitNodeURL) // Check for non-200 status
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read response body from %s: %w", torExitNodeURL, err) // Improved error message with URL
|
||||
}
|
||||
|
||||
torIPs := strings.Split(string(data), "\n")
|
||||
existingIPs, err := t.readExistingBlacklist()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to read existing blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
|
||||
}
|
||||
|
||||
// Merge and deduplicate IPs
|
||||
@@ -65,15 +71,15 @@ func (t *TorConfig) updateTorExitNodes() error {
|
||||
|
||||
// Write updated blacklist to file
|
||||
if err := t.writeBlacklist(uniqueIPs); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to write updated blacklist to file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
|
||||
}
|
||||
|
||||
t.lastUpdated = time.Now()
|
||||
t.logger.Info("Updated Tor exit nodes in IP blacklist", zap.Int("count", len(uniqueIPs)))
|
||||
t.logger.Info("Tor exit nodes updated", zap.Int("count", len(uniqueIPs))) // Improved log message
|
||||
t.logger.Debug("Tor exit node update completed successfully") // Debug log at end of update
|
||||
return nil
|
||||
}
|
||||
|
||||
// scheduleUpdates periodically updates the Tor exit node list.
|
||||
// scheduleUpdates periodically updates the Tor exit node list.
|
||||
func (t *TorConfig) scheduleUpdates() {
|
||||
interval, err := time.ParseDuration(t.UpdateInterval)
|
||||
@@ -96,13 +102,13 @@ func (t *TorConfig) scheduleUpdates() {
|
||||
|
||||
// Use for range to iterate over the ticker channel
|
||||
for range ticker.C {
|
||||
if err := t.updateTorExitNodes(); err != nil {
|
||||
if updateErr := t.updateTorExitNodes(); updateErr != nil { // Renamed err to updateErr for clarity
|
||||
if t.RetryOnFailure {
|
||||
t.logger.Error("Failed to update Tor exit nodes, retrying shortly", zap.Error(err))
|
||||
t.logger.Error("Failed to update Tor exit nodes, retrying shortly", zap.Error(updateErr)) // Use updateErr
|
||||
time.Sleep(retryInterval)
|
||||
continue
|
||||
} else {
|
||||
t.logger.Error("Failed to update Tor exit nodes, will retry at next scheduled interval", zap.Error(err))
|
||||
t.logger.Error("Failed to update Tor exit nodes, will retry at next scheduled interval", zap.Error(updateErr)) // Use updateErr
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -113,9 +119,10 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) {
|
||||
data, err := os.ReadFile(t.TORIPBlacklistFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
t.logger.Debug("Blacklist file does not exist, assuming empty list", zap.String("path", t.TORIPBlacklistFile)) // Debug log for non-existent file
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to read IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
|
||||
}
|
||||
return strings.Split(string(data), "\n"), nil
|
||||
}
|
||||
@@ -123,7 +130,12 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) {
|
||||
// writeBlacklist writes the updated IP blacklist to the file.
|
||||
func (t *TorConfig) writeBlacklist(ips []string) error {
|
||||
data := strings.Join(ips, "\n")
|
||||
return os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0644)
|
||||
err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
|
||||
}
|
||||
t.logger.Debug("Blacklist file updated", zap.String("path", t.TORIPBlacklistFile), zap.Int("entry_count", len(ips))) // Debug log for file update
|
||||
return nil
|
||||
}
|
||||
|
||||
// unique removes duplicate entries from a slice of strings.
|
||||
|
||||
Reference in New Issue
Block a user