diff --git a/caddywaf.go b/caddywaf.go index 5fde7d5..4427ae7 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -27,7 +27,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/oschwald/maxminddb-golang" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -193,23 +193,23 @@ func (m *Middleware) Provision(ctx caddy.Context) error { // Initialize GeoIP stats m.geoIPStats = make(map[string]int64) - // Configure GeoIP-based country blocking/whitelisting - if m.CountryBlock.Enabled || m.CountryWhitelist.Enabled { - geoIPPath := m.CountryBlock.GeoIPDBPath + // Configure GeoIP-based country blacklisting/whitelisting + if m.CountryBlacklist.Enabled || m.CountryWhitelist.Enabled { + geoIPPath := m.CountryBlacklist.GeoIPDBPath if m.CountryWhitelist.Enabled && m.CountryWhitelist.GeoIPDBPath != "" { geoIPPath = m.CountryWhitelist.GeoIPDBPath } if !fileExists(geoIPPath) { - m.logger.Warn("GeoIP database not found. Country blocking/whitelisting will be disabled", zap.String("path", geoIPPath)) + m.logger.Warn("GeoIP database not found. Country blacklisting/whitelisting will be disabled", zap.String("path", geoIPPath)) } else { reader, err := maxminddb.Open(geoIPPath) if err != nil { m.logger.Error("Failed to load GeoIP database", zap.String("path", geoIPPath), zap.Error(err)) } else { m.logger.Info("GeoIP database loaded successfully", zap.String("path", geoIPPath)) - if m.CountryBlock.Enabled { - m.CountryBlock.geoIP = reader + if m.CountryBlacklist.Enabled { + m.CountryBlacklist.geoIP = reader } if m.CountryWhitelist.Enabled { m.CountryWhitelist.geoIP = reader @@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { // Load IP blacklist if m.IPBlacklistFile != "" { - m.ipBlacklist = trie.NewTrie() + m.ipBlacklist = iptrie.NewTrie() err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist) if err != nil { return fmt.Errorf("failed to load IP blacklist: %w", err) @@ -288,20 +288,20 @@ func (m *Middleware) Shutdown(ctx context.Context) error { var errorOccurred bool // Close GeoIP databases - if m.CountryBlock.geoIP != nil { - m.logger.Debug("Closing country block GeoIP database...") - if err := m.CountryBlock.geoIP.Close(); err != nil { - m.logger.Error("Error encountered while closing country block GeoIP database", zap.Error(err)) + if m.CountryBlacklist.geoIP != nil { + m.logger.Debug("Closing country blacklist GeoIP database...") + if err := m.CountryBlacklist.geoIP.Close(); err != nil { + m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err)) if !errorOccurred { - firstError = fmt.Errorf("error closing country block GeoIP: %w", err) + firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err) errorOccurred = true } } else { - m.logger.Debug("Country block GeoIP database closed successfully.") + m.logger.Debug("Country blacklist GeoIP database closed successfully.") } - m.CountryBlock.geoIP = nil + m.CountryBlacklist.geoIP = nil } else { - m.logger.Debug("Country block GeoIP database was not open, skipping close.") + m.logger.Debug("Country blacklist GeoIP database was not open, skipping close.") } if m.CountryWhitelist.geoIP != nil { @@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error { m.logger.Info("Reloading WAF configuration") if m.IPBlacklistFile != "" { - newIPBlacklist := trie.NewTrie() + newIPBlacklist := iptrie.NewTrie() if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil { m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err)) return fmt.Errorf("failed to reload IP blacklist: %v", err) @@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error { return nil } -func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error { +func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error { if _, err := os.Stat(path); os.IsNotExist(err) { m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path)) return nil diff --git a/caddywaf_test.go b/caddywaf_test.go index 0294778..1c78e88 100644 --- a/caddywaf_test.go +++ b/caddywaf_test.go @@ -32,7 +32,7 @@ func TestMiddleware_Provision(t *testing.T) { IPBlacklistFile: "testdata/ip_blacklist.txt", DNSBlacklistFile: "testdata/dns_blacklist.txt", AnomalyThreshold: 10, - CountryBlock: CountryAccessFilter{ + CountryBlacklist: CountryAccessFilter{ Enabled: true, CountryList: []string{"US"}, GeoIPDBPath: "testdata/GeoIP2-Country-Test.mmdb", diff --git a/common_test.go b/common_test.go index 707b24f..a23439c 100644 --- a/common_test.go +++ b/common_test.go @@ -1,8 +1,21 @@ package caddywaf +import "net/http" + const ( geoIPdata = "GeoLite2-Country.mmdb" - googleUSIP = "74.125.131.105" localIP = "127.0.0.1" + aliCNIP = "47.88.198.38" + googleUSIP = "74.125.131.105" + googleBRIP = "128.201.228.12" + googleRUIP = "74.125.131.94" + testURL = "http://example.com" torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing ) + +var customResponse = map[int]CustomBlockResponse{ + 403: { + StatusCode: http.StatusForbidden, + Body: "Access Denied", + }, +} diff --git a/config.go b/config.go index 9944df8..231c3c3 100644 --- a/config.go +++ b/config.go @@ -140,7 +140,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware m.LogSeverity = "info" m.LogJSON = false m.AnomalyThreshold = 5 - m.CountryBlock.Enabled = false + m.CountryBlacklist.Enabled = false m.CountryWhitelist.Enabled = false m.LogFilePath = "debug.json" m.RedactSensitiveData = false @@ -269,7 +269,7 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar // parseCountryBlockDirective returns a closure to handle block_countries and whitelist_countries directives. func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfile.Dispenser, m *Middleware) error { return func(d *caddyfile.Dispenser, m *Middleware) error { - target := &m.CountryBlock + target := &m.CountryBlacklist directiveName := "block_countries" if !isBlock { target = &m.CountryWhitelist diff --git a/config_test.go b/config_test.go index fbe7911..bf8084c 100644 --- a/config_test.go +++ b/config_test.go @@ -201,14 +201,14 @@ func TestParseCountryBlock(t *testing.T) { t.Fatalf("parseCountryBlockDirective failed: %v", err) } - if !m.CountryBlock.Enabled { - t.Errorf("Expected country block to be enabled, got %v", m.CountryBlock.Enabled) + if !m.CountryBlacklist.Enabled { + t.Errorf("Expected country blacklist to be enabled, got %v", m.CountryBlacklist.Enabled) } - if m.CountryBlock.GeoIPDBPath != "/etc/geoip/GeoIP.dat" { - t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlock.GeoIPDBPath) + if m.CountryBlacklist.GeoIPDBPath != "/etc/geoip/GeoIP.dat" { + t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlacklist.GeoIPDBPath) } - if len(m.CountryBlock.CountryList) != 2 || m.CountryBlock.CountryList[0] != "US" || m.CountryBlock.CountryList[1] != "CA" { - t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlock.CountryList) + if len(m.CountryBlacklist.CountryList) != 2 || m.CountryBlacklist.CountryList[0] != "US" || m.CountryBlacklist.CountryList[1] != "CA" { + t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlacklist.CountryList) } } diff --git a/handler.go b/handler.go index 3049a18..badbd17 100644 --- a/handler.go +++ b/handler.go @@ -230,18 +230,18 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i zap.String("user_agent", r.UserAgent()), ) - if phase == 1 && m.CountryBlock.Enabled { - m.logger.Debug("Starting country blocking phase") - blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP) + if phase == 1 && m.CountryBlacklist.Enabled { + m.logger.Debug("Starting country blacklisting phase") + blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP) if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check country block", + m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting", r, zap.Error(err), ) m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, zap.String("message", "Request blocked due to internal error"), ) - m.logger.Debug("Country blocking phase completed - blocked due to error") + m.logger.Debug("Country blacklisting phase completed - blocked due to error") m.incrementGeoIPRequestsMetric(false) // Increment with false for error return } else if blocked { @@ -254,7 +254,34 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i } return } - m.logger.Debug("Country blocking phase completed - not blocked") + m.logger.Debug("Country blacklisting phase completed - not blocked") + m.incrementGeoIPRequestsMetric(false) // Increment with false for no block + } + + if phase == 1 && m.CountryWhitelist.Enabled { + m.logger.Debug("Starting country whitelisting phase") + allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist", + r, + zap.Error(err), + ) + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country whitelisting phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } else if !allowed { + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked by country")) + m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + m.logger.Debug("Country whitelisting phase completed - not blocked") m.incrementGeoIPRequestsMetric(false) // Increment with false for no block } @@ -327,7 +354,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i rules, ok := m.Rules[phase] if !ok { m.logger.Debug("No rules found for phase", zap.Int("phase", phase)) - return + // Don't block on empty rules. There may be no rules specified + // return } m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) @@ -434,6 +462,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i zap.Int("total_score", state.TotalScore), zap.Int("anomaly_threshold", m.AnomalyThreshold), ) + + m.allowRequest(state) } // incrementRateLimiterBlockedRequestsMetric increments the blocked requests metric for the rate limiter. diff --git a/handler_test.go b/handler_test.go index 855e240..f0d1178 100644 --- a/handler_test.go +++ b/handler_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -26,32 +26,37 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) { dnsBlacklist: map[string]struct{}{ "malicious.domain": {}, }, - ipBlacklist: trie.NewTrie(), - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, - }, + ipBlacklist: iptrie.NewTrie(), + CustomResponses: customResponse, } - // Simulate a request to a blacklisted domain - req := httptest.NewRequest("GET", "http://malicious.domain", nil) - req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} - // Process the request in Phase 1 - middleware.handlePhase(w, req, 1, state) + t.Run("Allow unblocked domain", func(t *testing.T) { + // Simulate a request to a blacklisted domain + req := httptest.NewRequest("GET", testURL, nil) + req.RemoteAddr = localIP - // Debug: Print the response body and status code - t.Logf("Response Body: %s", w.Body.String()) - t.Logf("Response Status Code: %d", w.Code) + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + assert.False(t, state.Blocked, "Request should be allowed") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") + }) - // Verify that the request was blocked - assert.True(t, state.Blocked, "Request should be blocked") - assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") - assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + t.Run("Block blacklisted domain", func(t *testing.T) { + // Simulate a request to a blacklisted domain + req := httptest.NewRequest("GET", "http://malicious.domain", nil) + req.RemoteAddr = localIP + + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) } func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { @@ -62,68 +67,135 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata) assert.NoError(t, err) - middleware := &Middleware{ + blackListMiddleware := &Middleware{ logger: logger, + ipBlacklist: iptrie.NewTrie(), geoIPHandler: geoIPHandler, - CountryBlock: CountryAccessFilter{ + CountryBlacklist: CountryAccessFilter{ Enabled: true, - CountryList: []string{"US"}, + CountryList: []string{"US", "RU"}, GeoIPDBPath: geoIPdata, // Path to a test GeoIP database geoIP: geoIPBlock, }, - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, - }, + CustomResponses: customResponse, } - // Simulate a request from a blocked country (US) - req := httptest.NewRequest("GET", "http://example.com", nil) - req.RemoteAddr = googleUSIP - w := httptest.NewRecorder() + whiteListMiddleware := &Middleware{ + logger: logger, + ipBlacklist: iptrie.NewTrie(), + geoIPHandler: geoIPHandler, + CountryWhitelist: CountryAccessFilter{ + Enabled: true, + CountryList: []string{"BR"}, + GeoIPDBPath: geoIPdata, // Path to a test GeoIP database + geoIP: geoIPBlock, + }, + CustomResponses: customResponse, + } + + req := httptest.NewRequest("GET", testURL, nil) + state := &WAFState{} - // Process the request in Phase 1 - middleware.handlePhase(w, req, 1, state) + t.Run("GeoIP Blacklist: Allow CN IP", func(t *testing.T) { + w := httptest.NewRecorder() + req.RemoteAddr = aliCNIP - // Verify that the request was blocked - assert.True(t, state.Blocked, "Request should be blocked") - assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") - assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + // Process the request in Phase 1 + blackListMiddleware.handlePhase(w, req, 1, state) + assert.False(t, state.Blocked, "Request should be allowed") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") + }) + + t.Run("GeoIP Blacklist: Block US IP", func(t *testing.T) { + w := httptest.NewRecorder() + req.RemoteAddr = googleUSIP + + // Process the request in Phase 1 + blackListMiddleware.handlePhase(w, req, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) + + t.Run("GeoIP Whitelist: Allow BR IP", func(t *testing.T) { + w := httptest.NewRecorder() + req.RemoteAddr = googleBRIP + + // Process the request in Phase 1 + whiteListMiddleware.handlePhase(w, req, 1, state) + assert.False(t, state.Blocked, "Request should be allowed") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") + }) + + t.Run("GeoIP Whitelist: Block RU IP", func(t *testing.T) { + w := httptest.NewRecorder() + req.RemoteAddr = googleRUIP + + // Process the request in Phase 1 + whiteListMiddleware.handlePhase(w, req, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) } func TestBlockedRequestPhase1_IPBlocking(t *testing.T) { logger, err := zap.NewDevelopment() assert.NoError(t, err) - ipBlackList := trie.NewTrie() - ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil) + blackList := iptrie.NewTrie() + loader := iptrie.NewTrieLoader(blackList) - middleware := &Middleware{ - logger: logger, - ipBlacklist: ipBlackList, - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, - }, + for _, net := range []string{ + "192.168.0.0/24", + "192.168.1.1/32", + } { + loader.Insert(netip.MustParsePrefix(net), "net="+net) } - req := httptest.NewRequest("GET", "http://example.com", nil) - req.RemoteAddr = localIP - w := httptest.NewRecorder() state := &WAFState{} + w := httptest.NewRecorder() - // Process the request in Phase 1 - middleware.handlePhase(w, req, 1, state) + t.Run("Allow unblocked CIDR", func(t *testing.T) { + middleware := &Middleware{ + logger: logger, + ipBlacklist: blackList, + CustomResponses: customResponse, + } - // Verify that the request was blocked - assert.True(t, state.Blocked, "Request should be blocked") - assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") - assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + req := httptest.NewRequest("GET", testURL, nil) + req.RemoteAddr = localIP + + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + + assert.False(t, state.Blocked, "Request should be allowed") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") + }) + + t.Run("Blocks blacklisted CIDR", func(t *testing.T) { + middleware := &Middleware{ + logger: logger, + ipBlacklist: blackList, + CustomResponses: customResponse, + } + + req := httptest.NewRequest("GET", testURL, nil) + req.RemoteAddr = "192.168.1.1" + + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) } func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { @@ -144,18 +216,13 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, - }, + CustomResponses: customResponse, } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("User-Agent", "nikto") // Create a context and add logID to it - FIX: ADD CONTEXT HERE @@ -204,12 +271,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header @@ -257,12 +324,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header @@ -310,12 +377,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header1", "good-value") req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers @@ -364,7 +431,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -417,12 +484,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("User-Agent", "good-user") @@ -470,12 +537,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP // Create a context and add logID to it @@ -522,12 +589,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set + req := httptest.NewRequest("GET", testURL, nil) // Header not set req.RemoteAddr = localIP // Create a context and add logID to it @@ -574,12 +641,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email @@ -627,12 +694,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "good-header") req.Header.Set("User-Agent", "bad-user-agent") @@ -680,12 +747,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "good-header") req.Header.Set("User-Agent", "good-user-agent") @@ -734,7 +801,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -794,7 +861,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -870,12 +937,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("this-is-a-bad-body") @@ -929,12 +996,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString(`{"data":{"malicious":true,"name":"test"}}`) @@ -988,12 +1055,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, strings.NewReader("param1=value1&secret=badvalue¶m2=value2"), ) req.RemoteAddr = localIP @@ -1043,12 +1110,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("User ID: 123-45-6789") @@ -1102,12 +1169,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("this-is-a-good-body") @@ -1161,7 +1228,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1181,7 +1248,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { t.Fatalf("Failed to close multipart writer: %v", err) } - req := httptest.NewRequest("POST", "http://example.com", body) + req := httptest.NewRequest("POST", testURL, body) req.RemoteAddr = localIP req.Header.Set("Content-Type", writer.FormDataContentType()) @@ -1229,12 +1296,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", nil) + req := httptest.NewRequest("POST", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1275,7 +1342,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1288,7 +1355,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1331,7 +1398,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1343,7 +1410,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1387,7 +1454,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1399,7 +1466,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1441,7 +1508,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1453,7 +1520,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1497,12 +1564,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value @@ -1550,12 +1617,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header1", "bad-value") req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value @@ -1579,7 +1646,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'") - req2 := httptest.NewRequest("GET", "http://example.com", nil) + req2 := httptest.NewRequest("GET", testURL, nil) req2.RemoteAddr = localIP req2.Header.Set("X-Custom-Header1", "good-value") req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value @@ -1603,7 +1670,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403") assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'") - req3 := httptest.NewRequest("GET", "http://example.com", nil) + req3 := httptest.NewRequest("GET", testURL, nil) req3.RemoteAddr = localIP req3.Header.Set("X-Custom-Header1", "good-value") req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value @@ -1658,7 +1725,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1728,7 +1795,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1781,7 +1848,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index e3c9c7d..494f788 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist + ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist } diff --git a/request_test.go b/request_test.go index f06d0d3..cc6d353 100644 --- a/request_test.go +++ b/request_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) { name: "Extract METHOD", target: "METHOD", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("POST", "http://example.com", nil) + req := httptest.NewRequest("POST", testURL, nil) return req, httptest.NewRecorder() }, expectedValue: "POST", @@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) { name: "Extract USER_AGENT", target: "USER_AGENT", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("User-Agent", "test-agent") return req, httptest.NewRecorder() }, @@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) { name: "Extract HEADERS prefix", target: "HEADERS:Content-Type", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("Content-Type", "application/json") return req, httptest.NewRecorder() }, @@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) { name: "Empty target", target: "", setupRequest: func() (*http.Request, http.ResponseWriter) { - return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder() + return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder() }, expectedError: true, }, @@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) { ResponseWritten: false, } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) // Create a context and add logID to it - FIX: ADD CONTEXT HERE ctx := context.Background() @@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), rateLimiter: func() *RateLimiter { @@ -459,12 +459,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { } return rl }(), - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, - }, + CustomResponses: customResponse, } // Add some IPs to the blacklist @@ -475,7 +470,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation w := httptest.NewRecorder() diff --git a/response.go b/response.go index 6c88f3b..2a337ec 100644 --- a/response.go +++ b/response.go @@ -8,6 +8,15 @@ import ( "go.uber.org/zap" ) +// allowRequest - handles request allowing +func (m *Middleware) allowRequest(state *WAFState) { + state.Blocked = false + state.StatusCode = http.StatusOK + state.ResponseWritten = false + + m.incrementAllowedRequestsMetric() +} + // blockRequest handles blocking a request and logging the details. func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID, matchedValue string, fields ...zap.Field) { // CRITICAL FIX: Set these flags before any other operations diff --git a/types.go b/types.go index c26c9bc..b4024f5 100644 --- a/types.go +++ b/types.go @@ -6,7 +6,7 @@ import ( "time" "github.com/oschwald/maxminddb-golang" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -104,10 +104,10 @@ type Middleware struct { IPBlacklistFile string `json:"ip_blacklist_file"` DNSBlacklistFile string `json:"dns_blacklist_file"` AnomalyThreshold int `json:"anomaly_threshold"` - CountryBlock CountryAccessFilter `json:"country_block"` + CountryBlacklist CountryAccessFilter `json:"country_blacklist"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` Rules map[int][]Rule `json:"-"` - ipBlacklist *trie.Trie `json:"-"` + ipBlacklist *iptrie.Trie `json:"-"` dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} logger *zap.Logger LogSeverity string `json:"log_severity,omitempty"`