From 604a2e50b8cd0c3fe61d4ae4c23b2fb1b74099fe Mon Sep 17 00:00:00 2001 From: fab Date: Sat, 11 Jan 2025 12:55:01 +0100 Subject: [PATCH] Update caddymib.go --- caddymib.go | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/caddymib.go b/caddymib.go index 65d01e8..d51b5d3 100644 --- a/caddymib.go +++ b/caddymib.go @@ -32,14 +32,15 @@ type Middleware struct { Whitelist []string `json:"whitelist,omitempty"` CustomResponseHeader string `json:"custom_response_header,omitempty"` LogRequestHeaders []string `json:"log_request_headers,omitempty"` - LogLevel string `json:"log_level,omitempty"` // New: Configurable log level - CIDRBans []string `json:"cidr_bans,omitempty"` // New: Support for CIDR range bans - BanResponseBody string `json:"ban_response_body,omitempty"` // New: Custom ban response body + LogLevel string `json:"log_level,omitempty"` + CIDRBans []string `json:"cidr_bans,omitempty"` + BanResponseBody string `json:"ban_response_body,omitempty"` + BanStatusCode int `json:"ban_status_code,omitempty"` // New: Configurable status code for bans logger *zap.Logger errorCounts map[string]int bannedIPs map[string]time.Time - bannedCIDRs []*net.IPNet // New: Track banned CIDR ranges + bannedCIDRs []*net.IPNet mu sync.RWMutex whitelistedNets []*net.IPNet } @@ -69,6 +70,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { m.logger.Info("starting caddy mib middleware") + // Set default values if m.MaxErrorCount == 0 { m.MaxErrorCount = 5 } @@ -78,6 +80,9 @@ func (m *Middleware) Provision(ctx caddy.Context) error { if m.BanDurationMultiplier == 0 { m.BanDurationMultiplier = 1 } + if m.BanStatusCode == 0 { + m.BanStatusCode = http.StatusForbidden // Default to 403 + } // Parse whitelist for _, cidr := range m.Whitelist { @@ -116,6 +121,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error { zap.Strings("log_request_headers", m.LogRequestHeaders), zap.String("log_level", m.LogLevel), zap.Strings("cidr_bans", m.CIDRBans), + zap.Int("ban_status_code", m.BanStatusCode), ) go m.cleanupExpiredBans() @@ -130,6 +136,9 @@ func (m *Middleware) Validate() error { if m.BanDuration <= 0 { return fmt.Errorf("ban_duration must be greater than 0") } + if m.BanStatusCode != http.StatusForbidden && m.BanStatusCode != http.StatusTooManyRequests { + return fmt.Errorf("ban_status_code must be 403 (Forbidden) or 429 (Too Many Requests)") + } return nil } @@ -156,7 +165,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd zap.String("client_ip", clientIP), zap.String("cidr", ipNet.String()), ) - w.WriteHeader(http.StatusForbidden) + w.WriteHeader(m.BanStatusCode) // Use configured status code if m.BanResponseBody != "" { w.Write([]byte(m.BanResponseBody)) } @@ -193,7 +202,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd m.errorCounts[clientIP] = -1 // Mark as logged m.mu.Unlock() } - w.WriteHeader(http.StatusForbidden) + w.WriteHeader(m.BanStatusCode) // Use configured status code if m.BanResponseBody != "" { w.Write([]byte(m.BanResponseBody)) } @@ -436,6 +445,19 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } m.BanResponseBody = d.Val() + case "ban_status_code": + if !d.NextArg() { + return d.ArgErr() + } + statusCode, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid ban_status_code: %s", d.Val()) + } + if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests { + return d.Errf("ban_status_code must be 403 (Forbidden) or 429 (Too Many Requests)") + } + m.BanStatusCode = statusCode + default: return d.Errf("unrecognized option: %s", d.Val()) }