mirror of
https://github.com/fabriziosalmi/caddy-mib.git
synced 2025-12-23 22:17:43 -05:00
Update caddymib.go
This commit is contained in:
131
caddymib.go
131
caddymib.go
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -30,11 +31,15 @@ type Middleware struct {
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
|
||||
Whitelist []string `json:"whitelist,omitempty"`
|
||||
CustomResponseHeader string `json:"custom_response_header,omitempty"`
|
||||
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // New field to specify headers to log
|
||||
LogRequestHeaders []string `json:"log_request_headers,omitempty"`
|
||||
LogLevel string `json:"log_level,omitempty"` // New: Configurable log level
|
||||
CIDRBans []string `json:"cidr_bans,omitempty"` // New: Support for CIDR range bans
|
||||
BanResponseBody string `json:"ban_response_body,omitempty"` // New: Custom ban response body
|
||||
|
||||
logger *zap.Logger
|
||||
errorCounts map[string]int
|
||||
bannedIPs map[string]time.Time
|
||||
bannedCIDRs []*net.IPNet // New: Track banned CIDR ranges
|
||||
mu sync.RWMutex
|
||||
whitelistedNets []*net.IPNet
|
||||
}
|
||||
@@ -53,6 +58,15 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
m.bannedIPs = make(map[string]time.Time)
|
||||
m.logger = ctx.Logger(m)
|
||||
|
||||
// Set log level
|
||||
if m.LogLevel != "" {
|
||||
level, err := zapcore.ParseLevel(m.LogLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid log level: %s", m.LogLevel)
|
||||
}
|
||||
m.logger = m.logger.WithOptions(zap.IncreaseLevel(level))
|
||||
}
|
||||
|
||||
m.logger.Info("starting caddy mib middleware")
|
||||
|
||||
if m.MaxErrorCount == 0 {
|
||||
@@ -65,6 +79,7 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
m.BanDurationMultiplier = 1
|
||||
}
|
||||
|
||||
// Parse whitelist
|
||||
for _, cidr := range m.Whitelist {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
@@ -83,13 +98,24 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
m.whitelistedNets = append(m.whitelistedNets, ipNet)
|
||||
}
|
||||
|
||||
// Parse CIDR bans
|
||||
for _, cidr := range m.CIDRBans {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR in CIDR bans: %s", cidr)
|
||||
}
|
||||
m.bannedCIDRs = append(m.bannedCIDRs, ipNet)
|
||||
}
|
||||
|
||||
m.logger.Info("caddy mib middleware provisioned",
|
||||
zap.Ints("error_codes", m.ErrorCodes),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
zap.Duration("ban_duration", time.Duration(m.BanDuration)),
|
||||
zap.Strings("whitelist", m.Whitelist),
|
||||
zap.String("custom_response_header", m.CustomResponseHeader),
|
||||
zap.Strings("log_request_headers", m.LogRequestHeaders), // Log configured request headers
|
||||
zap.Strings("log_request_headers", m.LogRequestHeaders),
|
||||
zap.String("log_level", m.LogLevel),
|
||||
zap.Strings("cidr_bans", m.CIDRBans),
|
||||
)
|
||||
|
||||
go m.cleanupExpiredBans()
|
||||
@@ -119,41 +145,62 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
return next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Normalize IP (treat IPv4 and IPv6 loopback as the same)
|
||||
clientIP = normalizeIP(clientIP)
|
||||
parsedIP := net.ParseIP(clientIP)
|
||||
|
||||
// Check if IP is in a banned CIDR range
|
||||
for _, ipNet := range m.bannedCIDRs {
|
||||
if ipNet.Contains(parsedIP) {
|
||||
m.logger.Info("IP is in a banned CIDR range",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("cidr", ipNet.String()),
|
||||
)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
if m.BanResponseBody != "" {
|
||||
w.Write([]byte(m.BanResponseBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if IP is whitelisted
|
||||
for _, ipNet := range m.whitelistedNets {
|
||||
if ipNet.Contains(parsedIP) {
|
||||
m.logger.Debug("client IP is whitelisted",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
return next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if IP is banned
|
||||
m.mu.RLock()
|
||||
banTime, banned := m.bannedIPs[clientIP]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if banned {
|
||||
if time.Now().Before(banTime) {
|
||||
m.logger.Info("IP is currently banned",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.Time("ban_expires", banTime),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
m.mu.RLock()
|
||||
_, alreadyLogged := m.errorCounts[clientIP]
|
||||
m.mu.RUnlock()
|
||||
if !alreadyLogged {
|
||||
m.logger.Info("IP is currently banned",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Time("ban_expires", banTime),
|
||||
)
|
||||
m.mu.Lock()
|
||||
m.errorCounts[clientIP] = -1 // Mark as logged
|
||||
m.mu.Unlock()
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
if m.BanResponseBody != "" {
|
||||
w.Write([]byte(m.BanResponseBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
m.logger.Info("unbanning IP; ban expired",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.Time("previous_ban_expiration", banTime),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
m.mu.Lock()
|
||||
delete(m.bannedIPs, clientIP)
|
||||
@@ -161,25 +208,21 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Skip middleware if no error codes are specified
|
||||
if len(m.ErrorCodes) == 0 {
|
||||
m.logger.Debug("no error codes specified; skipping middleware",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
return next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Record the response from the next handler
|
||||
rec := caddyhttp.NewResponseRecorder(w, nil, nil)
|
||||
err = next.ServeHTTP(rec, r)
|
||||
if err != nil {
|
||||
m.logger.Error("error in next handler",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
statusCode := extractStatusCodeFromError(err)
|
||||
if statusCode == 0 {
|
||||
@@ -188,21 +231,16 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
m.logger.Debug("extracted status code from error",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
m.trackErrorStatus(clientIP, statusCode, r.URL.Path, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// Track the response status code
|
||||
statusCode := rec.Status()
|
||||
m.logger.Debug("response status code",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_path", r.URL.Path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
)
|
||||
m.trackErrorStatus(clientIP, statusCode, r.URL.Path, r)
|
||||
|
||||
@@ -214,6 +252,21 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeIP normalizes IPv4 and IPv6 loopback addresses.
|
||||
func normalizeIP(ip string) string {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return ip
|
||||
}
|
||||
if parsedIP.IsLoopback() {
|
||||
if parsedIP.To4() != nil {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
return "::1"
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r *http.Request) {
|
||||
commonFields := []zap.Field{
|
||||
zap.String("client_ip", clientIP),
|
||||
@@ -257,7 +310,6 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
}
|
||||
|
||||
m.logger.Info("IP banned", logFields...)
|
||||
// No need to write header here, it's done in ServeHTTP
|
||||
}
|
||||
m.mu.Unlock()
|
||||
break
|
||||
@@ -358,13 +410,32 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
m.CustomResponseHeader = d.Val()
|
||||
|
||||
case "log_request_headers": // New Caddyfile option
|
||||
case "log_request_headers":
|
||||
var headers []string
|
||||
for d.NextArg() {
|
||||
headers = append(headers, d.Val())
|
||||
}
|
||||
m.LogRequestHeaders = headers
|
||||
|
||||
case "log_level":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.LogLevel = d.Val()
|
||||
|
||||
case "cidr_bans":
|
||||
var cidrBans []string
|
||||
for d.NextArg() {
|
||||
cidrBans = append(cidrBans, d.Val())
|
||||
}
|
||||
m.CIDRBans = cidrBans
|
||||
|
||||
case "ban_response_body":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
m.BanResponseBody = d.Val()
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized option: %s", d.Val())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user