From eea39d253befb0e9f8c1f0e0f06d2beba5a242ba Mon Sep 17 00:00:00 2001 From: Fabrizio Salmi Date: Sat, 6 Dec 2025 22:46:11 +0100 Subject: [PATCH] Security: Implement hardening improvements (LimitReader, GeoIP Fail-Open, UI Decoupling, Go Version) --- GNUmakefile | 3 +++ README.md | 4 ++-- assets.go | 8 +++++++ assets_stub.go | 7 ++++++ caddywaf.go | 2 +- go.mod | 2 +- handler.go | 48 +++++++++++++++++++++++-------------- handler_test.go | 64 ++++++++++++++++++++++++------------------------- request.go | 14 +++++++---- request_test.go | 38 ++++++++++++++--------------- rules_test.go | 2 +- types.go | 2 ++ 12 files changed, 116 insertions(+), 78 deletions(-) create mode 100644 assets.go create mode 100644 assets_stub.go diff --git a/GNUmakefile b/GNUmakefile index 6f62bf1..08cb17c 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -24,3 +24,6 @@ lint: lintfix: @echo "==> Checking source code with golangci-lint..." @golangci-lint run --fix + +test-integration: + @docker run --rm -v $(PWD):/app -w /app python:3.9-slim python test.py diff --git a/README.md b/README.md index da3d426..17ff80a 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ A robust, highly customizable, and feature-rich **Web Application Firewall (WAF) * **Dynamic Config Reloads:** Seamless updates without restarts. * **File Watchers:** Automatic reloads on rule/blacklist changes. * **Observability:** Seamless integration with ELK stack and Prometheus. -* **Rules generator**: powered by custom GPT, [try it here](https://chatgpt.com/g/g-677d07dd07e48191b799b9e5d6da7828-caddy-waf-ruler) +* **Rules generator**: [available here](https://chatgpt.com/g/g-677d07dd07e48191b799b9e5d6da7828-caddy-waf-ruler) _Simple at a glance UI :)_ ![demo](https://github.com/fabriziosalmi/caddy-waf/blob/main/docs/caddy-waf-ui.png?raw=true) @@ -176,7 +176,7 @@ If You like my projects, you may also like these ones: - [proxmox-vm-autoscale](https://github.com/fabriziosalmi/proxmox-vm-autoscale) Automatically scale virtual machines resources on Proxmox hosts - [UglyFeed](https://github.com/fabriziosalmi/UglyFeed) Retrieve, aggregate, filter, evaluate, rewrite and serve RSS feeds using Large Language Models for fun, research and learning purposes - [proxmox-lxc-autoscale](https://github.com/fabriziosalmi/proxmox-lxc-autoscale) Automatically scale LXC containers resources on Proxmox hosts -- [DevGPT](https://github.com/fabriziosalmi/DevGPT) Code togheter, right now! GPT powered code assistant to build project in minutes +- [DevAssistant](https://github.com/fabriziosalmi/DevGPT) Code together, right now! AI powered code assistant to build project in minutes - [websites-monitor](https://github.com/fabriziosalmi/websites-monitor) Websites monitoring via GitHub Actions (expiration, security, performances, privacy, SEO) - [caddy-mib](https://github.com/fabriziosalmi/caddy-mib) Track and ban client IPs generating repetitive errors on Caddy - [zonecontrol](https://github.com/fabriziosalmi/zonecontrol) Cloudflare Zones Settings Automation using GitHub Actions diff --git a/assets.go b/assets.go new file mode 100644 index 0000000..f8e44f3 --- /dev/null +++ b/assets.go @@ -0,0 +1,8 @@ +//go:build with_ui + +package caddywaf + +import "embed" + +//go:embed ui/* +var Assets embed.FS diff --git a/assets_stub.go b/assets_stub.go new file mode 100644 index 0000000..5a63dc0 --- /dev/null +++ b/assets_stub.go @@ -0,0 +1,7 @@ +//go:build !with_ui + +package caddywaf + +import "embed" + +var Assets embed.FS diff --git a/caddywaf.go b/caddywaf.go index 9fcf935..6529963 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -239,7 +239,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { m.configLoader = NewConfigLoader(m.logger) m.blacklistLoader = NewBlacklistLoader(m.logger) m.geoIPHandler = NewGeoIPHandler(m.logger) - m.requestValueExtractor = NewRequestValueExtractor(m.logger, m.RedactSensitiveData) + m.requestValueExtractor = NewRequestValueExtractor(m.logger, m.RedactSensitiveData, m.MaxRequestBodySize) // Configure GeoIP handler m.geoIPHandler.WithGeoIPCache(m.geoIPCacheTTL) diff --git a/go.mod b/go.mod index 9f812d4..fdacb02 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/fabriziosalmi/caddy-waf -go 1.25 +go 1.25.5 toolchain go1.25.3 diff --git a/handler.go b/handler.go index 7712cc9..e000f0d 100644 --- a/handler.go +++ b/handler.go @@ -334,12 +334,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i r, zap.Error(err), ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("Country whitelisting phase completed - blocked due to error") - m.incrementGeoIPRequestsMetric(false) // Increment with false for error - return + if m.GeoIPFailOpen { + m.logger.Warn("GeoIP lookup failed (Whitelist); Failing OPEN") + } else { + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country whitelisting phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } } else if !allowed { m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", zap.String("message", "Request blocked by country")) @@ -362,12 +366,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i r, zap.Error(err), ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule", - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("ASN blocking phase completed - blocked due to error") - m.incrementGeoIPRequestsMetric(false) // Increment with false for error - return + if m.GeoIPFailOpen { + m.logger.Warn("ASN lookup failed; Failing OPEN") + } else { + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule", + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("ASN blocking phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } } else if blocked { asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP) m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule", @@ -392,12 +400,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i r, zap.Error(err), ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("Country blacklisting phase completed - blocked due to error") - m.incrementGeoIPRequestsMetric(false) // Increment with false for error - return + if m.GeoIPFailOpen { + m.logger.Warn("GeoIP lookup failed (Blacklist); Failing OPEN") + } else { + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country blacklisting phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } } else if blocked { m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", zap.String("message", "Request blocked by country")) diff --git a/handler_test.go b/handler_test.go index f18fc41..e8bc546 100644 --- a/handler_test.go +++ b/handler_test.go @@ -29,7 +29,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) { }, ipBlacklist: iptrie.NewTrie(), CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } w := httptest.NewRecorder() @@ -83,7 +83,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { geoIP: geoIPBlock, }, CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } wlMiddleware := &Middleware{ @@ -97,7 +97,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { geoIP: geoIPBlock, }, CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } blackWhiteMw := &Middleware{ @@ -117,7 +117,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { geoIP: geoIPBlock, }, CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -216,7 +216,7 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) { logger: logger, ipBlacklist: blackList, CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -234,7 +234,7 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) { logger: logger, ipBlacklist: blackList, CustomResponses: customResponse, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -270,7 +270,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), CustomResponses: customResponse, } @@ -325,7 +325,7 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -378,7 +378,7 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -431,7 +431,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -485,7 +485,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", "http://bad-host.com", nil) @@ -538,7 +538,7 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -591,7 +591,7 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -643,7 +643,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) // Header not set @@ -695,7 +695,7 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -748,7 +748,7 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -801,7 +801,7 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -855,7 +855,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", "http://example.com?param1=good-param-value¶m2=good-value", nil) @@ -915,7 +915,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", "http://bad-host.com", nil) @@ -991,7 +991,7 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, @@ -1050,7 +1050,7 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, @@ -1109,7 +1109,7 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, @@ -1164,7 +1164,7 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, @@ -1223,7 +1223,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, @@ -1282,7 +1282,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } body := &bytes.Buffer{} @@ -1350,7 +1350,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("POST", testURL, nil) @@ -1396,7 +1396,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } mockHandler := func() caddyhttp.Handler { @@ -1452,7 +1452,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } mockHandler := func() caddyhttp.Handler { @@ -1508,7 +1508,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } mockHandler := func() caddyhttp.Handler { @@ -1562,7 +1562,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } mockHandler := func() caddyhttp.Handler { return caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { @@ -1618,7 +1618,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) @@ -1671,7 +1671,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } req := httptest.NewRequest("GET", testURL, nil) diff --git a/request.go b/request.go index ee1231b..f12e951 100644 --- a/request.go +++ b/request.go @@ -16,6 +16,7 @@ import ( type RequestValueExtractor struct { logger *zap.Logger redactSensitiveData bool // Add this field + maxBodySize int64 } // Extraction Target Constants - Improved Readability and Maintainability @@ -47,8 +48,11 @@ const ( var sensitiveTargets = []string{"password", "token", "apikey", "authorization", "secret"} // Define sensitive targets for redaction as package variable // NewRequestValueExtractor creates a new RequestValueExtractor with a given logger -func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool) *RequestValueExtractor { - return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData} +func NewRequestValueExtractor(logger *zap.Logger, redactSensitiveData bool, maxBodySize int64) *RequestValueExtractor { + if maxBodySize <= 0 { + maxBodySize = 10 * 1024 * 1024 // Default 10MB + } + return &RequestValueExtractor{logger: logger, redactSensitiveData: redactSensitiveData, maxBodySize: maxBodySize} } // ExtractValue extracts values based on the target, handling comma separated targets @@ -204,7 +208,8 @@ func (rve *RequestValueExtractor) extractBody(r *http.Request, target string) (s rve.logger.Debug("Request body is empty", zap.String("target", target)) return "", fmt.Errorf("request body is empty for target: %s", target) } - bodyBytes, err := io.ReadAll(r.Body) + reader := io.LimitReader(r.Body, rve.maxBodySize) + bodyBytes, err := io.ReadAll(reader) if err != nil { rve.logger.Error("Failed to read request body", zap.Error(err)) return "", fmt.Errorf("failed to read request body for target %s: %w", target, err) @@ -334,7 +339,8 @@ func (rve *RequestValueExtractor) extractValueForJSONPath(r *http.Request, jsonP return "", fmt.Errorf("request body is empty for target: %s", target) } - bodyBytes, err := io.ReadAll(r.Body) + reader := io.LimitReader(r.Body, rve.maxBodySize) + bodyBytes, err := io.ReadAll(reader) if err != nil { rve.logger.Error("Failed to read request body", zap.Error(err)) return "", fmt.Errorf("failed to read request body for JSON_PATH target %s: %w", target, err) diff --git a/request_test.go b/request_test.go index ba2b4d2..09184e3 100644 --- a/request_test.go +++ b/request_test.go @@ -21,7 +21,7 @@ import ( func TestExtractValue(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, true) + rve := NewRequestValueExtractor(logger, true, 0) tests := []struct { name string @@ -144,7 +144,7 @@ func TestRedactValueIfSensitive(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rve := NewRequestValueExtractor(logger, tt.redactSensitive) + rve := NewRequestValueExtractor(logger, tt.redactSensitive, 0) result := rve.RedactValueIfSensitive(tt.target, tt.value) if tt.expectedRedacted && result != "REDACTED" { @@ -159,7 +159,7 @@ func TestRedactValueIfSensitive(t *testing.T) { func TestExtractValue_HeaderCaseInsensitive(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("x-test-header", "test-value") @@ -172,7 +172,7 @@ func TestExtractValue_HeaderCaseInsensitive(t *testing.T) { func TestExtractValue_EmptyTarget(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -184,7 +184,7 @@ func TestExtractValue_EmptyTarget(t *testing.T) { func TestExtractValue_Method(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -196,7 +196,7 @@ func TestExtractValue_Method(t *testing.T) { func TestExtractValue_RemoteIP(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = localIP @@ -209,7 +209,7 @@ func TestExtractValue_RemoteIP(t *testing.T) { func TestExtractValue_Protocol(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.Proto = "HTTP/1.1" @@ -222,7 +222,7 @@ func TestExtractValue_Protocol(t *testing.T) { func TestExtractValue_Host(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.Host = "example.com" @@ -235,7 +235,7 @@ func TestExtractValue_Host(t *testing.T) { func TestExtractValue_Args(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/?foo=bar&baz=qux", nil) w := httptest.NewRecorder() @@ -247,7 +247,7 @@ func TestExtractValue_Args(t *testing.T) { func TestExtractValue_UserAgent(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("User-Agent", "test-agent") @@ -260,7 +260,7 @@ func TestExtractValue_UserAgent(t *testing.T) { func TestExtractValue_Path(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/test-path", nil) w := httptest.NewRecorder() @@ -272,7 +272,7 @@ func TestExtractValue_Path(t *testing.T) { func TestExtractValue_URI(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/test-path?foo=bar", nil) w := httptest.NewRecorder() @@ -284,7 +284,7 @@ func TestExtractValue_URI(t *testing.T) { func TestExtractValue_Body(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) body := bytes.NewBufferString("test body") req := httptest.NewRequest("POST", "/", body) @@ -297,7 +297,7 @@ func TestExtractValue_Body(t *testing.T) { func TestExtractValue_Headers(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-Test-Header", "test-value") @@ -310,7 +310,7 @@ func TestExtractValue_Headers(t *testing.T) { func TestExtractValue_Cookies(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{Name: "test-cookie", Value: "test-value"}) @@ -323,7 +323,7 @@ func TestExtractValue_Cookies(t *testing.T) { func TestExtractValue_UnknownTarget(t *testing.T) { logger := zap.NewNop() - rve := NewRequestValueExtractor(logger, false) + rve := NewRequestValueExtractor(logger, false, 0) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -368,7 +368,7 @@ func TestProcessRuleMatch_HighScore(t *testing.T) { AnomalyThreshold: 100, // High threshold ruleHits: sync.Map{}, muMetrics: sync.RWMutex{}, - requestValueExtractor: NewRequestValueExtractor(logger.Logger, false), // Initialize + requestValueExtractor: NewRequestValueExtractor(logger.Logger, false, 0), // Initialize } rule := &Rule{ @@ -420,7 +420,7 @@ func TestValidateRule_EmptyTargets(t *testing.T) { func TestNewRequestValueExtractor(t *testing.T) { logger := zap.NewNop() redactSensitiveData := true - rve := NewRequestValueExtractor(logger, redactSensitiveData) + rve := NewRequestValueExtractor(logger, redactSensitiveData, 0) assert.NotNil(t, rve) assert.Equal(t, logger, rve.logger) @@ -448,7 +448,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { ruleCache: NewRuleCache(), ipBlacklist: iptrie.NewTrie(), dnsBlacklist: map[string]struct{}{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), rateLimiter: func() *RateLimiter { rl, err := NewRateLimiter(RateLimit{ Requests: 10, diff --git a/rules_test.go b/rules_test.go index 7eaa00d..d50b952 100644 --- a/rules_test.go +++ b/rules_test.go @@ -150,7 +150,7 @@ func TestProcessRuleMatch(t *testing.T) { AnomalyThreshold: tt.anomalyThreshold, ruleHits: sync.Map{}, muMetrics: sync.RWMutex{}, - requestValueExtractor: NewRequestValueExtractor(logger, false), + requestValueExtractor: NewRequestValueExtractor(logger, false, 0), } w := httptest.NewRecorder() diff --git a/types.go b/types.go index 3e8e221..b48207a 100644 --- a/types.go +++ b/types.go @@ -137,6 +137,8 @@ type Middleware struct { LogFilePath string LogBuffer int `json:"log_buffer,omitempty"` // Add the LogBuffer field RedactSensitiveData bool `json:"redact_sensitive_data,omitempty"` + MaxRequestBodySize int64 `json:"max_request_body_size,omitempty"` + GeoIPFailOpen bool `json:"geoip_fail_open,omitempty"` ruleHits sync.Map `json:"-"` MetricsEndpoint string `json:"metrics_endpoint,omitempty"`