mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 14:17:45 -05:00
Security: Implement hardening improvements (LimitReader, GeoIP Fail-Open, UI Decoupling, Go Version)
This commit is contained in:
@@ -24,3 +24,6 @@ lint:
|
||||
lintfix:
|
||||
@echo "==> Checking source code with golangci-lint..."
|
||||
@golangci-lint run --fix
|
||||
|
||||
test-integration:
|
||||
@docker run --rm -v $(PWD):/app -w /app python:3.9-slim python test.py
|
||||
|
||||
@@ -18,7 +18,7 @@ A robust, highly customizable, and feature-rich **Web Application Firewall (WAF)
|
||||
* **Dynamic Config Reloads:** Seamless updates without restarts.
|
||||
* **File Watchers:** Automatic reloads on rule/blacklist changes.
|
||||
* **Observability:** Seamless integration with ELK stack and Prometheus.
|
||||
* **Rules generator**: powered by custom GPT, [try it here](https://chatgpt.com/g/g-677d07dd07e48191b799b9e5d6da7828-caddy-waf-ruler)
|
||||
* **Rules generator**: [available here](https://chatgpt.com/g/g-677d07dd07e48191b799b9e5d6da7828-caddy-waf-ruler)
|
||||
|
||||
_Simple at a glance UI :)_
|
||||

|
||||
@@ -176,7 +176,7 @@ If You like my projects, you may also like these ones:
|
||||
- [proxmox-vm-autoscale](https://github.com/fabriziosalmi/proxmox-vm-autoscale) Automatically scale virtual machines resources on Proxmox hosts
|
||||
- [UglyFeed](https://github.com/fabriziosalmi/UglyFeed) Retrieve, aggregate, filter, evaluate, rewrite and serve RSS feeds using Large Language Models for fun, research and learning purposes
|
||||
- [proxmox-lxc-autoscale](https://github.com/fabriziosalmi/proxmox-lxc-autoscale) Automatically scale LXC containers resources on Proxmox hosts
|
||||
- [DevGPT](https://github.com/fabriziosalmi/DevGPT) Code togheter, right now! GPT powered code assistant to build project in minutes
|
||||
- [DevAssistant](https://github.com/fabriziosalmi/DevGPT) Code together, right now! AI powered code assistant to build project in minutes
|
||||
- [websites-monitor](https://github.com/fabriziosalmi/websites-monitor) Websites monitoring via GitHub Actions (expiration, security, performances, privacy, SEO)
|
||||
- [caddy-mib](https://github.com/fabriziosalmi/caddy-mib) Track and ban client IPs generating repetitive errors on Caddy
|
||||
- [zonecontrol](https://github.com/fabriziosalmi/zonecontrol) Cloudflare Zones Settings Automation using GitHub Actions
|
||||
|
||||
8
assets.go
Normal file
8
assets.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build with_ui
|
||||
|
||||
package caddywaf
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed ui/*
|
||||
var Assets embed.FS
|
||||
7
assets_stub.go
Normal file
7
assets_stub.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !with_ui
|
||||
|
||||
package caddywaf
|
||||
|
||||
import "embed"
|
||||
|
||||
var Assets embed.FS
|
||||
@@ -239,7 +239,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
m.configLoader = NewConfigLoader(m.logger)
|
||||
m.blacklistLoader = NewBlacklistLoader(m.logger)
|
||||
m.geoIPHandler = NewGeoIPHandler(m.logger)
|
||||
m.requestValueExtractor = NewRequestValueExtractor(m.logger, m.RedactSensitiveData)
|
||||
m.requestValueExtractor = NewRequestValueExtractor(m.logger, m.RedactSensitiveData, m.MaxRequestBodySize)
|
||||
|
||||
// Configure GeoIP handler
|
||||
m.geoIPHandler.WithGeoIPCache(m.geoIPCacheTTL)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/fabriziosalmi/caddy-waf
|
||||
|
||||
go 1.25
|
||||
go 1.25.5
|
||||
|
||||
toolchain go1.25.3
|
||||
|
||||
|
||||
48
handler.go
48
handler.go
@@ -334,12 +334,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
if m.GeoIPFailOpen {
|
||||
m.logger.Warn("GeoIP lookup failed (Whitelist); Failing OPEN")
|
||||
} else {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
}
|
||||
} else if !allowed {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
|
||||
zap.String("message", "Request blocked by country"))
|
||||
@@ -362,12 +366,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("ASN blocking phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
if m.GeoIPFailOpen {
|
||||
m.logger.Warn("ASN lookup failed; Failing OPEN")
|
||||
} else {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("ASN blocking phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
}
|
||||
} else if blocked {
|
||||
asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule",
|
||||
@@ -392,12 +400,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
if m.GeoIPFailOpen {
|
||||
m.logger.Warn("GeoIP lookup failed (Blacklist); Failing OPEN")
|
||||
} else {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
}
|
||||
} else if blocked {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
|
||||
zap.String("message", "Request blocked by country"))
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
|
||||
},
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -83,7 +83,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
wlMiddleware := &Middleware{
|
||||
@@ -97,7 +97,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
blackWhiteMw := &Middleware{
|
||||
@@ -117,7 +117,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -216,7 +216,7 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -234,7 +234,7 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||
logger: logger,
|
||||
ipBlacklist: blackList,
|
||||
CustomResponses: customResponse,
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -270,7 +270,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -378,7 +378,7 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -431,7 +431,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -485,7 +485,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
|
||||
@@ -538,7 +538,7 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -591,7 +591,7 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -643,7 +643,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil) // Header not set
|
||||
@@ -695,7 +695,7 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -748,7 +748,7 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -801,7 +801,7 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -855,7 +855,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com?param1=good-param-value¶m2=good-value", nil)
|
||||
@@ -915,7 +915,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://bad-host.com", nil)
|
||||
@@ -991,7 +991,7 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
@@ -1050,7 +1050,7 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
@@ -1109,7 +1109,7 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
@@ -1164,7 +1164,7 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
@@ -1223,7 +1223,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL,
|
||||
@@ -1282,7 +1282,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
@@ -1350,7 +1350,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", testURL, nil)
|
||||
@@ -1396,7 +1396,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
mockHandler := func() caddyhttp.Handler {
|
||||
@@ -1452,7 +1452,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
mockHandler := func() caddyhttp.Handler {
|
||||
@@ -1508,7 +1508,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
mockHandler := func() caddyhttp.Handler {
|
||||
@@ -1562,7 +1562,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
mockHandler := func() caddyhttp.Handler {
|
||||
return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -1618,7 +1618,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
@@ -1671,7 +1671,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
|
||||
14
request.go
14
request.go
@@ -16,6 +16,7 @@ import (
|
||||
type RequestValueExtractor struct {
|
||||
logger *zap.Logger
|
||||
redactSensitiveData bool // Add this field
|
||||
maxBodySize int64
|
||||
}
|
||||
|
||||
// Extraction Target Constants - Improved Readability and Maintainability
|
||||
@@ -47,8 +48,11 @@ const (
|
||||
var sensitiveTargets = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive targets for redaction as package variable
|
||||
|
||||
// NewRequestValueExtractor creates a new RequestValueExtractor with a given logger
|
||||
func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool) *RequestValueExtractor {
|
||||
return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData}
|
||||
func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool, maxBodySize int64) *RequestValueExtractor {
|
||||
if maxBodySize <= 0 {
|
||||
maxBodySize = 10 * 1024 * 1024 // Default 10MB
|
||||
}
|
||||
return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData, maxBodySize: maxBodySize}
|
||||
}
|
||||
|
||||
// ExtractValue extracts values based on the target, handling comma separated targets
|
||||
@@ -204,7 +208,8 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s
|
||||
rve.logger.Debug("Request body is empty", zap.String("target", target))
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
reader := io.LimitReader(r.Body, rve.maxBodySize)
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for target %s: %w", target, err)
|
||||
@@ -334,7 +339,8 @@ func (rve *RequestValueExtractor) extractValueForJSONPath(r *http.Request, jsonP
|
||||
return "", fmt.Errorf("request body is empty for target: %s", target)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
reader := io.LimitReader(r.Body, rve.maxBodySize)
|
||||
bodyBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
rve.logger.Error("Failed to read request body", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err)
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
func TestExtractValue(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, true)
|
||||
rve := NewRequestValueExtractor(logger, true, 0)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -144,7 +144,7 @@ func TestRedactValueIfSensitive(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rve := NewRequestValueExtractor(logger, tt.redactSensitive)
|
||||
rve := NewRequestValueExtractor(logger, tt.redactSensitive, 0)
|
||||
result := rve.RedactValueIfSensitive(tt.target, tt.value)
|
||||
|
||||
if tt.expectedRedacted && result != "REDACTED" {
|
||||
@@ -159,7 +159,7 @@ func TestRedactValueIfSensitive(t *testing.T) {
|
||||
|
||||
func TestExtractValue_HeaderCaseInsensitive(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("x-test-header", "test-value")
|
||||
@@ -172,7 +172,7 @@ func TestExtractValue_HeaderCaseInsensitive(t *testing.T) {
|
||||
|
||||
func TestExtractValue_EmptyTarget(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -184,7 +184,7 @@ func TestExtractValue_EmptyTarget(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Method(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -196,7 +196,7 @@ func TestExtractValue_Method(t *testing.T) {
|
||||
|
||||
func TestExtractValue_RemoteIP(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = localIP
|
||||
@@ -209,7 +209,7 @@ func TestExtractValue_RemoteIP(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Protocol(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Proto = "HTTP/1.1"
|
||||
@@ -222,7 +222,7 @@ func TestExtractValue_Protocol(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Host(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Host = "example.com"
|
||||
@@ -235,7 +235,7 @@ func TestExtractValue_Host(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Args(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/?foo=bar&baz=qux", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -247,7 +247,7 @@ func TestExtractValue_Args(t *testing.T) {
|
||||
|
||||
func TestExtractValue_UserAgent(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
@@ -260,7 +260,7 @@ func TestExtractValue_UserAgent(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Path(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test-path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -272,7 +272,7 @@ func TestExtractValue_Path(t *testing.T) {
|
||||
|
||||
func TestExtractValue_URI(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test-path?foo=bar", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -284,7 +284,7 @@ func TestExtractValue_URI(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Body(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
body := bytes.NewBufferString("test body")
|
||||
req := httptest.NewRequest("POST", "/", body)
|
||||
@@ -297,7 +297,7 @@ func TestExtractValue_Body(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Headers(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Test-Header", "test-value")
|
||||
@@ -310,7 +310,7 @@ func TestExtractValue_Headers(t *testing.T) {
|
||||
|
||||
func TestExtractValue_Cookies(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "test-cookie", Value: "test-value"})
|
||||
@@ -323,7 +323,7 @@ func TestExtractValue_Cookies(t *testing.T) {
|
||||
|
||||
func TestExtractValue_UnknownTarget(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rve := NewRequestValueExtractor(logger, false)
|
||||
rve := NewRequestValueExtractor(logger, false, 0)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -368,7 +368,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) {
|
||||
AnomalyThreshold: 100, // High threshold
|
||||
ruleHits: sync.Map{},
|
||||
muMetrics: sync.RWMutex{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger.Logger, false), // Initialize
|
||||
requestValueExtractor: NewRequestValueExtractor(logger.Logger, false, 0), // Initialize
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
@@ -420,7 +420,7 @@ func TestValidateRule_EmptyTargets(t *testing.T) {
|
||||
func TestNewRequestValueExtractor(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
redactSensitiveData := true
|
||||
rve := NewRequestValueExtractor(logger, redactSensitiveData)
|
||||
rve := NewRequestValueExtractor(logger, redactSensitiveData, 0)
|
||||
|
||||
assert.NotNil(t, rve)
|
||||
assert.Equal(t, logger, rve.logger)
|
||||
@@ -448,7 +448,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
||||
ruleCache: NewRuleCache(),
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
dnsBlacklist: map[string]struct{}{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
rateLimiter: func() *RateLimiter {
|
||||
rl, err := NewRateLimiter(RateLimit{
|
||||
Requests: 10,
|
||||
|
||||
@@ -150,7 +150,7 @@ func TestProcessRuleMatch(t *testing.T) {
|
||||
AnomalyThreshold: tt.anomalyThreshold,
|
||||
ruleHits: sync.Map{},
|
||||
muMetrics: sync.RWMutex{},
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||
requestValueExtractor: NewRequestValueExtractor(logger, false, 0),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
2
types.go
2
types.go
@@ -137,6 +137,8 @@ type Middleware struct {
|
||||
LogFilePath string
|
||||
LogBuffer int `json:"log_buffer,omitempty"` // Add the LogBuffer field
|
||||
RedactSensitiveData bool `json:"redact_sensitive_data,omitempty"`
|
||||
MaxRequestBodySize int64 `json:"max_request_body_size,omitempty"`
|
||||
GeoIPFailOpen bool `json:"geoip_fail_open,omitempty"`
|
||||
|
||||
ruleHits sync.Map `json:"-"`
|
||||
MetricsEndpoint string `json:"metrics_endpoint,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user