mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 14:17:45 -05:00
test: upd ip blacklist test
This commit is contained in:
@@ -27,7 +27,7 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
@@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
|
||||
// Load IP blacklist
|
||||
if m.IPBlacklistFile != "" {
|
||||
m.ipBlacklist = trie.NewTrie()
|
||||
m.ipBlacklist = iptrie.NewTrie()
|
||||
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
||||
@@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error {
|
||||
|
||||
m.logger.Info("Reloading WAF configuration")
|
||||
if m.IPBlacklistFile != "" {
|
||||
newIPBlacklist := trie.NewTrie()
|
||||
newIPBlacklist := iptrie.NewTrie()
|
||||
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
|
||||
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
|
||||
return fmt.Errorf("failed to reload IP blacklist: %v", err)
|
||||
@@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
|
||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
|
||||
return nil
|
||||
|
||||
@@ -4,5 +4,6 @@ const (
|
||||
geoIPdata = "GeoLite2-Country.mmdb"
|
||||
googleUSIP = "74.125.131.105"
|
||||
localIP = "127.0.0.1"
|
||||
testURL = "http://example.com"
|
||||
torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing
|
||||
)
|
||||
|
||||
182
handler_test.go
182
handler_test.go
@@ -12,7 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
|
||||
dnsBlacklist: map[string]struct{}{
|
||||
"malicious.domain": {},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
@@ -80,7 +80,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Simulate a request from a blocked country (US)
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = googleUSIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -98,32 +98,60 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||
logger, err := zap.NewDevelopment()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ipBlackList := trie.NewTrie()
|
||||
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil)
|
||||
blackList := iptrie.NewTrie()
|
||||
loader := iptrie.NewTrieLoader(blackList)
|
||||
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: ipBlackList,
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
for _, net := range []string{
|
||||
"192.168.0.0/24",
|
||||
"192.168.1.1/32",
|
||||
} {
|
||||
loader.Insert(netip.MustParsePrefix(net), "net="+net)
|
||||
}
|
||||
|
||||
state := &WAFState{}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
resp := map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
t.Run("Allow unblocked CIDR", func(t *testing.T) {
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: resp,
|
||||
}
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
})
|
||||
|
||||
t.Run("Blocks blacklisted CIDR", func(t *testing.T) {
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: resp,
|
||||
}
|
||||
|
||||
req0 := httptest.NewRequest("GET", testURL, nil)
|
||||
req0.RemoteAddr = "192.168.1.1"
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req0, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
@@ -144,7 +172,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
@@ -155,7 +183,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("User-Agent", "nikto")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
@@ -204,12 +232,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
|
||||
|
||||
@@ -257,12 +285,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
|
||||
|
||||
@@ -310,12 +338,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header1", "good-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
|
||||
@@ -364,7 +392,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -417,12 +445,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("User-Agent", "good-user")
|
||||
|
||||
@@ -470,12 +498,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Create a context and add logID to it
|
||||
@@ -522,12 +550,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
|
||||
req := httptest.NewRequest("GET", testURL, nil) // Header not set
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Create a context and add logID to it
|
||||
@@ -574,12 +602,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
|
||||
|
||||
@@ -627,12 +655,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "bad-user-agent")
|
||||
@@ -680,12 +708,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "good-user-agent")
|
||||
@@ -734,7 +762,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -794,7 +822,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -870,12 +898,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this-is-a-bad-body")
|
||||
@@ -929,12 +957,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString(`{"data":{"malicious":true,"name":"test"}}`)
|
||||
@@ -988,12 +1016,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
|
||||
)
|
||||
req.RemoteAddr = localIP
|
||||
@@ -1043,12 +1071,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("User ID: 123-45-6789")
|
||||
@@ -1102,12 +1130,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this-is-a-good-body")
|
||||
@@ -1161,7 +1189,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1181,7 +1209,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
t.Fatalf("Failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com", body)
|
||||
req := httptest.NewRequest("POST", testURL, body)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
@@ -1229,12 +1257,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com", nil)
|
||||
req := httptest.NewRequest("POST", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1275,7 +1303,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1288,7 +1316,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1331,7 +1359,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1343,7 +1371,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1387,7 +1415,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1399,7 +1427,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1441,7 +1469,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1453,7 +1481,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1497,12 +1525,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
|
||||
|
||||
@@ -1550,12 +1578,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header1", "bad-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
|
||||
@@ -1579,7 +1607,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
|
||||
|
||||
req2 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req2 := httptest.NewRequest("GET", testURL, nil)
|
||||
req2.RemoteAddr = localIP
|
||||
req2.Header.Set("X-Custom-Header1", "good-value")
|
||||
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
|
||||
@@ -1603,7 +1631,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
|
||||
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
|
||||
|
||||
req3 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req3 := httptest.NewRequest("GET", testURL, nil)
|
||||
req3.RemoteAddr = localIP
|
||||
req3.Header.Set("X-Custom-Header1", "good-value")
|
||||
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
|
||||
@@ -1658,7 +1686,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
@@ -1728,7 +1756,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
@@ -1781,7 +1809,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
|
||||
ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist
|
||||
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
@@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract METHOD",
|
||||
target: "METHOD",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("POST", "http://example.com", nil)
|
||||
req := httptest.NewRequest("POST", testURL, nil)
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
expectedValue: "POST",
|
||||
@@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract USER_AGENT",
|
||||
target: "USER_AGENT",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
@@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract HEADERS prefix",
|
||||
target: "HEADERS:Content-Type",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
@@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Empty target",
|
||||
target: "",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder()
|
||||
return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder()
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
@@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
|
||||
ResponseWritten: false,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
@@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
rateLimiter: func() *RateLimiter {
|
||||
@@ -475,7 +475,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs
|
||||
req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
4
types.go
4
types.go
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
@@ -107,7 +107,7 @@ type Middleware struct {
|
||||
CountryBlock CountryAccessFilter `json:"country_block"`
|
||||
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
|
||||
Rules map[int][]Rule `json:"-"`
|
||||
ipBlacklist *trie.Trie `json:"-"`
|
||||
ipBlacklist *iptrie.Trie `json:"-"`
|
||||
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
|
||||
logger *zap.Logger
|
||||
LogSeverity string `json:"log_severity,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user