Update caddymib.go

This commit is contained in:
fab
2025-01-11 13:45:11 +01:00
committed by GitHub
parent 754e7b911b
commit 0f6c168a98

View File

@@ -35,16 +35,26 @@ type Middleware struct {
LogLevel string `json:"log_level,omitempty"`
CIDRBans []string `json:"cidr_bans,omitempty"`
BanResponseBody string `json:"ban_response_body,omitempty"`
BanStatusCode int `json:"ban_status_code,omitempty"` // New: Configurable status code for bans
BanStatusCode int `json:"ban_status_code,omitempty"`
// Per-path configuration
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"`
logger *zap.Logger
errorCounts map[string]int
bannedIPs map[string]time.Time
errorCounts sync.Map // Tracks errors per IP and path
bannedIPs sync.Map // Tracks banned IPs and their expiration times
bannedCIDRs []*net.IPNet
mu sync.RWMutex
whitelistedNets []*net.IPNet
}
// PathConfig defines per-path configuration.
type PathConfig struct {
ErrorCodes []int `json:"error_codes,omitempty"`
MaxErrorCount int `json:"max_error_count,omitempty"`
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
}
// CaddyModule returns the Caddy module information.
func (Middleware) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
@@ -55,8 +65,6 @@ func (Middleware) CaddyModule() caddy.ModuleInfo {
// Provision sets up the middleware.
func (m *Middleware) Provision(ctx caddy.Context) error {
m.errorCounts = make(map[string]int)
m.bannedIPs = make(map[string]time.Time)
m.logger = ctx.Logger(m)
// Set log level
@@ -165,7 +173,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
zap.String("client_ip", clientIP),
zap.String("cidr", ipNet.String()),
)
w.WriteHeader(m.BanStatusCode) // Use configured status code
w.WriteHeader(m.BanStatusCode)
if m.BanResponseBody != "" {
w.Write([]byte(m.BanResponseBody))
}
@@ -184,25 +192,13 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
}
// Check if IP is banned
m.mu.RLock()
banTime, banned := m.bannedIPs[clientIP]
m.mu.RUnlock()
if banned {
if time.Now().Before(banTime) {
m.mu.RLock()
_, alreadyLogged := m.errorCounts[clientIP]
m.mu.RUnlock()
if !alreadyLogged {
m.logger.Info("IP is currently banned",
zap.String("client_ip", clientIP),
zap.Time("ban_expires", banTime),
)
m.mu.Lock()
m.errorCounts[clientIP] = -1 // Mark as logged
m.mu.Unlock()
}
w.WriteHeader(m.BanStatusCode) // Use configured status code
if banTime, banned := m.bannedIPs.Load(clientIP); banned {
if time.Now().Before(banTime.(time.Time)) {
m.logger.Info("IP is currently banned",
zap.String("client_ip", clientIP),
zap.Time("ban_expires", banTime.(time.Time)),
)
w.WriteHeader(m.BanStatusCode)
if m.BanResponseBody != "" {
w.Write([]byte(m.BanResponseBody))
}
@@ -211,10 +207,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
m.logger.Info("unbanning IP; ban expired",
zap.String("client_ip", clientIP),
)
m.mu.Lock()
delete(m.bannedIPs, clientIP)
delete(m.errorCounts, clientIP)
m.mu.Unlock()
m.bannedIPs.Delete(clientIP)
m.errorCounts.Delete(clientIP)
}
// Skip middleware if no error codes are specified
@@ -285,16 +279,29 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
zap.String("user_agent", r.Header.Get("User-Agent")),
}
// Check if the path has a specific configuration
pathConfig, hasPathConfig := m.PerPathConfig[path]
if hasPathConfig {
m.logger.Debug("using per-path configuration",
zap.String("path", path),
)
m.trackErrorsForPath(clientIP, code, path, r, pathConfig)
return
}
// Use global configuration
for _, errCode := range m.ErrorCodes {
if code == errCode {
m.mu.Lock()
countBefore := m.errorCounts[clientIP]
countBefore := 0
if val, ok := m.errorCounts.Load(clientIP); ok {
countBefore = val.(map[string]int)["global"]
}
m.logger.Debug("tracking error", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", m.MaxErrorCount),
)...)
m.errorCounts[clientIP] = countBefore + 1
countNow := m.errorCounts[clientIP]
countNow := countBefore + 1
m.errorCounts.Store(clientIP, map[string]int{"global": countNow})
m.logger.Debug("error count incremented", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("max_error_count", m.MaxErrorCount),
@@ -303,7 +310,7 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
offenses := countNow - m.MaxErrorCount + 1
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
expiration := time.Now().Add(banDuration)
m.bannedIPs[clientIP] = expiration
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
zap.Int("error_count", countNow),
zap.Int("max_error_count", m.MaxErrorCount),
@@ -320,7 +327,58 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
m.logger.Info("IP banned", logFields...)
}
m.mu.Unlock()
break
}
}
}
// trackErrorsForPath tracks errors for a specific path.
func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, r *http.Request, config PathConfig) {
commonFields := []zap.Field{
zap.String("client_ip", clientIP),
zap.Int("error_code", code),
zap.String("request_path", path),
zap.String("method", r.Method),
zap.String("user_agent", r.Header.Get("User-Agent")),
}
for _, errCode := range config.ErrorCodes {
if code == errCode {
countBefore := 0
if val, ok := m.errorCounts.Load(clientIP); ok {
countBefore = val.(map[string]int)[path]
}
m.logger.Debug("tracking error for path", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", config.MaxErrorCount),
)...)
countNow := countBefore + 1
m.errorCounts.Store(clientIP, map[string]int{path: countNow})
m.logger.Debug("error count incremented for path", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("max_error_count", config.MaxErrorCount),
)...)
if countNow >= config.MaxErrorCount {
offenses := countNow - config.MaxErrorCount + 1
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
expiration := time.Now().Add(banDuration)
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
zap.Int("error_count", countNow),
zap.Int("max_error_count", config.MaxErrorCount),
zap.Duration("ban_duration", banDuration),
zap.Time("ban_expires", expiration),
)
// Add configured request headers to the log
for _, headerName := range m.LogRequestHeaders {
if value := r.Header.Get(headerName); value != "" {
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
}
}
m.logger.Info("IP banned for path", logFields...)
}
break
}
}
@@ -341,19 +399,18 @@ func extractStatusCodeFromError(err error) int {
// cleanupExpiredBans periodically cleans up expired bans.
func (m *Middleware) cleanupExpiredBans() {
for {
time.Sleep(time.Minute)
m.mu.Lock()
time.Sleep(time.Second) // Check bans every second
now := time.Now()
for ip, banTime := range m.bannedIPs {
if now.After(banTime) {
m.bannedIPs.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
m.logger.Info("cleaned up expired ban",
zap.String("client_ip", ip),
zap.String("client_ip", key.(string)),
)
delete(m.bannedIPs, ip)
delete(m.errorCounts, ip)
m.bannedIPs.Delete(key)
m.errorCounts.Delete(key)
}
}
m.mu.Unlock()
return true
})
}
}
@@ -458,6 +515,67 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
m.BanStatusCode = statusCode
case "per_path":
if !d.NextArg() {
return d.ArgErr()
}
path := d.Val()
config := PathConfig{}
for d.NextBlock(1) {
switch d.Val() {
case "error_codes":
var codes []int
for d.NextArg() {
code, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid error code: %s", d.Val())
}
codes = append(codes, code)
}
if len(codes) == 0 {
return d.Err("error_codes needs at least one argument")
}
config.ErrorCodes = codes
case "max_error_count":
if !d.NextArg() {
return d.ArgErr()
}
count, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid max_error_count: %s", d.Val())
}
config.MaxErrorCount = count
case "ban_duration":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := time.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid ban_duration: %v", err)
}
config.BanDuration = caddy.Duration(dur)
case "ban_duration_multiplier":
if !d.NextArg() {
return d.ArgErr()
}
multiplier, err := strconv.ParseFloat(d.Val(), 64)
if err != nil {
return d.Errf("invalid ban_duration_multiplier: %s", d.Val())
}
config.BanDurationMultiplier = multiplier
default:
return d.Errf("unrecognized option in per_path block: %s", d.Val())
}
}
if m.PerPathConfig == nil {
m.PerPathConfig = make(map[string]PathConfig)
}
m.PerPathConfig[path] = config
default:
return d.Errf("unrecognized option: %s", d.Val())
}