Files
caddy-waf/handler_test.go
2025-10-11 22:36:20 +03:00

1814 lines
55 KiB
Go

package caddywaf
import (
"bytes"
"context"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/netip"
"regexp"
"strings"
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
dnsBlacklist: map[string]struct{}{
"malicious.domain": {},
},
ipBlacklist: trie.NewTrie(),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
// Simulate a request to a blacklisted domain
req := httptest.NewRequest("GET", "http://malicious.domain", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Debug: Print the response body and status code
t.Logf("Response Body: %s", w.Body.String())
t.Logf("Response Status Code: %d", w.Code)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
geoIPHandler := NewGeoIPHandler(logger)
geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata)
assert.NoError(t, err)
middleware := &Middleware{
logger: logger,
geoIPHandler: geoIPHandler,
CountryBlock: CountryAccessFilter{
Enabled: true,
CountryList: []string{"US"},
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
geoIP: geoIPBlock,
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
// Simulate a request from a blocked country (US)
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = googleUSIP
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
ipBlackList := trie.NewTrie()
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1/24"), nil)
middleware := &Middleware{
logger: logger,
ipBlacklist: ipBlackList,
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2",
Pattern: "nikto",
Targets: []string{"USER_AGENT"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("nikto"),
},
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("User-Agent", "nikto")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-nikto" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule1",
Pattern: "bad-header",
Targets: []string{"HEADERS:X-Custom-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-header"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "this-is-a-bad-header") // Simulate a request with bad header
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headerregex" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Header Regex", "Response body should contain 'Blocked by Header Regex'")
}
func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_specific",
Pattern: "^specific-value$",
Targets: []string{"HEADERS:X-Specific-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("^specific-value$"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Specific Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Specific-Header", "specific-value") // Simulate a request with the specific header
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headerspecificvalue" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Specific Header Regex", "Response body should contain 'Blocked by Specific Header Regex'")
}
func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_comma",
Pattern: "bad-value",
Targets: []string{"HEADERS:X-Custom-Header1,HEADERS:X-Custom-Header2"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-value"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Comma-Separated Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "good-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Simulate a request with bad value in one of the headers
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headercomma" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Comma-Separated Header Regex", "Response body should contain 'Blocked by Comma-Separated Header Regex'")
}
func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_combined",
Pattern: "bad-user|bad-host",
Targets: []string{"USER_AGENT", "HOST"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-user|bad-host"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Combined Condition Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
req.RemoteAddr = localIP
req.Header.Set("User-Agent", "good-user")
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-combined"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Combined Condition Regex", "Response body should contain 'Blocked by Combined Condition Regex'")
}
func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_no_match",
Pattern: "nomatch",
Targets: []string{"USER_AGENT"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("nomatch"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("User-Agent", "good-user")
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-nomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_empty",
Pattern: ".+", // Match anything (including empty)
Targets: []string{"HEADERS:X-Empty-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile(".+"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Empty Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headerempty"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked because header is empty")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_missing",
Pattern: "test-value",
Targets: []string{"HEADERS:X-Missing-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("test-value"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Missing Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil) // Header not set
req.RemoteAddr = localIP
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headermissing"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked because header is missing")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_complex",
Pattern: `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, // Email regex
Targets: []string{"HEADERS:X-Email-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Complex Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Email-Header", "test@example.com") // Simulate a request with a valid email
// Create a context and add logID to it
ctx := context.Background()
logID := "test-log-id-headercomplex"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Complex Header Regex", "Response body should contain 'Blocked by Complex Header Regex'")
}
func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_multi_target",
Pattern: "bad",
Targets: []string{"HEADERS:X-Custom-Header", "USER_AGENT"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Multi-Target Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "bad-user-agent")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multimatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Multi-Target Regex", "Response body should contain 'Blocked by Multi-Target Regex'")
}
func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_multi_target_no_match",
Pattern: "bad",
Targets: []string{"HEADERS:X-Custom-Header", "USER_AGENT"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Multi-Target Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "good-header")
req.Header.Set("User-Agent", "good-user-agent")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multinomatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx) // Create new request with context
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_url_param_no_match",
Pattern: "nomatch",
Targets: []string{"URL_PARAM:param1"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("nomatch"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by URL Parameter Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com?param1=good-param-value¶m2=good-value", nil)
req.RemoteAddr = localIP
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-urlparamnomatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_multi1",
Pattern: "bad-user",
Targets: []string{"USER_AGENT"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-user"),
},
{
ID: "rule_multi2",
Pattern: "bad-host",
Targets: []string{"HOST"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-host"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Multiple Rules",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
req.RemoteAddr = localIP
req.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-multiplerules" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Multiple Rules", "Response body should contain 'Blocked by Multiple Rules'")
req2 := httptest.NewRequest("GET", "http://good-host.com", nil)
req2.RemoteAddr = localIP
req2.Header.Set("User-Agent", "bad-user") // Simulate a request with a bad user agent
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2 as well!
ctx2 := context.Background() // New context for the second request!
logID2 := "test-log-id-multiplerules2"
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
req2 = req2.WithContext(ctx2)
w2 := httptest.NewRecorder()
state2 := &WAFState{}
middleware.handlePhase(w2, req2, 1, state2)
t.Logf("State Blocked: %v", state2.Blocked)
t.Logf("Response Code: %d", w2.Code)
t.Logf("Response Body: %s", w2.Body.String())
assert.True(t, state2.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Rules", "Response body should contain 'Blocked by Multiple Rules'")
}
func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2",
Pattern: "bad-body",
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad-body"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-bad-body")
return b
}(), // Simulate a request with bad body
)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", "text/plain")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregex" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Body Regex", "Response body should contain 'Blocked by Body Regex'")
}
func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2_json",
Pattern: `"malicious":true`,
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile(`"malicious":true`),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by JSON Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString(`{"data":{"malicious":true,"name":"test"}}`)
return b
}(), // Simulate a request with JSON body
)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", "application/json")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexjson" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by JSON Body Regex", "Response body should contain 'Blocked by JSON Body Regex'")
}
func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2_form",
Pattern: "secret=badvalue",
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("secret=badvalue"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Form URL Encoded Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
strings.NewReader("param1=value1&secret=badvalue¶m2=value2"),
)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexform"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Form URL Encoded Regex", "Response body should contain 'Blocked by Form URL Encoded Regex'")
}
func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2_specific",
Pattern: `\d{3}-\d{2}-\d{4}`,
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile(`\d{3}-\d{2}-\d{4}`),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Specific Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("User ID: 123-45-6789")
return b
}(),
)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", "text/plain") // Setting content type
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexspecific"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Specific Body Regex", "Response body should contain 'Blocked by Specific Body Regex'")
}
func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule2_no_match",
Pattern: "nomatch",
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("nomatch"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com",
func() *bytes.Buffer {
b := new(bytes.Buffer)
b.WriteString("this-is-a-good-body")
return b
}(),
)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", "text/plain")
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexnomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule_multipart_no_match",
Pattern: "maliciousfile.txt",
Targets: []string{"FILE_NAME"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("maliciousfile.txt"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Multipart File Name Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", "goodfile.txt")
if err != nil {
t.Fatalf("Failed to create multipart form file: %v", err)
}
_, err = part.Write([]byte("file content"))
if err != nil {
t.Fatalf("Failed to write multipart form file: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("Failed to close multipart writer: %v", err)
}
req := httptest.NewRequest("POST", "http://example.com", body)
req.RemoteAddr = localIP
req.Header.Set("Content-Type", writer.FormDataContentType())
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-bodyregexmultipartnomatch"
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
2: {
{
ID: "rule_body_no_match",
Pattern: "some-pattern",
Targets: []string{"BODY"},
Phase: 2,
Score: 5,
Action: "block",
regex: regexp.MustCompile("some-pattern"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("POST", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 2, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
/////
func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
3: {
{
ID: "rule3_no_match",
Pattern: "nomatch",
Targets: []string{"RESPONSE_HEADERS:X-Response-Header"},
Phase: 3,
Score: 5,
Action: "block",
regex: regexp.MustCompile("nomatch"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Response Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
mockHandler := func() caddyhttp.Handler {
return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("X-Response-Header", "good-header")
w.WriteHeader(http.StatusOK)
return nil
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
err := middleware.ServeHTTP(w, req, mockHandler)
if err != nil {
t.Fatalf("ServeHTTP returned an error: %v", err)
}
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
4: {
{
ID: "rule4_empty",
Pattern: "test",
Targets: []string{"RESPONSE_BODY"},
Phase: 4,
Score: 5,
Action: "block",
regex: regexp.MustCompile("test"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Response Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
mockHandler := func() caddyhttp.Handler {
return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
err := middleware.ServeHTTP(w, req, mockHandler)
if err != nil {
t.Fatalf("ServeHTTP returned an error: %v", err)
}
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
////
func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
4: {
{
ID: "rule4_no_body",
Pattern: "test",
Targets: []string{"RESPONSE_BODY"},
Phase: 4,
Score: 5,
Action: "block",
regex: regexp.MustCompile("test"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Response Body Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
mockHandler := func() caddyhttp.Handler {
return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
err := middleware.ServeHTTP(w, req, mockHandler)
if err != nil {
t.Fatalf("ServeHTTP returned an error: %v", err)
}
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
3: {
{
ID: "rule_no_setcookie",
Pattern: "(?i)Set-Cookie:.*?(%0d|\\r)%0a",
Targets: []string{"RESPONSE_HEADERS"},
Phase: 3,
Score: 5,
Action: "block",
regex: regexp.MustCompile(`(?i)Set-Cookie:.*?(%0d|\r)%0a`),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Set-Cookie Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
mockHandler := func() caddyhttp.Handler {
return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("X-Custom-Header", "some-header-value") // Simulating a normal non-matching response
w.WriteHeader(http.StatusOK)
return nil
})
}()
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
w := httptest.NewRecorder()
state := &WAFState{}
err := middleware.ServeHTTP(w, req, mockHandler)
if err != nil {
t.Fatalf("ServeHTTP returned an error: %v", err)
}
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.False(t, state.Blocked, "Request should not be blocked")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
assert.Empty(t, w.Body.String(), "Response body should be empty")
}
//
func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_case_insensitive",
Pattern: "(?i)bad-value",
Targets: []string{"HEADERS:X-Custom-Header"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("(?i)bad-value"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Case-Insensitive Header Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header", "bAd-VaLuE") // Test with mixed-case header value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE
ctx := context.Background()
logID := "test-log-id-headercaseinsensitive" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked by case-insensitive regex")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Case-Insensitive Header Regex", "Response body should contain 'Blocked by Case-Insensitive Header Regex'")
}
func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
Rules: map[int][]Rule{
1: {
{
ID: "rule_header_multi",
Pattern: "bad",
Targets: []string{"HEADERS:X-Custom-Header1,HEADERS:X-Custom-Header2"},
Phase: 1,
Score: 5,
Action: "block",
regex: regexp.MustCompile("bad"),
},
},
},
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Blocked by Multiple Matching Headers Regex",
},
},
ruleCache: NewRuleCache(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = localIP
req.Header.Set("X-Custom-Header1", "bad-value")
req.Header.Set("X-Custom-Header2", "bad-value") // Both headers have a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req
ctx := context.Background()
logID := "test-log-id-headermultimatch" // Unique log ID for this test
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
state := &WAFState{}
middleware.handlePhase(w, req, 1, state)
t.Logf("State Blocked: %v", state.Blocked)
t.Logf("Response Code: %d", w.Code)
t.Logf("Response Body: %s", w.Body.String())
assert.True(t, state.Blocked, "Request should be blocked when both headers match")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req2 := httptest.NewRequest("GET", "http://example.com", nil)
req2.RemoteAddr = localIP
req2.Header.Set("X-Custom-Header1", "good-value")
req2.Header.Set("X-Custom-Header2", "bad-value") // One header has a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req2
ctx2 := context.Background()
logID2 := "test-log-id-headermultimatch2" // Unique log ID for this test
ctx2 = context.WithValue(ctx2, ContextKeyLogId("logID"), logID2)
req2 = req2.WithContext(ctx2)
w2 := httptest.NewRecorder()
state2 := &WAFState{}
middleware.handlePhase(w2, req2, 1, state2)
t.Logf("State Blocked: %v", state2.Blocked)
t.Logf("Response Code: %d", w2.Code)
t.Logf("Response Body: %s", w2.Body.String())
assert.True(t, state2.Blocked, "Request should be blocked when one header match")
assert.Equal(t, http.StatusForbidden, w2.Code, "Expected status code 403")
assert.Contains(t, w2.Body.String(), "Blocked by Multiple Matching Headers Regex", "Response body should contain 'Blocked by Multiple Matching Headers Regex'")
req3 := httptest.NewRequest("GET", "http://example.com", nil)
req3.RemoteAddr = localIP
req3.Header.Set("X-Custom-Header1", "good-value")
req3.Header.Set("X-Custom-Header2", "good-value") // None headers have a "bad" value
// Create a context and add logID to it - FIX: ADD CONTEXT HERE for req3
ctx3 := context.Background()
logID3 := "test-log-id-headermultimatch3" // Unique log ID for this test
ctx3 = context.WithValue(ctx3, ContextKeyLogId("logID"), logID3)
req3 = req3.WithContext(ctx3)
w3 := httptest.NewRecorder()
state3 := &WAFState{}
middleware.handlePhase(w3, req3, 1, state3)
t.Logf("State Blocked: %v", state3.Blocked)
t.Logf("Response Code: %d", w3.Code)
t.Logf("Response Body: %s", w3.Body.String())
assert.False(t, state3.Blocked, "Request should not be blocked when none headers match")
assert.Equal(t, http.StatusOK, w3.Code, "Expected status code 200")
}
// RequestLimit represents the rate limit state for a specific request
type RequestLimit struct {
Count int
LastReset time.Time
}
func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
rateLimiter: func() *RateLimiter {
rl := &RateLimiter{
config: RateLimit{
Requests: 1,
Window: time.Minute,
CleanupInterval: time.Minute,
Paths: []string{"/api/v1/.*", "/admin/.*"},
MatchAllPaths: false,
},
requests: make(map[string]map[string]*requestCounter),
stopCleanup: make(chan struct{}),
}
rl.startCleanup()
return rl
}(),
CustomResponses: map[int]CustomBlockResponse{
429: {
StatusCode: http.StatusTooManyRequests,
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
// Test path 1
req1 := httptest.NewRequest("GET", "/api/v1/users", nil)
req1.RemoteAddr = localIP
w1 := httptest.NewRecorder()
state1 := &WAFState{}
middleware.handlePhase(w1, req1, 1, state1)
assert.False(t, state1.Blocked, "First request to /api/v1 should be allowed")
assert.Equal(t, http.StatusOK, w1.Code, "Expected status code 200")
req2 := httptest.NewRequest("GET", "/api/v1/users", nil)
req2.RemoteAddr = localIP
w2 := httptest.NewRecorder()
state2 := &WAFState{}
middleware.handlePhase(w2, req2, 1, state2)
assert.True(t, state2.Blocked, "Second request to /api/v1 should be rate-limited")
assert.Equal(t, http.StatusTooManyRequests, w2.Code, "Expected status code 429")
// Test path 2
req3 := httptest.NewRequest("GET", "/admin/dashboard", nil)
req3.RemoteAddr = localIP
w3 := httptest.NewRecorder()
state3 := &WAFState{}
middleware.handlePhase(w3, req3, 1, state3)
assert.False(t, state3.Blocked, "First request to /admin should be allowed")
assert.Equal(t, http.StatusOK, w3.Code, "Expected status code 200")
req4 := httptest.NewRequest("GET", "/admin/dashboard", nil)
req4.RemoteAddr = localIP
w4 := httptest.NewRecorder()
state4 := &WAFState{}
middleware.handlePhase(w4, req4, 1, state4)
assert.True(t, state4.Blocked, "Second request to /admin should be rate-limited")
assert.Equal(t, http.StatusTooManyRequests, w4.Code, "Expected status code 429")
req5 := httptest.NewRequest("GET", "/not-rate-limited", nil)
req5.RemoteAddr = localIP
w5 := httptest.NewRecorder()
state5 := &WAFState{}
middleware.handlePhase(w5, req5, 1, state5)
assert.False(t, state5.Blocked, "Request not rate limited path should be allowed")
assert.Equal(t, http.StatusOK, w5.Code, "Expected status code 200")
}
func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
rateLimiter: func() *RateLimiter {
rl, err := NewRateLimiter(RateLimit{
Requests: 1,
Window: time.Minute,
CleanupInterval: time.Minute,
MatchAllPaths: true,
})
if err != nil {
t.Fatalf("Failed to create rate limiter: %v", err)
}
return rl
}(),
CustomResponses: map[int]CustomBlockResponse{
429: {
StatusCode: http.StatusTooManyRequests,
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
// Test different IPs
req1 := httptest.NewRequest("GET", "/api/users", nil)
req1.RemoteAddr = localIP
w1 := httptest.NewRecorder()
state1 := &WAFState{}
middleware.handlePhase(w1, req1, 1, state1)
assert.False(t, state1.Blocked, "First request from 192.168.1.1 should be allowed")
assert.Equal(t, http.StatusOK, w1.Code, "Expected status code 200")
req2 := httptest.NewRequest("GET", "/api/users", nil)
req2.RemoteAddr = "192.168.1.2"
w2 := httptest.NewRecorder()
state2 := &WAFState{}
middleware.handlePhase(w2, req2, 1, state2)
assert.False(t, state2.Blocked, "First request from 192.168.1.2 should be allowed")
assert.Equal(t, http.StatusOK, w2.Code, "Expected status code 200")
req3 := httptest.NewRequest("GET", "/api/users", nil)
req3.RemoteAddr = localIP
w3 := httptest.NewRecorder()
state3 := &WAFState{}
middleware.handlePhase(w3, req3, 1, state3)
assert.True(t, state3.Blocked, "Second request from 192.168.1.1 should be blocked")
assert.Equal(t, http.StatusTooManyRequests, w3.Code, "Expected status code 429")
}
func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
logger: logger,
rateLimiter: func() *RateLimiter {
rl, err := NewRateLimiter(RateLimit{
Requests: 1,
Window: time.Minute,
CleanupInterval: time.Minute,
MatchAllPaths: true,
})
if err != nil {
t.Fatalf("Failed to create rate limiter: %v", err)
}
return rl
}(),
CustomResponses: map[int]CustomBlockResponse{
429: {
StatusCode: http.StatusTooManyRequests,
Body: "Rate limit exceeded",
},
},
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
// Test with match all paths
req1 := httptest.NewRequest("GET", "/api/users", nil)
req1.RemoteAddr = localIP
w1 := httptest.NewRecorder()
state1 := &WAFState{}
middleware.handlePhase(w1, req1, 1, state1)
assert.False(t, state1.Blocked, "First request to /api/users should be allowed")
assert.Equal(t, http.StatusOK, w1.Code, "Expected status code 200")
req2 := httptest.NewRequest("GET", "/api/users", nil)
req2.RemoteAddr = localIP
w2 := httptest.NewRecorder()
state2 := &WAFState{}
middleware.handlePhase(w2, req2, 1, state2)
assert.True(t, state2.Blocked, "Second request to /api/users should be rate-limited")
assert.Equal(t, http.StatusTooManyRequests, w2.Code, "Expected status code 429")
req3 := httptest.NewRequest("GET", "/some-other-path", nil)
req3.RemoteAddr = localIP
w3 := httptest.NewRecorder()
state3 := &WAFState{}
middleware.handlePhase(w3, req3, 1, state3)
assert.True(t, state3.Blocked, "Second request to /some-other-path should be rate-limited because MatchAllPaths=true")
assert.Equal(t, http.StatusTooManyRequests, w3.Code, "Expected status code 429")
}