Update caddymib.go

This commit is contained in:
fab
2025-01-11 12:36:37 +01:00
committed by GitHub
parent 0461d9ea78
commit 4b9477d2c5

View File

@@ -15,6 +15,7 @@ import (
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func init() {
@@ -30,11 +31,15 @@ type Middleware struct {
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
Whitelist []string `json:"whitelist,omitempty"`
CustomResponseHeader string `json:"custom_response_header,omitempty"`
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // New field to specify headers to log
LogRequestHeaders []string `json:"log_request_headers,omitempty"`
LogLevel string `json:"log_level,omitempty"` // New: Configurable log level
CIDRBans []string `json:"cidr_bans,omitempty"` // New: Support for CIDR range bans
BanResponseBody string `json:"ban_response_body,omitempty"` // New: Custom ban response body
logger *zap.Logger
errorCounts map[string]int
bannedIPs map[string]time.Time
bannedCIDRs []*net.IPNet // New: Track banned CIDR ranges
mu sync.RWMutex
whitelistedNets []*net.IPNet
}
@@ -53,6 +58,15 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
m.bannedIPs = make(map[string]time.Time)
m.logger = ctx.Logger(m)
// Set log level
if m.LogLevel != "" {
level, err := zapcore.ParseLevel(m.LogLevel)
if err != nil {
return fmt.Errorf("invalid log level: %s", m.LogLevel)
}
m.logger = m.logger.WithOptions(zap.IncreaseLevel(level))
}
m.logger.Info("starting caddy mib middleware")
if m.MaxErrorCount == 0 {
@@ -65,6 +79,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
m.BanDurationMultiplier = 1
}
// Parse whitelist
for _, cidr := range m.Whitelist {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
@@ -83,13 +98,24 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
m.whitelistedNets = append(m.whitelistedNets, ipNet)
}
// Parse CIDR bans
for _, cidr := range m.CIDRBans {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR in CIDR bans: %s", cidr)
}
m.bannedCIDRs = append(m.bannedCIDRs, ipNet)
}
m.logger.Info("caddy mib middleware provisioned",
zap.Ints("error_codes", m.ErrorCodes),
zap.Int("max_error_count", m.MaxErrorCount),
zap.Duration("ban_duration", time.Duration(m.BanDuration)),
zap.Strings("whitelist", m.Whitelist),
zap.String("custom_response_header", m.CustomResponseHeader),
zap.Strings("log_request_headers", m.LogRequestHeaders), // Log configured request headers
zap.Strings("log_request_headers", m.LogRequestHeaders),
zap.String("log_level", m.LogLevel),
zap.Strings("cidr_bans", m.CIDRBans),
)
go m.cleanupExpiredBans()
@@ -119,41 +145,62 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
return next.ServeHTTP(w, r)
}
// Normalize IP (treat IPv4 and IPv6 loopback as the same)
clientIP = normalizeIP(clientIP)
parsedIP := net.ParseIP(clientIP)
// Check if IP is in a banned CIDR range
for _, ipNet := range m.bannedCIDRs {
if ipNet.Contains(parsedIP) {
m.logger.Info("IP is in a banned CIDR range",
zap.String("client_ip", clientIP),
zap.String("cidr", ipNet.String()),
)
w.WriteHeader(http.StatusForbidden)
if m.BanResponseBody != "" {
w.Write([]byte(m.BanResponseBody))
}
return nil
}
}
// Check if IP is whitelisted
for _, ipNet := range m.whitelistedNets {
if ipNet.Contains(parsedIP) {
m.logger.Debug("client IP is whitelisted",
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
return next.ServeHTTP(w, r)
}
}
// Check if IP is banned
m.mu.RLock()
banTime, banned := m.bannedIPs[clientIP]
m.mu.RUnlock()
if banned {
if time.Now().Before(banTime) {
m.logger.Info("IP is currently banned",
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.Time("ban_expires", banTime),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
m.mu.RLock()
_, alreadyLogged := m.errorCounts[clientIP]
m.mu.RUnlock()
if !alreadyLogged {
m.logger.Info("IP is currently banned",
zap.String("client_ip", clientIP),
zap.Time("ban_expires", banTime),
)
m.mu.Lock()
m.errorCounts[clientIP] = -1 // Mark as logged
m.mu.Unlock()
}
w.WriteHeader(http.StatusForbidden)
if m.BanResponseBody != "" {
w.Write([]byte(m.BanResponseBody))
}
return nil
}
m.logger.Info("unbanning IP; ban expired",
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.Time("previous_ban_expiration", banTime),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
m.mu.Lock()
delete(m.bannedIPs, clientIP)
@@ -161,25 +208,21 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
m.mu.Unlock()
}
// Skip middleware if no error codes are specified
if len(m.ErrorCodes) == 0 {
m.logger.Debug("no error codes specified; skipping middleware",
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
return next.ServeHTTP(w, r)
}
// Record the response from the next handler
rec := caddyhttp.NewResponseRecorder(w, nil, nil)
err = next.ServeHTTP(rec, r)
if err != nil {
m.logger.Error("error in next handler",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
statusCode := extractStatusCodeFromError(err)
if statusCode == 0 {
@@ -188,21 +231,16 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
m.logger.Debug("extracted status code from error",
zap.Int("status_code", statusCode),
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
m.trackErrorStatus(clientIP, statusCode, r.URL.Path, r)
return err
}
// Track the response status code
statusCode := rec.Status()
m.logger.Debug("response status code",
zap.Int("status_code", statusCode),
zap.String("client_ip", clientIP),
zap.String("request_path", r.URL.Path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
)
m.trackErrorStatus(clientIP, statusCode, r.URL.Path, r)
@@ -214,6 +252,21 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
return nil
}
// normalizeIP normalizes IPv4 and IPv6 loopback addresses.
func normalizeIP(ip string) string {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return ip
}
if parsedIP.IsLoopback() {
if parsedIP.To4() != nil {
return "127.0.0.1"
}
return "::1"
}
return ip
}
func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r *http.Request) {
commonFields := []zap.Field{
zap.String("client_ip", clientIP),
@@ -257,7 +310,6 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
}
m.logger.Info("IP banned", logFields...)
// No need to write header here, it's done in ServeHTTP
}
m.mu.Unlock()
break
@@ -358,13 +410,32 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
m.CustomResponseHeader = d.Val()
case "log_request_headers": // New Caddyfile option
case "log_request_headers":
var headers []string
for d.NextArg() {
headers = append(headers, d.Val())
}
m.LogRequestHeaders = headers
case "log_level":
if !d.NextArg() {
return d.ArgErr()
}
m.LogLevel = d.Val()
case "cidr_bans":
var cidrBans []string
for d.NextArg() {
cidrBans = append(cidrBans, d.Val())
}
m.CIDRBans = cidrBans
case "ban_response_body":
if !d.NextArg() {
return d.ArgErr()
}
m.BanResponseBody = d.Val()
default:
return d.Errf("unrecognized option: %s", d.Val())
}