mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
Merge pull request #65 from drev74/test/blocking
feat: impl country whitelisting
This commit is contained in:
36
caddywaf.go
36
caddywaf.go
@@ -27,7 +27,7 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
@@ -193,23 +193,23 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
// Initialize GeoIP stats
|
||||
m.geoIPStats = make(map[string]int64)
|
||||
|
||||
// Configure GeoIP-based country blocking/whitelisting
|
||||
if m.CountryBlock.Enabled || m.CountryWhitelist.Enabled {
|
||||
geoIPPath := m.CountryBlock.GeoIPDBPath
|
||||
// Configure GeoIP-based country blacklisting/whitelisting
|
||||
if m.CountryBlacklist.Enabled || m.CountryWhitelist.Enabled {
|
||||
geoIPPath := m.CountryBlacklist.GeoIPDBPath
|
||||
if m.CountryWhitelist.Enabled && m.CountryWhitelist.GeoIPDBPath != "" {
|
||||
geoIPPath = m.CountryWhitelist.GeoIPDBPath
|
||||
}
|
||||
|
||||
if !fileExists(geoIPPath) {
|
||||
m.logger.Warn("GeoIP database not found. Country blocking/whitelisting will be disabled", zap.String("path", geoIPPath))
|
||||
m.logger.Warn("GeoIP database not found. Country blacklisting/whitelisting will be disabled", zap.String("path", geoIPPath))
|
||||
} else {
|
||||
reader, err := maxminddb.Open(geoIPPath)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to load GeoIP database", zap.String("path", geoIPPath), zap.Error(err))
|
||||
} else {
|
||||
m.logger.Info("GeoIP database loaded successfully", zap.String("path", geoIPPath))
|
||||
if m.CountryBlock.Enabled {
|
||||
m.CountryBlock.geoIP = reader
|
||||
if m.CountryBlacklist.Enabled {
|
||||
m.CountryBlacklist.geoIP = reader
|
||||
}
|
||||
if m.CountryWhitelist.Enabled {
|
||||
m.CountryWhitelist.geoIP = reader
|
||||
@@ -237,7 +237,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
|
||||
// Load IP blacklist
|
||||
if m.IPBlacklistFile != "" {
|
||||
m.ipBlacklist = trie.NewTrie()
|
||||
m.ipBlacklist = iptrie.NewTrie()
|
||||
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
||||
@@ -288,20 +288,20 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
|
||||
var errorOccurred bool
|
||||
|
||||
// Close GeoIP databases
|
||||
if m.CountryBlock.geoIP != nil {
|
||||
m.logger.Debug("Closing country block GeoIP database...")
|
||||
if err := m.CountryBlock.geoIP.Close(); err != nil {
|
||||
m.logger.Error("Error encountered while closing country block GeoIP database", zap.Error(err))
|
||||
if m.CountryBlacklist.geoIP != nil {
|
||||
m.logger.Debug("Closing country blacklist GeoIP database...")
|
||||
if err := m.CountryBlacklist.geoIP.Close(); err != nil {
|
||||
m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err))
|
||||
if !errorOccurred {
|
||||
firstError = fmt.Errorf("error closing country block GeoIP: %w", err)
|
||||
firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err)
|
||||
errorOccurred = true
|
||||
}
|
||||
} else {
|
||||
m.logger.Debug("Country block GeoIP database closed successfully.")
|
||||
m.logger.Debug("Country blacklist GeoIP database closed successfully.")
|
||||
}
|
||||
m.CountryBlock.geoIP = nil
|
||||
m.CountryBlacklist.geoIP = nil
|
||||
} else {
|
||||
m.logger.Debug("Country block GeoIP database was not open, skipping close.")
|
||||
m.logger.Debug("Country blacklist GeoIP database was not open, skipping close.")
|
||||
}
|
||||
|
||||
if m.CountryWhitelist.geoIP != nil {
|
||||
@@ -426,7 +426,7 @@ func (m *Middleware) ReloadConfig() error {
|
||||
|
||||
m.logger.Info("Reloading WAF configuration")
|
||||
if m.IPBlacklistFile != "" {
|
||||
newIPBlacklist := trie.NewTrie()
|
||||
newIPBlacklist := iptrie.NewTrie()
|
||||
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
|
||||
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
|
||||
return fmt.Errorf("failed to reload IP blacklist: %v", err)
|
||||
@@ -452,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
|
||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap iptrie.Trie) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
|
||||
return nil
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestMiddleware_Provision(t *testing.T) {
|
||||
IPBlacklistFile: "testdata/ip_blacklist.txt",
|
||||
DNSBlacklistFile: "testdata/dns_blacklist.txt",
|
||||
AnomalyThreshold: 10,
|
||||
CountryBlock: CountryAccessFilter{
|
||||
CountryBlacklist: CountryAccessFilter{
|
||||
Enabled: true,
|
||||
CountryList: []string{"US"},
|
||||
GeoIPDBPath: "testdata/GeoIP2-Country-Test.mmdb",
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
package caddywaf
|
||||
|
||||
import "net/http"
|
||||
|
||||
const (
|
||||
geoIPdata = "GeoLite2-Country.mmdb"
|
||||
googleUSIP = "74.125.131.105"
|
||||
localIP = "127.0.0.1"
|
||||
aliCNIP = "47.88.198.38"
|
||||
googleUSIP = "74.125.131.105"
|
||||
googleBRIP = "128.201.228.12"
|
||||
googleRUIP = "74.125.131.94"
|
||||
testURL = "http://example.com"
|
||||
torListURL = "https://cdn.nws.neurodyne.pro/nws-cdn-ut8hw561/waf/torbulkexitlist" // custom TOR list URL for testing
|
||||
)
|
||||
|
||||
var customResponse = map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -140,7 +140,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
|
||||
m.LogSeverity = "info"
|
||||
m.LogJSON = false
|
||||
m.AnomalyThreshold = 5
|
||||
m.CountryBlock.Enabled = false
|
||||
m.CountryBlacklist.Enabled = false
|
||||
m.CountryWhitelist.Enabled = false
|
||||
m.LogFilePath = "debug.json"
|
||||
m.RedactSensitiveData = false
|
||||
@@ -269,7 +269,7 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar
|
||||
// parseCountryBlockDirective returns a closure to handle block_countries and whitelist_countries directives.
|
||||
func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
return func(d *caddyfile.Dispenser, m *Middleware) error {
|
||||
target := &m.CountryBlock
|
||||
target := &m.CountryBlacklist
|
||||
directiveName := "block_countries"
|
||||
if !isBlock {
|
||||
target = &m.CountryWhitelist
|
||||
|
||||
@@ -201,14 +201,14 @@ func TestParseCountryBlock(t *testing.T) {
|
||||
t.Fatalf("parseCountryBlockDirective failed: %v", err)
|
||||
}
|
||||
|
||||
if !m.CountryBlock.Enabled {
|
||||
t.Errorf("Expected country block to be enabled, got %v", m.CountryBlock.Enabled)
|
||||
if !m.CountryBlacklist.Enabled {
|
||||
t.Errorf("Expected country blacklist to be enabled, got %v", m.CountryBlacklist.Enabled)
|
||||
}
|
||||
if m.CountryBlock.GeoIPDBPath != "/etc/geoip/GeoIP.dat" {
|
||||
t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlock.GeoIPDBPath)
|
||||
if m.CountryBlacklist.GeoIPDBPath != "/etc/geoip/GeoIP.dat" {
|
||||
t.Errorf("Expected GeoIP DB path to be '/etc/geoip/GeoIP.dat', got '%s'", m.CountryBlacklist.GeoIPDBPath)
|
||||
}
|
||||
if len(m.CountryBlock.CountryList) != 2 || m.CountryBlock.CountryList[0] != "US" || m.CountryBlock.CountryList[1] != "CA" {
|
||||
t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlock.CountryList)
|
||||
if len(m.CountryBlacklist.CountryList) != 2 || m.CountryBlacklist.CountryList[0] != "US" || m.CountryBlacklist.CountryList[1] != "CA" {
|
||||
t.Errorf("Expected country list to be ['US', 'CA'], got %v", m.CountryBlacklist.CountryList)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
44
handler.go
44
handler.go
@@ -230,18 +230,18 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
)
|
||||
|
||||
if phase == 1 && m.CountryBlock.Enabled {
|
||||
m.logger.Debug("Starting country blocking phase")
|
||||
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
|
||||
if phase == 1 && m.CountryBlacklist.Enabled {
|
||||
m.logger.Debug("Starting country blacklisting phase")
|
||||
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country block",
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country blocking phase completed - blocked due to error")
|
||||
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if blocked {
|
||||
@@ -254,7 +254,34 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country blocking phase completed - not blocked")
|
||||
m.logger.Debug("Country blacklisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
|
||||
if phase == 1 && m.CountryWhitelist.Enabled {
|
||||
m.logger.Debug("Starting country whitelisting phase")
|
||||
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if !allowed {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"))
|
||||
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country whitelisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
|
||||
@@ -327,7 +354,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
rules, ok := m.Rules[phase]
|
||||
if !ok {
|
||||
m.logger.Debug("No rules found for phase", zap.Int("phase", phase))
|
||||
return
|
||||
// Don't block on empty rules. There may be no rules specified
|
||||
// return
|
||||
}
|
||||
|
||||
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
|
||||
@@ -434,6 +462,8 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
zap.Int("total_score", state.TotalScore),
|
||||
zap.Int("anomaly_threshold", m.AnomalyThreshold),
|
||||
)
|
||||
|
||||
m.allowRequest(state)
|
||||
}
|
||||
|
||||
// incrementRateLimiterBlockedRequestsMetric increments the blocked requests metric for the rate limiter.
|
||||
|
||||
305
handler_test.go
305
handler_test.go
@@ -12,7 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -26,32 +26,37 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
|
||||
dnsBlacklist: map[string]struct{}{
|
||||
"malicious.domain": {},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
},
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
// Simulate a request to a blacklisted domain
|
||||
req := httptest.NewRequest("GET", "http://malicious.domain", nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
t.Run("Allow unblocked domain", func(t *testing.T) {
|
||||
// Simulate a request to a blacklisted domain
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Debug: Print the response body and status code
|
||||
t.Logf("Response Body: %s", w.Body.String())
|
||||
t.Logf("Response Status Code: %d", w.Code)
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
t.Run("Block blacklisted domain", func(t *testing.T) {
|
||||
// Simulate a request to a blacklisted domain
|
||||
req := httptest.NewRequest("GET", "http://malicious.domain", nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
@@ -62,68 +67,135 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata)
|
||||
assert.NoError(t, err)
|
||||
|
||||
middleware := &Middleware{
|
||||
blackListMiddleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
geoIPHandler: geoIPHandler,
|
||||
CountryBlock: CountryAccessFilter{
|
||||
CountryBlacklist: CountryAccessFilter{
|
||||
Enabled: true,
|
||||
CountryList: []string{"US"},
|
||||
CountryList: []string{"US", "RU"},
|
||||
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
// Simulate a request from a blocked country (US)
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.RemoteAddr = googleUSIP
|
||||
w := httptest.NewRecorder()
|
||||
whiteListMiddleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
geoIPHandler: geoIPHandler,
|
||||
CountryWhitelist: CountryAccessFilter{
|
||||
Enabled: true,
|
||||
CountryList: []string{"BR"},
|
||||
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
|
||||
state := &WAFState{}
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
t.Run("GeoIP Blacklist: Allow CN IP", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req.RemoteAddr = aliCNIP
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
// Process the request in Phase 1
|
||||
blackListMiddleware.handlePhase(w, req, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
|
||||
t.Run("GeoIP Blacklist: Block US IP", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req.RemoteAddr = googleUSIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
blackListMiddleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
|
||||
t.Run("GeoIP Whitelist: Allow BR IP", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req.RemoteAddr = googleBRIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
whiteListMiddleware.handlePhase(w, req, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
|
||||
t.Run("GeoIP Whitelist: Block RU IP", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req.RemoteAddr = googleRUIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
whiteListMiddleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||
logger, err := zap.NewDevelopment()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ipBlackList := trie.NewTrie()
|
||||
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil)
|
||||
blackList := iptrie.NewTrie()
|
||||
loader := iptrie.NewTrieLoader(blackList)
|
||||
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: ipBlackList,
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
},
|
||||
for _, net := range []string{
|
||||
"192.168.0.0/24",
|
||||
"192.168.1.1/32",
|
||||
} {
|
||||
loader.Insert(netip.MustParsePrefix(net), "net="+net)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
t.Run("Allow unblocked CIDR", func(t *testing.T) {
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
|
||||
t.Run("Blocks blacklisted CIDR", func(t *testing.T) {
|
||||
middleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = "192.168.1.1"
|
||||
|
||||
// Process the request in Phase 1
|
||||
middleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
@@ -144,18 +216,13 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("User-Agent", "nikto")
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
@@ -204,12 +271,12 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
|
||||
|
||||
@@ -257,12 +324,12 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
|
||||
|
||||
@@ -310,12 +377,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header1", "good-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
|
||||
@@ -364,7 +431,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -417,12 +484,12 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("User-Agent", "good-user")
|
||||
|
||||
@@ -470,12 +537,12 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Create a context and add logID to it
|
||||
@@ -522,12 +589,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
|
||||
req := httptest.NewRequest("GET", testURL, nil) // Header not set
|
||||
req.RemoteAddr = localIP
|
||||
|
||||
// Create a context and add logID to it
|
||||
@@ -574,12 +641,12 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
|
||||
|
||||
@@ -627,12 +694,12 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "bad-user-agent")
|
||||
@@ -680,12 +747,12 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "good-header")
|
||||
req.Header.Set("User-Agent", "good-user-agent")
|
||||
@@ -734,7 +801,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -794,7 +861,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -870,12 +937,12 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this-is-a-bad-body")
|
||||
@@ -929,12 +996,12 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString(`{"data":{"malicious":true,"name":"test"}}`)
|
||||
@@ -988,12 +1055,12 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
|
||||
)
|
||||
req.RemoteAddr = localIP
|
||||
@@ -1043,12 +1110,12 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("User ID: 123-45-6789")
|
||||
@@ -1102,12 +1169,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com",
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
func() *bytes.Buffer {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this-is-a-good-body")
|
||||
@@ -1161,7 +1228,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1181,7 +1248,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
t.Fatalf("Failed to close multipart writer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com", body)
|
||||
req := httptest.NewRequest("POST", testURL, body)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
@@ -1229,12 +1296,12 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "http://example.com", nil)
|
||||
req := httptest.NewRequest("POST", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1275,7 +1342,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1288,7 +1355,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1331,7 +1398,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1343,7 +1410,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1387,7 +1454,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1399,7 +1466,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1441,7 +1508,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
@@ -1453,7 +1520,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
w := httptest.NewRecorder()
|
||||
state := &WAFState{}
|
||||
@@ -1497,12 +1564,12 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
|
||||
|
||||
@@ -1550,12 +1617,12 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = localIP
|
||||
req.Header.Set("X-Custom-Header1", "bad-value")
|
||||
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
|
||||
@@ -1579,7 +1646,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
|
||||
|
||||
req2 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req2 := httptest.NewRequest("GET", testURL, nil)
|
||||
req2.RemoteAddr = localIP
|
||||
req2.Header.Set("X-Custom-Header1", "good-value")
|
||||
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
|
||||
@@ -1603,7 +1670,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
|
||||
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
|
||||
|
||||
req3 := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req3 := httptest.NewRequest("GET", testURL, nil)
|
||||
req3.RemoteAddr = localIP
|
||||
req3.Header.Set("X-Custom-Header1", "good-value")
|
||||
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
|
||||
@@ -1658,7 +1725,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
@@ -1728,7 +1795,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
@@ -1781,7 +1848,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -398,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
|
||||
Body: "Rate limit exceeded",
|
||||
},
|
||||
},
|
||||
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
|
||||
ipBlacklist: iptrie.NewTrie(), // Initialize ipBlacklist
|
||||
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
@@ -34,7 +34,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract METHOD",
|
||||
target: "METHOD",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("POST", "http://example.com", nil)
|
||||
req := httptest.NewRequest("POST", testURL, nil)
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
expectedValue: "POST",
|
||||
@@ -54,7 +54,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract USER_AGENT",
|
||||
target: "USER_AGENT",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
@@ -65,7 +65,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Extract HEADERS prefix",
|
||||
target: "HEADERS:Content-Type",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req, httptest.NewRecorder()
|
||||
},
|
||||
@@ -86,7 +86,7 @@ func TestExtractValue(t *testing.T) {
|
||||
name: "Empty target",
|
||||
target: "",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
return httptest.NewRequest("GET", "http://example.com", nil), httptest.NewRecorder()
|
||||
return httptest.NewRequest("GET", testURL, nil), httptest.NewRecorder()
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
@@ -383,7 +383,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
|
||||
ResponseWritten: false,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
|
||||
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
|
||||
ctx := context.Background()
|
||||
@@ -445,7 +445,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: trie.NewTrie(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
rateLimiter: func() *RateLimiter {
|
||||
@@ -459,12 +459,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
}
|
||||
return rl
|
||||
}(),
|
||||
CustomResponses: map[int]CustomBlockResponse{
|
||||
403: {
|
||||
StatusCode: http.StatusForbidden,
|
||||
Body: "Access Denied",
|
||||
},
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
// Add some IPs to the blacklist
|
||||
@@ -475,7 +470,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
req.RemoteAddr = fmt.Sprintf("192.168.1.%d", i%256) // Simulate different IPs
|
||||
req.Header.Set("User-Agent", "test-agent") // Add a header for rule evaluation
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -8,6 +8,15 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// allowRequest - handles request allowing
|
||||
func (m *Middleware) allowRequest(state *WAFState) {
|
||||
state.Blocked = false
|
||||
state.StatusCode = http.StatusOK
|
||||
state.ResponseWritten = false
|
||||
|
||||
m.incrementAllowedRequestsMetric()
|
||||
}
|
||||
|
||||
// blockRequest handles blocking a request and logging the details.
|
||||
func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID, matchedValue string, fields ...zap.Field) {
|
||||
// CRITICAL FIX: Set these flags before any other operations
|
||||
|
||||
6
types.go
6
types.go
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
trie "github.com/phemmer/go-iptrie"
|
||||
"github.com/phemmer/go-iptrie"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
@@ -104,10 +104,10 @@ type Middleware struct {
|
||||
IPBlacklistFile string `json:"ip_blacklist_file"`
|
||||
DNSBlacklistFile string `json:"dns_blacklist_file"`
|
||||
AnomalyThreshold int `json:"anomaly_threshold"`
|
||||
CountryBlock CountryAccessFilter `json:"country_block"`
|
||||
CountryBlacklist CountryAccessFilter `json:"country_blacklist"`
|
||||
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
|
||||
Rules map[int][]Rule `json:"-"`
|
||||
ipBlacklist *trie.Trie `json:"-"`
|
||||
ipBlacklist *iptrie.Trie `json:"-"`
|
||||
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
|
||||
logger *zap.Logger
|
||||
LogSeverity string `json:"log_severity,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user