mirror of
https://github.com/fabriziosalmi/caddy-mib.git
synced 2025-12-23 22:17:43 -05:00
Update caddymib.go
This commit is contained in:
208
caddymib.go
208
caddymib.go
@@ -35,16 +35,26 @@ type Middleware struct {
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
CIDRBans []string `json:"cidr_bans,omitempty"`
|
||||
BanResponseBody string `json:"ban_response_body,omitempty"`
|
||||
BanStatusCode int `json:"ban_status_code,omitempty"` // New: Configurable status code for bans
|
||||
BanStatusCode int `json:"ban_status_code,omitempty"`
|
||||
|
||||
// Per-path configuration
|
||||
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"`
|
||||
|
||||
logger *zap.Logger
|
||||
errorCounts map[string]int
|
||||
bannedIPs map[string]time.Time
|
||||
errorCounts sync.Map // Tracks errors per IP and path
|
||||
bannedIPs sync.Map // Tracks banned IPs and their expiration times
|
||||
bannedCIDRs []*net.IPNet
|
||||
mu sync.RWMutex
|
||||
whitelistedNets []*net.IPNet
|
||||
}
|
||||
|
||||
// PathConfig defines per-path configuration.
|
||||
type PathConfig struct {
|
||||
ErrorCodes []int `json:"error_codes,omitempty"`
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"`
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
func (Middleware) CaddyModule() caddy.ModuleInfo {
|
||||
return caddy.ModuleInfo{
|
||||
@@ -55,8 +65,6 @@ func (Middleware) CaddyModule() caddy.ModuleInfo {
|
||||
|
||||
// Provision sets up the middleware.
|
||||
func (m *Middleware) Provision(ctx caddy.Context) error {
|
||||
m.errorCounts = make(map[string]int)
|
||||
m.bannedIPs = make(map[string]time.Time)
|
||||
m.logger = ctx.Logger(m)
|
||||
|
||||
// Set log level
|
||||
@@ -165,7 +173,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("cidr", ipNet.String()),
|
||||
)
|
||||
w.WriteHeader(m.BanStatusCode) // Use configured status code
|
||||
w.WriteHeader(m.BanStatusCode)
|
||||
if m.BanResponseBody != "" {
|
||||
w.Write([]byte(m.BanResponseBody))
|
||||
}
|
||||
@@ -184,25 +192,13 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
}
|
||||
|
||||
// Check if IP is banned
|
||||
m.mu.RLock()
|
||||
banTime, banned := m.bannedIPs[clientIP]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if banned {
|
||||
if time.Now().Before(banTime) {
|
||||
m.mu.RLock()
|
||||
_, alreadyLogged := m.errorCounts[clientIP]
|
||||
m.mu.RUnlock()
|
||||
if !alreadyLogged {
|
||||
m.logger.Info("IP is currently banned",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Time("ban_expires", banTime),
|
||||
)
|
||||
m.mu.Lock()
|
||||
m.errorCounts[clientIP] = -1 // Mark as logged
|
||||
m.mu.Unlock()
|
||||
}
|
||||
w.WriteHeader(m.BanStatusCode) // Use configured status code
|
||||
if banTime, banned := m.bannedIPs.Load(clientIP); banned {
|
||||
if time.Now().Before(banTime.(time.Time)) {
|
||||
m.logger.Info("IP is currently banned",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Time("ban_expires", banTime.(time.Time)),
|
||||
)
|
||||
w.WriteHeader(m.BanStatusCode)
|
||||
if m.BanResponseBody != "" {
|
||||
w.Write([]byte(m.BanResponseBody))
|
||||
}
|
||||
@@ -211,10 +207,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
m.logger.Info("unbanning IP; ban expired",
|
||||
zap.String("client_ip", clientIP),
|
||||
)
|
||||
m.mu.Lock()
|
||||
delete(m.bannedIPs, clientIP)
|
||||
delete(m.errorCounts, clientIP)
|
||||
m.mu.Unlock()
|
||||
m.bannedIPs.Delete(clientIP)
|
||||
m.errorCounts.Delete(clientIP)
|
||||
}
|
||||
|
||||
// Skip middleware if no error codes are specified
|
||||
@@ -285,16 +279,29 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
}
|
||||
|
||||
// Check if the path has a specific configuration
|
||||
pathConfig, hasPathConfig := m.PerPathConfig[path]
|
||||
if hasPathConfig {
|
||||
m.logger.Debug("using per-path configuration",
|
||||
zap.String("path", path),
|
||||
)
|
||||
m.trackErrorsForPath(clientIP, code, path, r, pathConfig)
|
||||
return
|
||||
}
|
||||
|
||||
// Use global configuration
|
||||
for _, errCode := range m.ErrorCodes {
|
||||
if code == errCode {
|
||||
m.mu.Lock()
|
||||
countBefore := m.errorCounts[clientIP]
|
||||
countBefore := 0
|
||||
if val, ok := m.errorCounts.Load(clientIP); ok {
|
||||
countBefore = val.(map[string]int)["global"]
|
||||
}
|
||||
m.logger.Debug("tracking error", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
)...)
|
||||
m.errorCounts[clientIP] = countBefore + 1
|
||||
countNow := m.errorCounts[clientIP]
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(clientIP, map[string]int{"global": countNow})
|
||||
m.logger.Debug("error count incremented", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
@@ -303,7 +310,7 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
offenses := countNow - m.MaxErrorCount + 1
|
||||
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs[clientIP] = expiration
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
zap.Int("error_count", countNow),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
@@ -320,7 +327,58 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
|
||||
m.logger.Info("IP banned", logFields...)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackErrorsForPath tracks errors for a specific path.
|
||||
func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, r *http.Request, config PathConfig) {
|
||||
commonFields := []zap.Field{
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.Int("error_code", code),
|
||||
zap.String("request_path", path),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
}
|
||||
|
||||
for _, errCode := range config.ErrorCodes {
|
||||
if code == errCode {
|
||||
countBefore := 0
|
||||
if val, ok := m.errorCounts.Load(clientIP); ok {
|
||||
countBefore = val.(map[string]int)[path]
|
||||
}
|
||||
m.logger.Debug("tracking error for path", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
)...)
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(clientIP, map[string]int{path: countNow})
|
||||
m.logger.Debug("error count incremented for path", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
)...)
|
||||
if countNow >= config.MaxErrorCount {
|
||||
offenses := countNow - config.MaxErrorCount + 1
|
||||
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
zap.Int("error_count", countNow),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
zap.Duration("ban_duration", banDuration),
|
||||
zap.Time("ban_expires", expiration),
|
||||
)
|
||||
|
||||
// Add configured request headers to the log
|
||||
for _, headerName := range m.LogRequestHeaders {
|
||||
if value := r.Header.Get(headerName); value != "" {
|
||||
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Info("IP banned for path", logFields...)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -341,19 +399,18 @@ func extractStatusCodeFromError(err error) int {
|
||||
// cleanupExpiredBans periodically cleans up expired bans.
|
||||
func (m *Middleware) cleanupExpiredBans() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
m.mu.Lock()
|
||||
time.Sleep(time.Second) // Check bans every second
|
||||
now := time.Now()
|
||||
for ip, banTime := range m.bannedIPs {
|
||||
if now.After(banTime) {
|
||||
m.bannedIPs.Range(func(key, value interface{}) bool {
|
||||
if now.After(value.(time.Time)) {
|
||||
m.logger.Info("cleaned up expired ban",
|
||||
zap.String("client_ip", ip),
|
||||
zap.String("client_ip", key.(string)),
|
||||
)
|
||||
delete(m.bannedIPs, ip)
|
||||
delete(m.errorCounts, ip)
|
||||
m.bannedIPs.Delete(key)
|
||||
m.errorCounts.Delete(key)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,6 +515,67 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
m.BanStatusCode = statusCode
|
||||
|
||||
case "per_path":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
path := d.Val()
|
||||
config := PathConfig{}
|
||||
for d.NextBlock(1) {
|
||||
switch d.Val() {
|
||||
case "error_codes":
|
||||
var codes []int
|
||||
for d.NextArg() {
|
||||
code, err := strconv.Atoi(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid error code: %s", d.Val())
|
||||
}
|
||||
codes = append(codes, code)
|
||||
}
|
||||
if len(codes) == 0 {
|
||||
return d.Err("error_codes needs at least one argument")
|
||||
}
|
||||
config.ErrorCodes = codes
|
||||
|
||||
case "max_error_count":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
count, err := strconv.Atoi(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid max_error_count: %s", d.Val())
|
||||
}
|
||||
config.MaxErrorCount = count
|
||||
|
||||
case "ban_duration":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid ban_duration: %v", err)
|
||||
}
|
||||
config.BanDuration = caddy.Duration(dur)
|
||||
|
||||
case "ban_duration_multiplier":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
multiplier, err := strconv.ParseFloat(d.Val(), 64)
|
||||
if err != nil {
|
||||
return d.Errf("invalid ban_duration_multiplier: %s", d.Val())
|
||||
}
|
||||
config.BanDurationMultiplier = multiplier
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized option in per_path block: %s", d.Val())
|
||||
}
|
||||
}
|
||||
if m.PerPathConfig == nil {
|
||||
m.PerPathConfig = make(map[string]PathConfig)
|
||||
}
|
||||
m.PerPathConfig[path] = config
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized option: %s", d.Val())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user