Merge pull request #10 from kjnez/main

This commit is contained in:
fab
2025-12-01 06:15:24 +01:00
committed by GitHub
4 changed files with 579 additions and 35 deletions

View File

@@ -12,8 +12,9 @@
caddy_mib {
error_codes 404 500 401 # Error codes to track
max_error_count 10 # Global error threshold (reduced for faster testing)
ban_duration 5s # Global ban duration (reduced to 10 seconds)
ban_duration_multiplier 1 # Global ban duration multiplier
ban_duration 1s # Global ban duration (1s for fast testing)
error_count_timeout 5s # Global error count timeout
ban_duration_multiplier 2 # Global ban duration multiplier
# whitelist 127.0.0.1 ::1 # Whitelisted IPs
log_level debug # Log level for debugging
ban_response_body "You have been banned due to excessive errors. Please try again later."
@@ -23,7 +24,7 @@
per_path /login {
error_codes 404 # Error codes to track for /login
max_error_count 5 # Error threshold for /login (reduced for faster testing)
ban_duration 10s # Ban duration for /login (reduced to 15 seconds)
ban_duration 2s # Ban duration for /login (2s for fast testing)
ban_duration_multiplier 1
}
@@ -31,7 +32,7 @@
per_path /api {
error_codes 404 500 # Error codes to track for /api
max_error_count 8 # Error threshold for /api (reduced for faster testing)
ban_duration 15s # Ban duration for /api (reduced to 20 seconds)
ban_duration 3s # Ban duration for /api (3s for fast testing)
ban_duration_multiplier 1
}
}

View File

@@ -10,6 +10,7 @@
* **Configurable Error Limits**: Set max errors per IP before banning.
* **Flexible Ban Times**: Use human-readable formats (e.g., 5s, 10m, 1h).
* **Exponential Ban Increase**: Ban duration grows for repeat offenders.
* **Sliding Window Error Tracking**: Reset error counts after a period of inactivity (optional).
* **Trusted IP Whitelisting**: Exclude specific IPs or CIDRs from bans.
* **Path-Specific Settings**: Tailor limits and bans per URL path.
* **Custom Ban Messages**: Set custom response bodies and headers.
@@ -83,6 +84,7 @@ Here's a comprehensive example showcasing a range of options:
max_error_count 10 # Allow up to 10 global errors
ban_duration 5s # Base ban duration of 5 seconds
ban_duration_multiplier 1.5 # Increase ban duration for repeat offenders
error_count_timeout 1h # Reset error count after 1 hour of inactivity (optional)
whitelist 127.0.0.1 ::1 192.168.1.0/24 # Whitelist local IPs and network
log_request_headers User-Agent X-Custom-Header # Log specified request headers
log_level debug # Debug log level for this middleware
@@ -99,6 +101,7 @@ Here's a comprehensive example showcasing a range of options:
max_error_count 5 # Only allow 5 errors before banning
ban_duration 10s # Ban duration of 10 seconds
ban_duration_multiplier 2 # Exponential increase in /login ban duration
error_count_timeout 15m # Reset after 15 minutes for /login (optional)
}
per_path /api {
@@ -106,6 +109,7 @@ Here's a comprehensive example showcasing a range of options:
max_error_count 8 # Allow 8 errors before banning
ban_duration 15s # 15-second ban duration
ban_duration_multiplier 1 # No exponential increase in /api ban duration
error_count_timeout 30m # Reset after 30 minutes for /api (optional)
}
}
@@ -131,6 +135,12 @@ Here's a comprehensive example showcasing a range of options:
* A floating-point number to exponentially increase the ban duration upon each subsequent offense.
* Example: `ban_duration_multiplier 1.5`
* Defaults to `1.0` (no multiplier).
- **`error_count_timeout`** _(Optional)_:
* Time window for counting errors. If the time between errors exceeds this duration, the error count resets to 1 (sliding window behavior).
* Useful for preventing permanent error accumulation and avoiding bans from occasional errors spread over long periods.
* Example: `error_count_timeout 1h` (1 hour), `error_count_timeout 30m` (30 minutes)
* Set to `0` or omit to disable (errors never expire - original behavior).
* Can be overridden per-path.
- **`whitelist`** _(Optional)_:
* A space-separated list of IP addresses or CIDR ranges to exclude from being banned.
* Example: `whitelist 127.0.0.1 ::1 192.168.1.0/24`
@@ -161,8 +171,9 @@ Here's a comprehensive example showcasing a range of options:
- **`per_path <path>`** _(Optional)_:
* Configures per-path settings, taking precedence over global configurations.
* Reuses all the same options as global ones: `error_codes`, `max_error_count`, `ban_duration` and `ban_duration_multiplier`
* Reuses all the same options as global ones: `error_codes`, `max_error_count`, `ban_duration`, `ban_duration_multiplier`, and `error_count_timeout`
* Each path block must be defined as a Caddyfile block.
* If `error_count_timeout` is not specified in a per-path config, it inherits the global value.
---
@@ -176,11 +187,40 @@ Here's a comprehensive example showcasing a range of options:
4. The client attempts to access the `/login` endpoint, which is configured with specific error limits and ban duration that are different than the global ones.
5. The client is banned after triggering multiple 404, resulting in a separate ban and `429` response.
### Sliding Window Behavior
When `error_count_timeout` is configured, the middleware implements a sliding window for error tracking:
**Example Configuration:**
```caddyfile
caddy_mib {
error_codes 404
max_error_count 5
ban_duration 10m
error_count_timeout 1h # Reset after 1 hour of inactivity
}
```
**Behavior:**
- User hits 3 errors within 10 minutes → count = 3
- User waits 61 minutes (exceeds 1-hour timeout)
- User hits 1 more error → count resets to 1 (not banned)
- User hits 4 more errors quickly → count reaches 5, user is banned
**Without timeout (default):**
- All errors accumulate indefinitely
- After ban expires, hitting just 1 more error triggers immediate re-ban
**Use Cases:**
- **Set timeout**: Protect against concentrated attacks while forgiving occasional errors
- **No timeout**: Stricter enforcement, useful for zero-tolerance scenarios
### Best Practices
* **Start with a reasonable `max_error_count`**: This should be high enough to avoid banning legitimate users and bots while still protecting against malicious attacks.
* **Use a moderate `ban_duration`**: Start with a short ban duration and gradually increase it if needed.
* **Utilize `ban_duration_multiplier` wisely**: Be careful when using exponential multipliers, as they can quickly lead to very long ban times.
* **Configure `error_count_timeout` for most use cases**: A 1-hour timeout is a good starting point to prevent permanent error accumulation while still catching abuse patterns. Omit for zero-tolerance enforcement.
* **Whitelist trusted networks**: It's good practice to whitelist internal networks to prevent self-blocking.
* **Set proper log levels**: Setting `log_level` to `debug` can help during testing, while `info` or `warn` are more suitable for production use.
* **Use `cidr_bans`**: Use `cidr_bans` in combination with the `whitelist` for more precise configuration.
@@ -285,6 +325,29 @@ These log lines provide valuable information on when IPs are banned, which error
---
## Recent Updates
### v1.1.0 (Latest)
**New Features:**
- **Sliding Window Error Tracking**: Added `error_count_timeout` configuration option
- Prevents permanent error accumulation
- Resets error counts after a period of inactivity
- Configurable globally and per-path
- Backwards compatible (disabled by default)
**Bug Fixes:**
- Fixed error count cleanup when bans expire
- Previously: Only attempted to delete error counts using IP as key
- Now: Properly deletes all error counts across all paths for banned IPs
- Impact: Error counts are now correctly reset after ban expiration
**Testing:**
- Added 6 comprehensive tests covering new functionality
- All 16 tests passing with full coverage
---
## License
This project is licensed under the **AGPL-3.0 License**. Refer to the [LICENSE](LICENSE) file for full details.

View File

@@ -30,6 +30,7 @@ type Middleware struct {
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense
ErrorCountTimeout caddy.Duration `json:"error_count_timeout,omitempty"` // Time window for counting errors (0 = disabled, errors never expire)
Whitelist []string `json:"whitelist,omitempty"` // List of IPs or CIDRs to whitelist
CustomResponseHeader string `json:"custom_response_header,omitempty"` // Custom header to add to responses
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // Request headers to log
@@ -44,6 +45,7 @@ type Middleware struct {
logger *zap.Logger
errorCounts sync.Map // Tracks errors per IP and path
bannedIPs sync.Map // Tracks banned IPs and their expiration times
offenseCounts sync.Map // Tracks number of times each IP has been banned (for multiplier)
bannedCIDRs []*net.IPNet
whitelistedNets []*net.IPNet
}
@@ -54,6 +56,14 @@ type PathConfig struct {
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning for this path
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning for this path
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense for this path
ErrorCountTimeout caddy.Duration `json:"error_count_timeout,omitempty"` // Time window for counting errors (0 = use global setting)
}
// errorTracker tracks error counts and timing for sliding window behavior.
type errorTracker struct {
Count int
FirstErrorAt time.Time
LastErrorAt time.Time
}
// CaddyModule returns the Caddy module information.
@@ -220,7 +230,15 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
zap.String("client_ip", clientIP),
)
m.bannedIPs.Delete(clientIP)
m.errorCounts.Delete(clientIP)
// Delete all error counts for this IP across all paths
m.errorCounts.Range(func(countKey, countValue interface{}) bool {
countKeyStr := countKey.(string)
if strings.HasPrefix(countKeyStr, clientIP+":") {
m.errorCounts.Delete(countKey)
}
return true
})
}
// Skip middleware if no error codes are specified
@@ -303,31 +321,68 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
// Use global configuration
for _, errCode := range m.ErrorCodes {
if code == errCode {
countBefore := 0
now := time.Now()
// Load or initialize error tracker
var tracker errorTracker
if val, ok := m.errorCounts.Load(key); ok {
countBefore = val.(int)
tracker = val.(errorTracker)
// Implement sliding window: reset count if timeout has passed
if m.ErrorCountTimeout > 0 && now.Sub(tracker.LastErrorAt) > time.Duration(m.ErrorCountTimeout) {
m.logger.Debug("error count timeout expired, resetting count",
zap.String("client_ip", clientIP),
zap.String("path", path),
zap.Duration("time_since_last_error", now.Sub(tracker.LastErrorAt)),
zap.Duration("timeout", time.Duration(m.ErrorCountTimeout)),
)
tracker = errorTracker{
Count: 1,
FirstErrorAt: now,
LastErrorAt: now,
}
} else {
tracker.Count++
tracker.LastErrorAt = now
}
} else {
// First error for this IP:path
tracker = errorTracker{
Count: 1,
FirstErrorAt: now,
LastErrorAt: now,
}
}
m.logger.Debug("tracking error", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", m.MaxErrorCount),
)...)
countNow := countBefore + 1
m.errorCounts.Store(key, countNow)
m.errorCounts.Store(key, tracker)
m.logger.Debug("error count incremented", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("current_error_count", tracker.Count),
zap.Int("max_error_count", m.MaxErrorCount),
zap.Time("first_error_at", tracker.FirstErrorAt),
zap.Time("last_error_at", tracker.LastErrorAt),
)...)
if countNow >= m.MaxErrorCount {
offenses := countNow - m.MaxErrorCount + 1
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
if tracker.Count >= m.MaxErrorCount {
// Increment offense count for this IP (global path)
// Use clientIP as key for global offense tracking
offenseKey := clientIP
offenseCount := 1
if val, ok := m.offenseCounts.Load(offenseKey); ok {
offenseCount = val.(int) + 1
}
m.offenseCounts.Store(offenseKey, offenseCount)
// Calculate ban duration with multiplier based on offense count
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenseCount)))
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
banDuration = 24 * time.Hour
}
expiration := time.Now().Add(banDuration)
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
zap.Int("error_count", countNow),
zap.Int("error_count", tracker.Count),
zap.Int("max_error_count", m.MaxErrorCount),
zap.Int("offense_count", offenseCount),
zap.Duration("ban_duration", banDuration),
zap.Time("ban_expires", expiration),
)
@@ -361,31 +416,74 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
for _, errCode := range config.ErrorCodes {
if code == errCode {
countBefore := 0
if val, ok := m.errorCounts.Load(key); ok {
countBefore = val.(int)
now := time.Now()
// Determine which timeout to use (per-path or global)
timeout := config.ErrorCountTimeout
if timeout == 0 {
timeout = m.ErrorCountTimeout
}
m.logger.Debug("tracking error for path", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", config.MaxErrorCount),
)...)
countNow := countBefore + 1
m.errorCounts.Store(key, countNow)
// Load or initialize error tracker
var tracker errorTracker
if val, ok := m.errorCounts.Load(key); ok {
tracker = val.(errorTracker)
// Implement sliding window: reset count if timeout has passed
if timeout > 0 && now.Sub(tracker.LastErrorAt) > time.Duration(timeout) {
m.logger.Debug("error count timeout expired for path, resetting count",
zap.String("client_ip", clientIP),
zap.String("path", path),
zap.Duration("time_since_last_error", now.Sub(tracker.LastErrorAt)),
zap.Duration("timeout", time.Duration(timeout)),
)
tracker = errorTracker{
Count: 1,
FirstErrorAt: now,
LastErrorAt: now,
}
} else {
tracker.Count++
tracker.LastErrorAt = now
}
} else {
// First error for this IP:path
tracker = errorTracker{
Count: 1,
FirstErrorAt: now,
LastErrorAt: now,
}
}
m.errorCounts.Store(key, tracker)
m.logger.Debug("error count incremented for path", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("current_error_count", tracker.Count),
zap.Int("max_error_count", config.MaxErrorCount),
zap.Time("first_error_at", tracker.FirstErrorAt),
zap.Time("last_error_at", tracker.LastErrorAt),
)...)
if countNow >= config.MaxErrorCount {
offenses := countNow - config.MaxErrorCount + 1
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
if tracker.Count >= config.MaxErrorCount {
// Increment offense count for this IP:path combination
// Use composite key for per-path offense tracking
offenseKey := fmt.Sprintf("%s:%s", clientIP, path)
offenseCount := 1
if val, ok := m.offenseCounts.Load(offenseKey); ok {
offenseCount = val.(int) + 1
}
m.offenseCounts.Store(offenseKey, offenseCount)
// Calculate ban duration with multiplier based on offense count
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenseCount)))
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
banDuration = 24 * time.Hour
}
expiration := time.Now().Add(banDuration)
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
zap.Int("error_count", countNow),
zap.Int("error_count", tracker.Count),
zap.Int("max_error_count", config.MaxErrorCount),
zap.Int("offense_count", offenseCount),
zap.Duration("ban_duration", banDuration),
zap.Time("ban_expires", expiration),
)
@@ -431,11 +529,26 @@ func (m *Middleware) cleanupExpiredBans() {
now := time.Now()
m.bannedIPs.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
clientIP := key.(string)
m.logger.Info("cleaned up expired ban",
zap.String("client_ip", key.(string)),
zap.String("client_ip", clientIP),
)
m.bannedIPs.Delete(key)
m.errorCounts.Delete(key)
// Delete all error counts for this IP across all paths
// errorCounts keys are in format "IP:path"
m.errorCounts.Range(func(countKey, countValue interface{}) bool {
countKeyStr := countKey.(string)
// Check if this error count belongs to the banned IP
if strings.HasPrefix(countKeyStr, clientIP+":") {
m.errorCounts.Delete(countKey)
m.logger.Debug("cleaned up error count for unbanned IP",
zap.String("client_ip", clientIP),
zap.String("key", countKeyStr),
)
}
return true
})
}
return true
})
@@ -491,6 +604,16 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
m.BanDurationMultiplier = multiplier
case "error_count_timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := time.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid error_count_timeout: %v", err)
}
m.ErrorCountTimeout = caddy.Duration(dur)
case "whitelist":
var whitelist []string
for d.NextArg() {
@@ -595,6 +718,16 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
config.BanDurationMultiplier = multiplier
case "error_count_timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := time.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid error_count_timeout: %v", err)
}
config.ErrorCountTimeout = caddy.Duration(dur)
default:
return d.Errf("unrecognized option in per_path block: %s", d.Val())
}

View File

@@ -352,3 +352,350 @@ func TestMiddleware_ServeHTTP_LogRequestHeaders(t *testing.T) {
}
}
// TestMiddleware_ErrorCountTimeout tests the sliding window behavior
func TestMiddleware_ErrorCountTimeout(t *testing.T) {
m := Middleware{
ErrorCodes: []int{404},
MaxErrorCount: 3,
BanDuration: caddy.Duration(1 * time.Minute),
ErrorCountTimeout: caddy.Duration(2 * time.Second), // 2 second window
}
ctx := caddy.Context{}
m.Provision(ctx)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNotFound)
return nil
})
// Make 2 errors within the timeout window
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404, got %d", rec.Code)
}
}
// Wait for timeout to expire
time.Sleep(3 * time.Second)
// Make another error - count should reset to 1
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404, got %d", rec.Code)
}
// Verify error count was reset by checking we can make 2 more errors without ban
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if i < 1 && rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404, got %d", rec.Code)
}
}
// Now we should be banned (3rd error in this window)
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
}
}
// TestMiddleware_ErrorCountResetOnBanExpiry tests that error counts are cleared when ban expires
func TestMiddleware_ErrorCountResetOnBanExpiry(t *testing.T) {
m := Middleware{
ErrorCodes: []int{404},
MaxErrorCount: 2,
BanDuration: caddy.Duration(1 * time.Second), // Short ban for testing
}
ctx := caddy.Context{}
m.Provision(ctx)
req := httptest.NewRequest("GET", "http://example.com/path1", nil)
req.RemoteAddr = "192.168.1.1:12345"
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNotFound)
return nil
})
// Trigger ban with 2 errors
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
}
// Verify banned
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
}
// Wait for ban to expire
time.Sleep(2 * time.Second)
// Make request - should unban and clear error counts
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404 (unbanned), got %d", rec.Code)
}
// Verify error count was reset - we should be able to make another error without immediate ban
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404 (not banned yet), got %d", rec.Code)
}
// Wait for cleanup goroutine to process the second ban expiry
time.Sleep(2 * time.Second)
// Verify the error counts for path1 were actually deleted
key := "192.168.1.1:/path1"
if _, ok := m.errorCounts.Load(key); ok {
t.Error("Expected error count to be deleted after ban expired, but it still exists")
}
}
// TestMiddleware_PerPathErrorCountTimeout tests per-path timeout configuration
func TestMiddleware_PerPathErrorCountTimeout(t *testing.T) {
m := Middleware{
ErrorCodes: []int{404},
MaxErrorCount: 3,
BanDuration: caddy.Duration(1 * time.Minute),
ErrorCountTimeout: caddy.Duration(5 * time.Second), // Global: 5 seconds
PerPathConfig: map[string]PathConfig{
"/api": {
ErrorCodes: []int{404},
MaxErrorCount: 2,
BanDuration: caddy.Duration(1 * time.Minute),
BanDurationMultiplier: 1, // Add multiplier
ErrorCountTimeout: caddy.Duration(1 * time.Second), // Override: 1 second for /api
},
},
}
ctx := caddy.Context{}
m.Provision(ctx)
// Test /api path with shorter timeout
reqAPI := httptest.NewRequest("GET", "http://example.com/api", nil)
reqAPI.RemoteAddr = "192.168.1.1:12345"
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNotFound)
return nil
})
// Make 1 error on /api
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, reqAPI, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
// Wait for /api timeout to expire (1 second)
time.Sleep(2 * time.Second)
// Make another error - count should be reset to 1 (not banned)
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, reqAPI, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404, got %d", rec.Code)
}
// Make one more error - this should be the 2nd error in the new window, triggering ban
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, reqAPI, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
// This request should succeed (404) because it's the 2nd error which triggers the ban
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status code 404 before ban, got %d", rec.Code)
}
// Now verify we are banned on the next request
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, reqAPI, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
}
}
// TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout tests Caddyfile parsing with error_count_timeout
func TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout(t *testing.T) {
m := Middleware{}
d := caddyfile.NewTestDispenser(`
caddy_mib {
error_codes 404 500
max_error_count 5
ban_duration 10m
error_count_timeout 1h
per_path /admin {
error_codes 401
max_error_count 3
ban_duration 30m
error_count_timeout 15m
}
}
`)
err := m.UnmarshalCaddyfile(d)
if err != nil {
t.Fatalf("UnmarshalCaddyfile failed: %v", err)
}
// Verify global error_count_timeout
if m.ErrorCountTimeout != caddy.Duration(1*time.Hour) {
t.Errorf("Expected error_count_timeout to be 1h, got %v", m.ErrorCountTimeout)
}
// Verify per-path error_count_timeout
pathConfig, ok := m.PerPathConfig["/admin"]
if !ok {
t.Fatal("Expected per_path config for /admin, but not found")
}
if pathConfig.ErrorCountTimeout != caddy.Duration(15*time.Minute) {
t.Errorf("Expected per_path error_count_timeout to be 15m, got %v", pathConfig.ErrorCountTimeout)
}
}
// TestMiddleware_NoErrorCountTimeout tests that without timeout, errors accumulate indefinitely
func TestMiddleware_NoErrorCountTimeout(t *testing.T) {
m := Middleware{
ErrorCodes: []int{404},
MaxErrorCount: 3,
BanDuration: caddy.Duration(1 * time.Minute),
ErrorCountTimeout: 0, // Disabled
}
ctx := caddy.Context{}
m.Provision(ctx)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNotFound)
return nil
})
// Make 2 errors
for i := 0; i < 2; i++ {
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
}
// Wait long time (would expire if timeout was set)
time.Sleep(3 * time.Second)
// Make one more error - should trigger ban (count was not reset)
rec := httptest.NewRecorder()
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
// Verify banned (3rd error total)
rec = httptest.NewRecorder()
err = m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
}
}
// TestMiddleware_CleanupDeletesAllPathsForIP tests the bug fix for cleanup
func TestMiddleware_CleanupDeletesAllPathsForIP(t *testing.T) {
m := Middleware{
ErrorCodes: []int{404},
MaxErrorCount: 5, // High enough to allow errors on all paths before ban
BanDuration: caddy.Duration(1 * time.Second),
}
ctx := caddy.Context{}
m.Provision(ctx)
// Make errors on multiple paths for the same IP
paths := []string{"/path1", "/path2", "/path3"}
for _, path := range paths {
req := httptest.NewRequest("GET", "http://example.com"+path, nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNotFound)
return nil
})
err := m.ServeHTTP(rec, req, next)
if err != nil {
t.Fatalf("ServeHTTP failed: %v", err)
}
}
// Manually trigger a ban to test cleanup
m.bannedIPs.Store("192.168.1.1", time.Now().Add(1*time.Second))
// Verify error counts exist for all paths
for _, path := range paths {
key := "192.168.1.1:" + path
if _, ok := m.errorCounts.Load(key); !ok {
t.Errorf("Expected error count for path %s to exist", path)
}
}
// Wait for ban to expire and cleanup to run
time.Sleep(3 * time.Second)
// Verify all error counts were deleted
for _, path := range paths {
key := "192.168.1.1:" + path
if _, ok := m.errorCounts.Load(key); ok {
t.Errorf("Expected error count for path %s to be deleted, but it still exists", path)
}
}
}