From 145feb4bf8817af6b8c3d2be41ddf53f7472a204 Mon Sep 17 00:00:00 2001 From: drev74 Date: Sun, 12 Oct 2025 13:12:13 +0300 Subject: [PATCH] test: upd ip blacklist test --- caddywaf.go | 8 +- common_test.go | 1 + handler_test.go | 182 +++++++++++++++++++++++++------------------- ratelimiter_test.go | 4 +- request_test.go | 16 ++-- types.go | 4 +- 6 files changed, 122 insertions(+), 93 deletions(-) diff --git a/caddywaf.go b/caddywaf.go index 5fde7d5..9a496bc 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -27,7 +27,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/oschwald/maxminddb-golang" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { // Load IP blacklist if m.IPBlacklistFile != "" { - m.ipBlacklist = trie.NewTrie() + m.ipBlacklist = iptrie.NewTrie() err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist) if err != nil { return fmt.Errorf("failed to load IP blacklist: %w", err) @@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error { m.logger.Info("Reloading WAF configuration") if m.IPBlacklistFile != "" { - newIPBlacklist := trie.NewTrie() + newIPBlacklist := iptrie.NewTrie() if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil { m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err)) return fmt.Errorf("failed to reload IP blacklist: %v", err) @@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error { return nil } -func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error { +func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error { if _, err := os.Stat(path); os.IsNotExist(err) { m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path)) return nil diff --git a/common_test.go b/common_test.go index 707b24f..b6fcf04 100644 --- a/common_test.go +++ b/common_test.go @@ -4,5 +4,6 @@ const ( geoIPdata = "GeoLite2-Country.mmdb" googleUSIP = "74.125.131.105" localIP = "127.0.0.1" + testURL = "http://example.com" torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing ) diff --git a/handler_test.go b/handler_test.go index 855e240..b304798 100644 --- a/handler_test.go +++ b/handler_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -26,7 +26,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) { dnsBlacklist: map[string]struct{}{ "malicious.domain": {}, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), CustomResponses: map[int]CustomBlockResponse{ 403: { StatusCode: http.StatusForbidden, @@ -80,7 +80,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { } // Simulate a request from a blocked country (US) - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = googleUSIP w := httptest.NewRecorder() state := &WAFState{} @@ -98,32 +98,60 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) { logger, err := zap.NewDevelopment() assert.NoError(t, err) - ipBlackList := trie.NewTrie() - ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil) + blackList := iptrie.NewTrie() + loader := iptrie.NewTrieLoader(blackList) - middleware := &Middleware{ - logger: logger, - ipBlacklist: ipBlackList, - CustomResponses: map[int]CustomBlockResponse{ - 403: { - StatusCode: http.StatusForbidden, - Body: "Access Denied", - }, + for _, net := range []string{ + "192.168.0.0/24", + "192.168.1.1/32", + } { + loader.Insert(netip.MustParsePrefix(net), "net="+net) + } + + state := &WAFState{} + w := httptest.NewRecorder() + + resp := map[int]CustomBlockResponse{ + 403: { + StatusCode: http.StatusForbidden, + Body: "Access Denied", }, } - req := httptest.NewRequest("GET", "http://example.com", nil) - req.RemoteAddr = localIP - w := httptest.NewRecorder() - state := &WAFState{} + t.Run("Allow unblocked CIDR", func(t *testing.T) { + middleware := &Middleware{ + logger: logger, + ipBlacklist: blackList, + CustomResponses: resp, + } - // Process the request in Phase 1 - middleware.handlePhase(w, req, 1, state) + req := httptest.NewRequest("GET", testURL, nil) + req.RemoteAddr = localIP - // Verify that the request was blocked - assert.True(t, state.Blocked, "Request should be blocked") - assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") - assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + + assert.False(t, state.Blocked, "Request should be allowed") + }) + + t.Run("Blocks blacklisted CIDR", func(t *testing.T) { + middleware := &Middleware{ + logger: logger, + ipBlacklist: blackList, + CustomResponses: resp, + } + + req0 := httptest.NewRequest("GET", testURL, nil) + req0.RemoteAddr = "192.168.1.1" + + // Process the request in Phase 1 + middleware.handlePhase(w, req0, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) } func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { @@ -144,7 +172,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), CustomResponses: map[int]CustomBlockResponse{ @@ -155,7 +183,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { }, } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("User-Agent", "nikto") // Create a context and add logID to it - FIX: ADD CONTEXT HERE @@ -204,12 +232,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header @@ -257,12 +285,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header @@ -310,12 +338,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header1", "good-value") req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers @@ -364,7 +392,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -417,12 +445,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("User-Agent", "good-user") @@ -470,12 +498,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP // Create a context and add logID to it @@ -522,12 +550,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set + req := httptest.NewRequest("GET", testURL, nil) // Header not set req.RemoteAddr = localIP // Create a context and add logID to it @@ -574,12 +602,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email @@ -627,12 +655,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "good-header") req.Header.Set("User-Agent", "bad-user-agent") @@ -680,12 +708,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "good-header") req.Header.Set("User-Agent", "good-user-agent") @@ -734,7 +762,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -794,7 +822,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -870,12 +898,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("this-is-a-bad-body") @@ -929,12 +957,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString(`{"data":{"malicious":true,"name":"test"}}`) @@ -988,12 +1016,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, strings.NewReader("param1=value1&secret=badvalue¶m2=value2"), ) req.RemoteAddr = localIP @@ -1043,12 +1071,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("User ID: 123-45-6789") @@ -1102,12 +1130,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", + req := httptest.NewRequest("POST", testURL, func() *bytes.Buffer { b := new(bytes.Buffer) b.WriteString("this-is-a-good-body") @@ -1161,7 +1189,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1181,7 +1209,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { t.Fatalf("Failed to close multipart writer: %v", err) } - req := httptest.NewRequest("POST", "http://example.com", body) + req := httptest.NewRequest("POST", testURL, body) req.RemoteAddr = localIP req.Header.Set("Content-Type", writer.FormDataContentType()) @@ -1229,12 +1257,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("POST", "http://example.com", nil) + req := httptest.NewRequest("POST", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1275,7 +1303,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1288,7 +1316,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1331,7 +1359,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1343,7 +1371,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1387,7 +1415,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1399,7 +1427,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1441,7 +1469,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1453,7 +1481,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { }) }() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP w := httptest.NewRecorder() state := &WAFState{} @@ -1497,12 +1525,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value @@ -1550,12 +1578,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = localIP req.Header.Set("X-Custom-Header1", "bad-value") req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value @@ -1579,7 +1607,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'") - req2 := httptest.NewRequest("GET", "http://example.com", nil) + req2 := httptest.NewRequest("GET", testURL, nil) req2.RemoteAddr = localIP req2.Header.Set("X-Custom-Header1", "good-value") req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value @@ -1603,7 +1631,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403") assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'") - req3 := httptest.NewRequest("GET", "http://example.com", nil) + req3 := httptest.NewRequest("GET", testURL, nil) req3.RemoteAddr = localIP req3.Header.Set("X-Custom-Header1", "good-value") req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value @@ -1658,7 +1686,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1728,7 +1756,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1781,7 +1809,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index e3c9c7d..494f788 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist + ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist } diff --git a/request_test.go b/request_test.go index f06d0d3..d9bd7d0 100644 --- a/request_test.go +++ b/request_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) { name: "Extract METHOD", target: "METHOD", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("POST", "http://example.com", nil) + req := httptest.NewRequest("POST", testURL, nil) return req, httptest.NewRecorder() }, expectedValue: "POST", @@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) { name: "Extract USER_AGENT", target: "USER_AGENT", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("User-Agent", "test-agent") return req, httptest.NewRecorder() }, @@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) { name: "Extract HEADERS prefix", target: "HEADERS:Content-Type", setupRequest: func() (*http.Request, http.ResponseWriter) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.Header.Set("Content-Type", "application/json") return req, httptest.NewRecorder() }, @@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) { name: "Empty target", target: "", setupRequest: func() (*http.Request, http.ResponseWriter) { - return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder() + return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder() }, expectedError: true, }, @@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) { ResponseWritten: false, } - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) // Create a context and add logID to it - FIX: ADD CONTEXT HERE ctx := context.Background() @@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: trie.NewTrie(), + ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), rateLimiter: func() *RateLimiter { @@ -475,7 +475,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", testURL, nil) req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation w := httptest.NewRecorder() diff --git a/types.go b/types.go index c26c9bc..16496ac 100644 --- a/types.go +++ b/types.go @@ -6,7 +6,7 @@ import ( "time" "github.com/oschwald/maxminddb-golang" - trie "github.com/phemmer/go-iptrie" + "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -107,7 +107,7 @@ type Middleware struct { CountryBlock CountryAccessFilter `json:"country_block"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` Rules map[int][]Rule `json:"-"` - ipBlacklist *trie.Trie `json:"-"` + ipBlacklist *iptrie.Trie `json:"-"` dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} logger *zap.Logger LogSeverity string `json:"log_severity,omitempty"`