Merge pull request #65 from drev74/test/blocking

feat: impl country whitelisting
This commit is contained in:
fab
2025-10-12 19:33:38 +02:00
committed by GitHub
11 changed files with 287 additions and 173 deletions

View File

@@ -27,7 +27,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -193,23 +193,23 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
// Initialize GeoIP stats
m.geoIPStats = make(map[string]int64)
// Configure GeoIP-based country blocking/whitelisting
if m.CountryBlock.Enabled || m.CountryWhitelist.Enabled {
geoIPPath := m.CountryBlock.GeoIPDBPath
// Configure GeoIP-based country blacklisting/whitelisting
if m.CountryBlacklist.Enabled || m.CountryWhitelist.Enabled {
geoIPPath := m.CountryBlacklist.GeoIPDBPath
if m.CountryWhitelist.Enabled && m.CountryWhitelist.GeoIPDBPath != "" {
geoIPPath = m.CountryWhitelist.GeoIPDBPath
}
if !fileExists(geoIPPath) {
m.logger.Warn("GeoIP database not found. Country blocking/whitelisting will be disabled", zap.String("path", geoIPPath))
m.logger.Warn("GeoIP database not found. Country blacklisting/whitelisting will be disabled", zap.String("path", geoIPPath))
} else {
reader, err := maxminddb.Open(geoIPPath)
if err != nil {
m.logger.Error("Failed to load GeoIP database", zap.String("path", geoIPPath), zap.Error(err))
} else {
m.logger.Info("GeoIP database loaded successfully", zap.String("path", geoIPPath))
if m.CountryBlock.Enabled {
m.CountryBlock.geoIP = reader
if m.CountryBlacklist.Enabled {
m.CountryBlacklist.geoIP = reader
}
if m.CountryWhitelist.Enabled {
m.CountryWhitelist.geoIP = reader
@@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
// Load IP blacklist
if m.IPBlacklistFile != "" {
m.ipBlacklist = trie.NewTrie()
m.ipBlacklist = iptrie.NewTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
@@ -288,20 +288,20 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
var errorOccurred bool
// Close GeoIP databases
if m.CountryBlock.geoIP != nil {
m.logger.Debug("Closing country block GeoIP database...")
if err := m.CountryBlock.geoIP.Close(); err != nil {
m.logger.Error("Error encountered while closing country block GeoIP database", zap.Error(err))
if m.CountryBlacklist.geoIP != nil {
m.logger.Debug("Closing country blacklist GeoIP database...")
if err := m.CountryBlacklist.geoIP.Close(); err != nil {
m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err))
if !errorOccurred {
firstError = fmt.Errorf("error closing country block GeoIP: %w", err)
firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err)
errorOccurred = true
}
} else {
m.logger.Debug("Country block GeoIP database closed successfully.")
m.logger.Debug("Country blacklist GeoIP database closed successfully.")
}
m.CountryBlock.geoIP = nil
m.CountryBlacklist.geoIP = nil
} else {
m.logger.Debug("Country block GeoIP database was not open, skipping close.")
m.logger.Debug("Country blacklist GeoIP database was not open, skipping close.")
}
if m.CountryWhitelist.geoIP != nil {
@@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error {
m.logger.Info("Reloading WAF configuration")
if m.IPBlacklistFile != "" {
newIPBlacklist := trie.NewTrie()
newIPBlacklist := iptrie.NewTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload IP blacklist: %v", err)
@@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
return nil
}
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
return nil

View File

@@ -32,7 +32,7 @@ func TestMiddleware_Provision(t *testing.T) {
IPBlacklistFile: "testdata/ip_blacklist.txt",
DNSBlacklistFile: "testdata/dns_blacklist.txt",
AnomalyThreshold: 10,
CountryBlock: CountryAccessFilter{
CountryBlacklist: CountryAccessFilter{
Enabled: true,
CountryList: []string{"US"},
GeoIPDBPath: "testdata/GeoIP2-Country-Test.mmdb",

View File

@@ -1,8 +1,21 @@
package caddywaf
import "net/http"
const (
geoIPdata = "GeoLite2-Country.mmdb"
googleUSIP = "74.125.131.105"
localIP = "127.0.0.1"
aliCNIP = "47.88.198.38"
googleUSIP = "74.125.131.105"
googleBRIP = "128.201.228.12"
googleRUIP = "74.125.131.94"
testURL = "http://example.com"
torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing
)
var customResponse = map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
}

View File

@@ -140,7 +140,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
m.LogSeverity = "info"
m.LogJSON = false
m.AnomalyThreshold = 5
m.CountryBlock.Enabled = false
m.CountryBlacklist.Enabled = false
m.CountryWhitelist.Enabled = false
m.LogFilePath = "debug.json"
m.RedactSensitiveData = false
@@ -269,7 +269,7 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
// parseCountryBlockDirective returns a closure to handle block_countries and whitelist_countries directives.
func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfile.Dispenser, m *Middleware) error {
return func(d *caddyfile.Dispenser, m *Middleware) error {
target := &m.CountryBlock
target := &m.CountryBlacklist
directiveName := "block_countries"
if !isBlock {
target = &m.CountryWhitelist

View File

@@ -201,14 +201,14 @@ func TestParseCountryBlock(t *testing.T) {
t.Fatalf("parseCountryBlockDirective failed: %v", err)
}
if !m.CountryBlock.Enabled {
t.Errorf("Expected country block to be enabled, got %v", m.CountryBlock.Enabled)
if !m.CountryBlacklist.Enabled {
t.Errorf("Expected country blacklist to be enabled, got %v", m.CountryBlacklist.Enabled)
}
if m.CountryBlock.GeoIPDBPath != "/etc/geoip/GeoIP.dat" {
t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlock.GeoIPDBPath)
if m.CountryBlacklist.GeoIPDBPath != "/etc/geoip/GeoIP.dat" {
t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlacklist.GeoIPDBPath)
}
if len(m.CountryBlock.CountryList) != 2 || m.CountryBlock.CountryList[0] != "US" || m.CountryBlock.CountryList[1] != "CA" {
t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlock.CountryList)
if len(m.CountryBlacklist.CountryList) != 2 || m.CountryBlacklist.CountryList[0] != "US" || m.CountryBlacklist.CountryList[1] != "CA" {
t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlacklist.CountryList)
}
}

View File

@@ -230,18 +230,18 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
zap.String("user_agent", r.UserAgent()),
)
if phase == 1 && m.CountryBlock.Enabled {
m.logger.Debug("Starting country blocking phase")
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
if phase == 1 && m.CountryBlacklist.Enabled {
m.logger.Debug("Starting country blacklisting phase")
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check country block",
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked due to internal error"),
)
m.logger.Debug("Country blocking phase completed - blocked due to error")
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
return
} else if blocked {
@@ -254,7 +254,34 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
}
return
}
m.logger.Debug("Country blocking phase completed - not blocked")
m.logger.Debug("Country blacklisting phase completed - not blocked")
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
}
if phase == 1 && m.CountryWhitelist.Enabled {
m.logger.Debug("Starting country whitelisting phase")
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked due to internal error"),
)
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
return
} else if !allowed {
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country"))
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
if m.CustomResponses != nil {
m.writeCustomResponse(w, state.StatusCode)
}
return
}
m.logger.Debug("Country whitelisting phase completed - not blocked")
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
}
@@ -327,7 +354,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
rules, ok := m.Rules[phase]
if !ok {
m.logger.Debug("No rules found for phase", zap.Int("phase", phase))
return
// Don't block on empty rules. There may be no rules specified
// return
}
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
@@ -434,6 +462,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
m.allowRequest(state)
}
// incrementRateLimiterBlockedRequestsMetric increments the blocked requests metric for the rate limiter.

View File

@@ -12,7 +12,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@@ -26,32 +26,37 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
dnsBlacklist: map[string]struct{}{
"malicious.domain": {},
},
ipBlacklist: trie.NewTrie(),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
ipBlacklist: iptrie.NewTrie(),
CustomResponses: customResponse,
}
// Simulate a request to a blacklisted domain
req := httptest.NewRequest("GET", "http://malicious.domain", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
t.Run("Allow unblocked domain", func(t *testing.T) {
// Simulate a request to a blacklisted domain
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
// Debug: Print the response body and status code
t.Logf("Response Body: %s", w.Body.String())
t.Logf("Response Status Code: %d", w.Code)
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
assert.False(t, state.Blocked, "Request should be allowed")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
})
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
t.Run("Block blacklisted domain", func(t *testing.T) {
// Simulate a request to a blacklisted domain
req := httptest.NewRequest("GET", "http://malicious.domain", nil)
req.RemoteAddr = localIP
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
})
}
func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
@@ -62,68 +67,135 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata)
assert.NoError(t, err)
middleware := &Middleware{
blackListMiddleware := &Middleware{
logger: logger,
ipBlacklist: iptrie.NewTrie(),
geoIPHandler: geoIPHandler,
CountryBlock: CountryAccessFilter{
CountryBlacklist: CountryAccessFilter{
Enabled: true,
CountryList: []string{"US"},
CountryList: []string{"US", "RU"},
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
geoIP: geoIPBlock,
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
CustomResponses: customResponse,
}
// Simulate a request from a blocked country (US)
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = googleUSIP
w := httptest.NewRecorder()
whiteListMiddleware := &Middleware{
logger: logger,
ipBlacklist: iptrie.NewTrie(),
geoIPHandler: geoIPHandler,
CountryWhitelist: CountryAccessFilter{
Enabled: true,
CountryList: []string{"BR"},
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
geoIP: geoIPBlock,
},
CustomResponses: customResponse,
}
req := httptest.NewRequest("GET", testURL, nil)
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
t.Run("GeoIP Blacklist: Allow CN IP", func(t *testing.T) {
w := httptest.NewRecorder()
req.RemoteAddr = aliCNIP
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
// Process the request in Phase 1
blackListMiddleware.handlePhase(w, req, 1, state)
assert.False(t, state.Blocked, "Request should be allowed")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
})
t.Run("GeoIP Blacklist: Block US IP", func(t *testing.T) {
w := httptest.NewRecorder()
req.RemoteAddr = googleUSIP
// Process the request in Phase 1
blackListMiddleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
})
t.Run("GeoIP Whitelist: Allow BR IP", func(t *testing.T) {
w := httptest.NewRecorder()
req.RemoteAddr = googleBRIP
// Process the request in Phase 1
whiteListMiddleware.handlePhase(w, req, 1, state)
assert.False(t, state.Blocked, "Request should be allowed")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
})
t.Run("GeoIP Whitelist: Block RU IP", func(t *testing.T) {
w := httptest.NewRecorder()
req.RemoteAddr = googleRUIP
// Process the request in Phase 1
whiteListMiddleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
})
}
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
ipBlackList := trie.NewTrie()
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil)
blackList := iptrie.NewTrie()
loader := iptrie.NewTrieLoader(blackList)
middleware := &Middleware{
logger: logger,
ipBlacklist: ipBlackList,
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
for _, net := range []string{
"192.168.0.0/24",
"192.168.1.1/32",
} {
loader.Insert(netip.MustParsePrefix(net), "net="+net)
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
w := httptest.NewRecorder()
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
t.Run("Allow unblocked CIDR", func(t *testing.T) {
middleware := &Middleware{
logger: logger,
ipBlacklist: blackList,
CustomResponses: customResponse,
}
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
assert.False(t, state.Blocked, "Request should be allowed")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
})
t.Run("Blocks blacklisted CIDR", func(t *testing.T) {
middleware := &Middleware{
logger: logger,
ipBlacklist: blackList,
CustomResponses: customResponse,
}
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = "192.168.1.1"
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
})
}
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
@@ -144,18 +216,13 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
CustomResponses: customResponse,
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("User-Agent", "nikto")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
@@ -204,12 +271,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
@@ -257,12 +324,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
@@ -310,12 +377,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "good-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
@@ -364,7 +431,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -417,12 +484,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("User-Agent", "good-user")
@@ -470,12 +537,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
// Create a context and add logID to it
@@ -522,12 +589,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
req := httptest.NewRequest("GET", testURL, nil) // Header not set
req.RemoteAddr = localIP
// Create a context and add logID to it
@@ -574,12 +641,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
@@ -627,12 +694,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "bad-user-agent")
@@ -680,12 +747,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "good-user-agent")
@@ -734,7 +801,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -794,7 +861,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -870,12 +937,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-bad-body")
@@ -929,12 +996,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString(`{"data":{"malicious":true,"name":"test"}}`)
@@ -988,12 +1055,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
)
req.RemoteAddr = localIP
@@ -1043,12 +1110,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("User ID: 123-45-6789")
@@ -1102,12 +1169,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
req := httptest.NewRequest("POST", testURL,
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-good-body")
@@ -1161,7 +1228,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1181,7 +1248,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
t.Fatalf("Failed to close multipart writer: %v", err)
}
req := httptest.NewRequest("POST", "http://example.com", body)
req := httptest.NewRequest("POST", testURL, body)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", writer.FormDataContentType())
@@ -1229,12 +1296,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com", nil)
req := httptest.NewRequest("POST", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1275,7 +1342,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1288,7 +1355,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1331,7 +1398,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1343,7 +1410,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1387,7 +1454,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1399,7 +1466,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1441,7 +1508,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1453,7 +1520,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
@@ -1497,12 +1564,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
@@ -1550,12 +1617,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "bad-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
@@ -1579,7 +1646,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req2 := httptest.NewRequest("GET", "http://example.com", nil)
req2 := httptest.NewRequest("GET", testURL, nil)
req2.RemoteAddr = localIP
req2.Header.Set("X-Custom-Header1", "good-value")
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
@@ -1603,7 +1670,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req3 := httptest.NewRequest("GET", "http://example.com", nil)
req3 := httptest.NewRequest("GET", testURL, nil)
req3.RemoteAddr = localIP
req3.Header.Set("X-Custom-Header1", "good-value")
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
@@ -1658,7 +1725,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1728,7 +1795,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1781,7 +1848,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}

View File

@@ -7,7 +7,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
@@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
}

View File

@@ -11,7 +11,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract METHOD",
target: "METHOD",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("POST", "http://example.com", nil)
req := httptest.NewRequest("POST", testURL, nil)
return req, httptest.NewRecorder()
},
expectedValue: "POST",
@@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract USER_AGENT",
target: "USER_AGENT",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("User-Agent", "test-agent")
return req, httptest.NewRecorder()
},
@@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) {
name: "Extract HEADERS prefix",
target: "HEADERS:Content-Type",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.Header.Set("Content-Type", "application/json")
return req, httptest.NewRecorder()
},
@@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) {
name: "Empty target",
target: "",
setupRequest: func() (*http.Request, http.ResponseWriter) {
return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder()
return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder()
},
expectedError: true,
},
@@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
ResponseWritten: false,
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
@@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
ipBlacklist: iptrie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
rateLimiter: func() *RateLimiter {
@@ -459,12 +459,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
}
return rl
}(),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
CustomResponses: customResponse,
}
// Add some IPs to the blacklist
@@ -475,7 +470,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
wg.Add(1)
go func(i int) {
defer wg.Done()
req := httptest.NewRequest("GET", "http://example.com", nil)
req := httptest.NewRequest("GET", testURL, nil)
req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs
req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation
w := httptest.NewRecorder()

View File

@@ -8,6 +8,15 @@ import (
"go.uber.org/zap"
)
// allowRequest - handles request allowing
func (m *Middleware) allowRequest(state *WAFState) {
state.Blocked = false
state.StatusCode = http.StatusOK
state.ResponseWritten = false
m.incrementAllowedRequestsMetric()
}
// blockRequest handles blocking a request and logging the details.
func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID, matchedValue string, fields ...zap.Field) {
// CRITICAL FIX: Set these flags before any other operations

View File

@@ -6,7 +6,7 @@ import (
"time"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -104,10 +104,10 @@ type Middleware struct {
IPBlacklistFile string `json:"ip_blacklist_file"`
DNSBlacklistFile string `json:"dns_blacklist_file"`
AnomalyThreshold int `json:"anomaly_threshold"`
CountryBlock CountryAccessFilter `json:"country_block"`
CountryBlacklist CountryAccessFilter `json:"country_blacklist"`
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
Rules map[int][]Rule `json:"-"`
ipBlacklist *trie.Trie `json:"-"`
ipBlacklist *iptrie.Trie `json:"-"`
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
logger *zap.Logger
LogSeverity string `json:"log_severity,omitempty"`