test: upd ip blacklist test

This commit is contained in:
drev74
2025-10-12 13:12:13 +03:00
parent 485c86fdbc
commit 145feb4bf8
6 changed files with 122 additions and 93 deletions

View File

@@ -27,7 +27,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
// Load IP blacklist
if m.IPBlacklistFile != "" {
m.ipBlacklist = trie.NewTrie()
m.ipBlacklist = iptrie.NewTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
@@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error {
m.logger.Info("Reloading WAF configuration")
if m.IPBlacklistFile != "" {
newIPBlacklist := trie.NewTrie()
newIPBlacklist := iptrie.NewTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload IP blacklist: %v", err)
@@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
return nil
}
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
return nil

View File

@@ -4,5 +4,6 @@ const (
geoIPdata = "GeoLite2-Country.mmdb"
googleUSIP = "74.125.131.105"
localIP = "127.0.0.1"
testURL = "http://example.com"
torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing
)

View File

@@ -12,7 +12,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@@ -26,7 +26,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
dnsBlacklist: map[string]struct{}{
"malicious.domain": {},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
@@ -80,7 +80,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
}
// Simulate a request from a blocked country (US)
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = googleUSIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -98,32 +98,60 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
ipBlackList := trie.NewTrie()
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil)
blackList := iptrie.NewTrie()
loader := iptrie.NewTrieLoader(blackList)
middleware := &Middleware{
logger: logger,
ipBlacklist: ipBlackList,
CustomResponses: map[int]CustomBlockResponse{
for _, net := range []string{
"192.168.0.0/24",
"192.168.1.1/32",
} {
loader.Insert(netip.MustParsePrefix(net), "net="+net)
}
state := &WAFState{}
w := httptest.NewRecorder()
resp := map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
req := httptest.NewRequest("GET", "http://example.com", nil)
t.Run("Allow unblocked CIDR", func(t *testing.T) {
middleware := &Middleware{
logger: logger,
ipBlacklist: blackList,
CustomResponses: resp,
}
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
assert.False(t, state.Blocked, "Request should be allowed")
})
t.Run("Blocks blacklisted CIDR", func(t *testing.T) {
middleware := &Middleware{
logger: logger,
ipBlacklist: blackList,
CustomResponses: resp,
}
req0 := httptest.NewRequest("GET", testURL, nil)
req0.RemoteAddr = "192.168.1.1"
// Process the request in Phase 1
middleware.handlePhase(w, req0, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
})
}
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
@@ -144,7 +172,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
CustomResponses: map[int]CustomBlockResponse{
@@ -155,7 +183,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
},
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("User-Agent", "nikto")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
@@ -204,12 +232,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
@@ -257,12 +285,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
@@ -310,12 +338,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "good-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
@@ -364,7 +392,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -417,12 +445,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("User-Agent", "good-user")
@@ -470,12 +498,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
// Create a context and add logID to it
@@ -522,12 +550,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
req := httptest.NewRequest("GET", testURL, nil) // Header not set
req.RemoteAddr = localIP
// Create a context and add logID to it
@@ -574,12 +602,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
@@ -627,12 +655,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "bad-user-agent")
@@ -680,12 +708,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "good-user-agent")
@@ -734,7 +762,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -794,7 +822,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -870,12 +898,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-bad-body")
@@ -929,12 +957,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString(`{"data":{"malicious":true,"name":"test"}}`)
@@ -988,12 +1016,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
)
req.RemoteAddr = localIP
@@ -1043,12 +1071,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("User ID: 123-45-6789")
@@ -1102,12 +1130,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-good-body")
@@ -1161,7 +1189,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1181,7 +1209,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
t.Fatalf("Failed to close multipart writer: %v", err)
}
req := httptest.NewRequest("POST", "http://example.com", body)
req := httptest.NewRequest("POST", testURL, body)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", writer.FormDataContentType())
@@ -1229,12 +1257,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com", nil)
req := httptest.NewRequest("POST", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1275,7 +1303,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1288,7 +1316,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1331,7 +1359,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1343,7 +1371,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1387,7 +1415,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1399,7 +1427,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1441,7 +1469,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1453,7 +1481,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1497,12 +1525,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
@@ -1550,12 +1578,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "bad-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
@@ -1579,7 +1607,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req2 := httptest.NewRequest("GET", "http://example.com", nil)
req2 := httptest.NewRequest("GET", testURL, nil)
req2.RemoteAddr = localIP
req2.Header.Set("X-Custom-Header1", "good-value")
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
@@ -1603,7 +1631,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req3 := httptest.NewRequest("GET", "http://example.com", nil)
req3 := httptest.NewRequest("GET", testURL, nil)
req3.RemoteAddr = localIP
req3.Header.Set("X-Custom-Header1", "good-value")
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
@@ -1658,7 +1686,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1728,7 +1756,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1781,7 +1809,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}

View File

@@ -7,7 +7,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
@@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
}

View File

@@ -11,7 +11,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract METHOD",
target: "METHOD",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("POST", "http://example.com", nil)
req := httptest.NewRequest("POST", testURL, nil)
return req, httptest.NewRecorder()
},
expectedValue: "POST",
@@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract USER_AGENT",
target: "USER_AGENT",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("User-Agent", "test-agent")
return req, httptest.NewRecorder()
},
@@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract HEADERS prefix",
target: "HEADERS:Content-Type",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("Content-Type", "application/json")
return req, httptest.NewRecorder()
},
@@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) {
name: "Empty target",
target: "",
setupRequest: func() (*http.Request, http.ResponseWriter) {
return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder()
return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder()
},
expectedError: true,
},
@@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
ResponseWritten: false,
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
@@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
rateLimiter: func() *RateLimiter {
@@ -475,7 +475,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
wg.Add(1)
go func(i int) {
defer wg.Done()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs
req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation
w := httptest.NewRecorder()

View File

@@ -6,7 +6,7 @@ import (
"time"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -107,7 +107,7 @@ type Middleware struct {
CountryBlock CountryAccessFilter `json:"country_block"`
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
Rules map[int][]Rule `json:"-"`
ipBlacklist *trie.Trie `json:"-"`
ipBlacklist *iptrie.Trie `json:"-"`
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
logger *zap.Logger
LogSeverity string `json:"log_severity,omitempty"`