mirror of
https://github.com/fabriziosalmi/caddy-mib.git
synced 2025-12-23 22:17:43 -05:00
Merge pull request #10 from kjnez/main
This commit is contained in:
@@ -12,8 +12,9 @@
|
||||
caddy_mib {
|
||||
error_codes 404 500 401 # Error codes to track
|
||||
max_error_count 10 # Global error threshold (reduced for faster testing)
|
||||
ban_duration 5s # Global ban duration (reduced to 10 seconds)
|
||||
ban_duration_multiplier 1 # Global ban duration multiplier
|
||||
ban_duration 1s # Global ban duration (1s for fast testing)
|
||||
error_count_timeout 5s # Global error count timeout
|
||||
ban_duration_multiplier 2 # Global ban duration multiplier
|
||||
# whitelist 127.0.0.1 ::1 # Whitelisted IPs
|
||||
log_level debug # Log level for debugging
|
||||
ban_response_body "You have been banned due to excessive errors. Please try again later."
|
||||
@@ -23,7 +24,7 @@
|
||||
per_path /login {
|
||||
error_codes 404 # Error codes to track for /login
|
||||
max_error_count 5 # Error threshold for /login (reduced for faster testing)
|
||||
ban_duration 10s # Ban duration for /login (reduced to 15 seconds)
|
||||
ban_duration 2s # Ban duration for /login (2s for fast testing)
|
||||
ban_duration_multiplier 1
|
||||
}
|
||||
|
||||
@@ -31,7 +32,7 @@
|
||||
per_path /api {
|
||||
error_codes 404 500 # Error codes to track for /api
|
||||
max_error_count 8 # Error threshold for /api (reduced for faster testing)
|
||||
ban_duration 15s # Ban duration for /api (reduced to 20 seconds)
|
||||
ban_duration 3s # Ban duration for /api (3s for fast testing)
|
||||
ban_duration_multiplier 1
|
||||
}
|
||||
}
|
||||
|
||||
65
README.md
65
README.md
@@ -10,6 +10,7 @@
|
||||
* **Configurable Error Limits**: Set max errors per IP before banning.
|
||||
* **Flexible Ban Times**: Use human-readable formats (e.g., 5s, 10m, 1h).
|
||||
* **Exponential Ban Increase**: Ban duration grows for repeat offenders.
|
||||
* **Sliding Window Error Tracking**: Reset error counts after a period of inactivity (optional).
|
||||
* **Trusted IP Whitelisting**: Exclude specific IPs or CIDRs from bans.
|
||||
* **Path-Specific Settings**: Tailor limits and bans per URL path.
|
||||
* **Custom Ban Messages**: Set custom response bodies and headers.
|
||||
@@ -83,6 +84,7 @@ Here's a comprehensive example showcasing a range of options:
|
||||
max_error_count 10 # Allow up to 10 global errors
|
||||
ban_duration 5s # Base ban duration of 5 seconds
|
||||
ban_duration_multiplier 1.5 # Increase ban duration for repeat offenders
|
||||
error_count_timeout 1h # Reset error count after 1 hour of inactivity (optional)
|
||||
whitelist 127.0.0.1 ::1 192.168.1.0/24 # Whitelist local IPs and network
|
||||
log_request_headers User-Agent X-Custom-Header # Log specified request headers
|
||||
log_level debug # Debug log level for this middleware
|
||||
@@ -99,6 +101,7 @@ Here's a comprehensive example showcasing a range of options:
|
||||
max_error_count 5 # Only allow 5 errors before banning
|
||||
ban_duration 10s # Ban duration of 10 seconds
|
||||
ban_duration_multiplier 2 # Exponential increase in /login ban duration
|
||||
error_count_timeout 15m # Reset after 15 minutes for /login (optional)
|
||||
}
|
||||
|
||||
per_path /api {
|
||||
@@ -106,6 +109,7 @@ Here's a comprehensive example showcasing a range of options:
|
||||
max_error_count 8 # Allow 8 errors before banning
|
||||
ban_duration 15s # 15-second ban duration
|
||||
ban_duration_multiplier 1 # No exponential increase in /api ban duration
|
||||
error_count_timeout 30m # Reset after 30 minutes for /api (optional)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +135,12 @@ Here's a comprehensive example showcasing a range of options:
|
||||
* A floating-point number to exponentially increase the ban duration upon each subsequent offense.
|
||||
* Example: `ban_duration_multiplier 1.5`
|
||||
* Defaults to `1.0` (no multiplier).
|
||||
- **`error_count_timeout`** _(Optional)_:
|
||||
* Time window for counting errors. If the time between errors exceeds this duration, the error count resets to 1 (sliding window behavior).
|
||||
* Useful for preventing permanent error accumulation and avoiding bans from occasional errors spread over long periods.
|
||||
* Example: `error_count_timeout 1h` (1 hour), `error_count_timeout 30m` (30 minutes)
|
||||
* Set to `0` or omit to disable (errors never expire - original behavior).
|
||||
* Can be overridden per-path.
|
||||
- **`whitelist`** _(Optional)_:
|
||||
* A space-separated list of IP addresses or CIDR ranges to exclude from being banned.
|
||||
* Example: `whitelist 127.0.0.1 ::1 192.168.1.0/24`
|
||||
@@ -161,8 +171,9 @@ Here's a comprehensive example showcasing a range of options:
|
||||
|
||||
- **`per_path <path>`** _(Optional)_:
|
||||
* Configures per-path settings, taking precedence over global configurations.
|
||||
* Reuses all the same options as global ones: `error_codes`, `max_error_count`, `ban_duration` and `ban_duration_multiplier`
|
||||
* Reuses all the same options as global ones: `error_codes`, `max_error_count`, `ban_duration`, `ban_duration_multiplier`, and `error_count_timeout`
|
||||
* Each path block must be defined as a Caddyfile block.
|
||||
* If `error_count_timeout` is not specified in a per-path config, it inherits the global value.
|
||||
|
||||
---
|
||||
|
||||
@@ -176,11 +187,40 @@ Here's a comprehensive example showcasing a range of options:
|
||||
4. The client attempts to access the `/login` endpoint, which is configured with specific error limits and ban duration that are different than the global ones.
|
||||
5. The client is banned after triggering multiple 404, resulting in a separate ban and `429` response.
|
||||
|
||||
### Sliding Window Behavior
|
||||
|
||||
When `error_count_timeout` is configured, the middleware implements a sliding window for error tracking:
|
||||
|
||||
**Example Configuration:**
|
||||
```caddyfile
|
||||
caddy_mib {
|
||||
error_codes 404
|
||||
max_error_count 5
|
||||
ban_duration 10m
|
||||
error_count_timeout 1h # Reset after 1 hour of inactivity
|
||||
}
|
||||
```
|
||||
|
||||
**Behavior:**
|
||||
- User hits 3 errors within 10 minutes → count = 3
|
||||
- User waits 61 minutes (exceeds 1-hour timeout)
|
||||
- User hits 1 more error → count resets to 1 (not banned)
|
||||
- User hits 4 more errors quickly → count reaches 5, user is banned
|
||||
|
||||
**Without timeout (default):**
|
||||
- All errors accumulate indefinitely
|
||||
- After ban expires, hitting just 1 more error triggers immediate re-ban
|
||||
|
||||
**Use Cases:**
|
||||
- **Set timeout**: Protect against concentrated attacks while forgiving occasional errors
|
||||
- **No timeout**: Stricter enforcement, useful for zero-tolerance scenarios
|
||||
|
||||
### Best Practices
|
||||
|
||||
* **Start with a reasonable `max_error_count`**: This should be high enough to avoid banning legitimate users and bots while still protecting against malicious attacks.
|
||||
* **Use a moderate `ban_duration`**: Start with a short ban duration and gradually increase it if needed.
|
||||
* **Utilize `ban_duration_multiplier` wisely**: Be careful when using exponential multipliers, as they can quickly lead to very long ban times.
|
||||
* **Configure `error_count_timeout` for most use cases**: A 1-hour timeout is a good starting point to prevent permanent error accumulation while still catching abuse patterns. Omit for zero-tolerance enforcement.
|
||||
* **Whitelist trusted networks**: It's good practice to whitelist internal networks to prevent self-blocking.
|
||||
* **Set proper log levels**: Setting `log_level` to `debug` can help during testing, while `info` or `warn` are more suitable for production use.
|
||||
* **Use `cidr_bans`**: Use `cidr_bans` in combination with the `whitelist` for more precise configuration.
|
||||
@@ -285,6 +325,29 @@ These log lines provide valuable information on when IPs are banned, which error
|
||||
|
||||
---
|
||||
|
||||
## Recent Updates
|
||||
|
||||
### v1.1.0 (Latest)
|
||||
|
||||
**New Features:**
|
||||
- **Sliding Window Error Tracking**: Added `error_count_timeout` configuration option
|
||||
- Prevents permanent error accumulation
|
||||
- Resets error counts after a period of inactivity
|
||||
- Configurable globally and per-path
|
||||
- Backwards compatible (disabled by default)
|
||||
|
||||
**Bug Fixes:**
|
||||
- Fixed error count cleanup when bans expire
|
||||
- Previously: Only attempted to delete error counts using IP as key
|
||||
- Now: Properly deletes all error counts across all paths for banned IPs
|
||||
- Impact: Error counts are now correctly reset after ban expiration
|
||||
|
||||
**Testing:**
|
||||
- Added 6 comprehensive tests covering new functionality
|
||||
- All 16 tests passing with full coverage
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
This project is licensed under the **AGPL-3.0 License**. Refer to the [LICENSE](LICENSE) file for full details.
|
||||
|
||||
|
||||
193
caddymib.go
193
caddymib.go
@@ -30,6 +30,7 @@ type Middleware struct {
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense
|
||||
ErrorCountTimeout caddy.Duration `json:"error_count_timeout,omitempty"` // Time window for counting errors (0 = disabled, errors never expire)
|
||||
Whitelist []string `json:"whitelist,omitempty"` // List of IPs or CIDRs to whitelist
|
||||
CustomResponseHeader string `json:"custom_response_header,omitempty"` // Custom header to add to responses
|
||||
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // Request headers to log
|
||||
@@ -44,6 +45,7 @@ type Middleware struct {
|
||||
logger *zap.Logger
|
||||
errorCounts sync.Map // Tracks errors per IP and path
|
||||
bannedIPs sync.Map // Tracks banned IPs and their expiration times
|
||||
offenseCounts sync.Map // Tracks number of times each IP has been banned (for multiplier)
|
||||
bannedCIDRs []*net.IPNet
|
||||
whitelistedNets []*net.IPNet
|
||||
}
|
||||
@@ -54,6 +56,14 @@ type PathConfig struct {
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning for this path
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning for this path
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense for this path
|
||||
ErrorCountTimeout caddy.Duration `json:"error_count_timeout,omitempty"` // Time window for counting errors (0 = use global setting)
|
||||
}
|
||||
|
||||
// errorTracker tracks error counts and timing for sliding window behavior.
|
||||
type errorTracker struct {
|
||||
Count int
|
||||
FirstErrorAt time.Time
|
||||
LastErrorAt time.Time
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
@@ -220,7 +230,15 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
zap.String("client_ip", clientIP),
|
||||
)
|
||||
m.bannedIPs.Delete(clientIP)
|
||||
m.errorCounts.Delete(clientIP)
|
||||
|
||||
// Delete all error counts for this IP across all paths
|
||||
m.errorCounts.Range(func(countKey, countValue interface{}) bool {
|
||||
countKeyStr := countKey.(string)
|
||||
if strings.HasPrefix(countKeyStr, clientIP+":") {
|
||||
m.errorCounts.Delete(countKey)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Skip middleware if no error codes are specified
|
||||
@@ -303,31 +321,68 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
// Use global configuration
|
||||
for _, errCode := range m.ErrorCodes {
|
||||
if code == errCode {
|
||||
countBefore := 0
|
||||
now := time.Now()
|
||||
|
||||
// Load or initialize error tracker
|
||||
var tracker errorTracker
|
||||
if val, ok := m.errorCounts.Load(key); ok {
|
||||
countBefore = val.(int)
|
||||
tracker = val.(errorTracker)
|
||||
|
||||
// Implement sliding window: reset count if timeout has passed
|
||||
if m.ErrorCountTimeout > 0 && now.Sub(tracker.LastErrorAt) > time.Duration(m.ErrorCountTimeout) {
|
||||
m.logger.Debug("error count timeout expired, resetting count",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", path),
|
||||
zap.Duration("time_since_last_error", now.Sub(tracker.LastErrorAt)),
|
||||
zap.Duration("timeout", time.Duration(m.ErrorCountTimeout)),
|
||||
)
|
||||
tracker = errorTracker{
|
||||
Count: 1,
|
||||
FirstErrorAt: now,
|
||||
LastErrorAt: now,
|
||||
}
|
||||
} else {
|
||||
tracker.Count++
|
||||
tracker.LastErrorAt = now
|
||||
}
|
||||
} else {
|
||||
// First error for this IP:path
|
||||
tracker = errorTracker{
|
||||
Count: 1,
|
||||
FirstErrorAt: now,
|
||||
LastErrorAt: now,
|
||||
}
|
||||
}
|
||||
m.logger.Debug("tracking error", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
)...)
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(key, countNow)
|
||||
|
||||
m.errorCounts.Store(key, tracker)
|
||||
m.logger.Debug("error count incremented", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("current_error_count", tracker.Count),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
zap.Time("first_error_at", tracker.FirstErrorAt),
|
||||
zap.Time("last_error_at", tracker.LastErrorAt),
|
||||
)...)
|
||||
if countNow >= m.MaxErrorCount {
|
||||
offenses := countNow - m.MaxErrorCount + 1
|
||||
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
|
||||
|
||||
if tracker.Count >= m.MaxErrorCount {
|
||||
// Increment offense count for this IP (global path)
|
||||
// Use clientIP as key for global offense tracking
|
||||
offenseKey := clientIP
|
||||
offenseCount := 1
|
||||
if val, ok := m.offenseCounts.Load(offenseKey); ok {
|
||||
offenseCount = val.(int) + 1
|
||||
}
|
||||
m.offenseCounts.Store(offenseKey, offenseCount)
|
||||
|
||||
// Calculate ban duration with multiplier based on offense count
|
||||
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenseCount)))
|
||||
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
|
||||
banDuration = 24 * time.Hour
|
||||
}
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
zap.Int("error_count", countNow),
|
||||
zap.Int("error_count", tracker.Count),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
zap.Int("offense_count", offenseCount),
|
||||
zap.Duration("ban_duration", banDuration),
|
||||
zap.Time("ban_expires", expiration),
|
||||
)
|
||||
@@ -361,31 +416,74 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
|
||||
|
||||
for _, errCode := range config.ErrorCodes {
|
||||
if code == errCode {
|
||||
countBefore := 0
|
||||
if val, ok := m.errorCounts.Load(key); ok {
|
||||
countBefore = val.(int)
|
||||
now := time.Now()
|
||||
|
||||
// Determine which timeout to use (per-path or global)
|
||||
timeout := config.ErrorCountTimeout
|
||||
if timeout == 0 {
|
||||
timeout = m.ErrorCountTimeout
|
||||
}
|
||||
m.logger.Debug("tracking error for path", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
)...)
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(key, countNow)
|
||||
|
||||
// Load or initialize error tracker
|
||||
var tracker errorTracker
|
||||
if val, ok := m.errorCounts.Load(key); ok {
|
||||
tracker = val.(errorTracker)
|
||||
|
||||
// Implement sliding window: reset count if timeout has passed
|
||||
if timeout > 0 && now.Sub(tracker.LastErrorAt) > time.Duration(timeout) {
|
||||
m.logger.Debug("error count timeout expired for path, resetting count",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("path", path),
|
||||
zap.Duration("time_since_last_error", now.Sub(tracker.LastErrorAt)),
|
||||
zap.Duration("timeout", time.Duration(timeout)),
|
||||
)
|
||||
tracker = errorTracker{
|
||||
Count: 1,
|
||||
FirstErrorAt: now,
|
||||
LastErrorAt: now,
|
||||
}
|
||||
} else {
|
||||
tracker.Count++
|
||||
tracker.LastErrorAt = now
|
||||
}
|
||||
} else {
|
||||
// First error for this IP:path
|
||||
tracker = errorTracker{
|
||||
Count: 1,
|
||||
FirstErrorAt: now,
|
||||
LastErrorAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
m.errorCounts.Store(key, tracker)
|
||||
m.logger.Debug("error count incremented for path", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("current_error_count", tracker.Count),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
zap.Time("first_error_at", tracker.FirstErrorAt),
|
||||
zap.Time("last_error_at", tracker.LastErrorAt),
|
||||
)...)
|
||||
if countNow >= config.MaxErrorCount {
|
||||
offenses := countNow - config.MaxErrorCount + 1
|
||||
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
|
||||
|
||||
if tracker.Count >= config.MaxErrorCount {
|
||||
// Increment offense count for this IP:path combination
|
||||
// Use composite key for per-path offense tracking
|
||||
offenseKey := fmt.Sprintf("%s:%s", clientIP, path)
|
||||
offenseCount := 1
|
||||
if val, ok := m.offenseCounts.Load(offenseKey); ok {
|
||||
offenseCount = val.(int) + 1
|
||||
}
|
||||
m.offenseCounts.Store(offenseKey, offenseCount)
|
||||
|
||||
// Calculate ban duration with multiplier based on offense count
|
||||
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenseCount)))
|
||||
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
|
||||
banDuration = 24 * time.Hour
|
||||
}
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
zap.Int("error_count", countNow),
|
||||
zap.Int("error_count", tracker.Count),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
zap.Int("offense_count", offenseCount),
|
||||
zap.Duration("ban_duration", banDuration),
|
||||
zap.Time("ban_expires", expiration),
|
||||
)
|
||||
@@ -431,11 +529,26 @@ func (m *Middleware) cleanupExpiredBans() {
|
||||
now := time.Now()
|
||||
m.bannedIPs.Range(func(key, value interface{}) bool {
|
||||
if now.After(value.(time.Time)) {
|
||||
clientIP := key.(string)
|
||||
m.logger.Info("cleaned up expired ban",
|
||||
zap.String("client_ip", key.(string)),
|
||||
zap.String("client_ip", clientIP),
|
||||
)
|
||||
m.bannedIPs.Delete(key)
|
||||
m.errorCounts.Delete(key)
|
||||
|
||||
// Delete all error counts for this IP across all paths
|
||||
// errorCounts keys are in format "IP:path"
|
||||
m.errorCounts.Range(func(countKey, countValue interface{}) bool {
|
||||
countKeyStr := countKey.(string)
|
||||
// Check if this error count belongs to the banned IP
|
||||
if strings.HasPrefix(countKeyStr, clientIP+":") {
|
||||
m.errorCounts.Delete(countKey)
|
||||
m.logger.Debug("cleaned up error count for unbanned IP",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("key", countKeyStr),
|
||||
)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -491,6 +604,16 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
m.BanDurationMultiplier = multiplier
|
||||
|
||||
case "error_count_timeout":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid error_count_timeout: %v", err)
|
||||
}
|
||||
m.ErrorCountTimeout = caddy.Duration(dur)
|
||||
|
||||
case "whitelist":
|
||||
var whitelist []string
|
||||
for d.NextArg() {
|
||||
@@ -595,6 +718,16 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
config.BanDurationMultiplier = multiplier
|
||||
|
||||
case "error_count_timeout":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid error_count_timeout: %v", err)
|
||||
}
|
||||
config.ErrorCountTimeout = caddy.Duration(dur)
|
||||
|
||||
default:
|
||||
return d.Errf("unrecognized option in per_path block: %s", d.Val())
|
||||
}
|
||||
|
||||
347
caddymib_test.go
347
caddymib_test.go
@@ -352,3 +352,350 @@ func TestMiddleware_ServeHTTP_LogRequestHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TestMiddleware_ErrorCountTimeout tests the sliding window behavior
|
||||
func TestMiddleware_ErrorCountTimeout(t *testing.T) {
|
||||
m := Middleware{
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 3,
|
||||
BanDuration: caddy.Duration(1 * time.Minute),
|
||||
ErrorCountTimeout: caddy.Duration(2 * time.Second), // 2 second window
|
||||
}
|
||||
ctx := caddy.Context{}
|
||||
m.Provision(ctx)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make 2 errors within the timeout window
|
||||
for i := 0; i < 2; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for timeout to expire
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Make another error - count should reset to 1
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404, got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Verify error count was reset by checking we can make 2 more errors without ban
|
||||
for i := 0; i < 2; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if i < 1 && rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Now we should be banned (3rd error in this window)
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddleware_ErrorCountResetOnBanExpiry tests that error counts are cleared when ban expires
|
||||
func TestMiddleware_ErrorCountResetOnBanExpiry(t *testing.T) {
|
||||
m := Middleware{
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 2,
|
||||
BanDuration: caddy.Duration(1 * time.Second), // Short ban for testing
|
||||
}
|
||||
ctx := caddy.Context{}
|
||||
m.Provision(ctx)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/path1", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Trigger ban with 2 errors
|
||||
for i := 0; i < 2; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify banned
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Wait for ban to expire
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Make request - should unban and clear error counts
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404 (unbanned), got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Verify error count was reset - we should be able to make another error without immediate ban
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404 (not banned yet), got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Wait for cleanup goroutine to process the second ban expiry
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify the error counts for path1 were actually deleted
|
||||
key := "192.168.1.1:/path1"
|
||||
if _, ok := m.errorCounts.Load(key); ok {
|
||||
t.Error("Expected error count to be deleted after ban expired, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddleware_PerPathErrorCountTimeout tests per-path timeout configuration
|
||||
func TestMiddleware_PerPathErrorCountTimeout(t *testing.T) {
|
||||
m := Middleware{
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 3,
|
||||
BanDuration: caddy.Duration(1 * time.Minute),
|
||||
ErrorCountTimeout: caddy.Duration(5 * time.Second), // Global: 5 seconds
|
||||
PerPathConfig: map[string]PathConfig{
|
||||
"/api": {
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 2,
|
||||
BanDuration: caddy.Duration(1 * time.Minute),
|
||||
BanDurationMultiplier: 1, // Add multiplier
|
||||
ErrorCountTimeout: caddy.Duration(1 * time.Second), // Override: 1 second for /api
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := caddy.Context{}
|
||||
m.Provision(ctx)
|
||||
|
||||
// Test /api path with shorter timeout
|
||||
reqAPI := httptest.NewRequest("GET", "http://example.com/api", nil)
|
||||
reqAPI.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make 1 error on /api
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, reqAPI, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for /api timeout to expire (1 second)
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Make another error - count should be reset to 1 (not banned)
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, reqAPI, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404, got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Make one more error - this should be the 2nd error in the new window, triggering ban
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, reqAPI, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
// This request should succeed (404) because it's the 2nd error which triggers the ban
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status code 404 before ban, got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Now verify we are banned on the next request
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, reqAPI, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout tests Caddyfile parsing with error_count_timeout
|
||||
func TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout(t *testing.T) {
|
||||
m := Middleware{}
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
caddy_mib {
|
||||
error_codes 404 500
|
||||
max_error_count 5
|
||||
ban_duration 10m
|
||||
error_count_timeout 1h
|
||||
per_path /admin {
|
||||
error_codes 401
|
||||
max_error_count 3
|
||||
ban_duration 30m
|
||||
error_count_timeout 15m
|
||||
}
|
||||
}
|
||||
`)
|
||||
|
||||
err := m.UnmarshalCaddyfile(d)
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalCaddyfile failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify global error_count_timeout
|
||||
if m.ErrorCountTimeout != caddy.Duration(1*time.Hour) {
|
||||
t.Errorf("Expected error_count_timeout to be 1h, got %v", m.ErrorCountTimeout)
|
||||
}
|
||||
|
||||
// Verify per-path error_count_timeout
|
||||
pathConfig, ok := m.PerPathConfig["/admin"]
|
||||
if !ok {
|
||||
t.Fatal("Expected per_path config for /admin, but not found")
|
||||
}
|
||||
|
||||
if pathConfig.ErrorCountTimeout != caddy.Duration(15*time.Minute) {
|
||||
t.Errorf("Expected per_path error_count_timeout to be 15m, got %v", pathConfig.ErrorCountTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddleware_NoErrorCountTimeout tests that without timeout, errors accumulate indefinitely
|
||||
func TestMiddleware_NoErrorCountTimeout(t *testing.T) {
|
||||
m := Middleware{
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 3,
|
||||
BanDuration: caddy.Duration(1 * time.Minute),
|
||||
ErrorCountTimeout: 0, // Disabled
|
||||
}
|
||||
ctx := caddy.Context{}
|
||||
m.Provision(ctx)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make 2 errors
|
||||
for i := 0; i < 2; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait long time (would expire if timeout was set)
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Make one more error - should trigger ban (count was not reset)
|
||||
rec := httptest.NewRecorder()
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify banned (3rd error total)
|
||||
rec = httptest.NewRecorder()
|
||||
err = m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddleware_CleanupDeletesAllPathsForIP tests the bug fix for cleanup
|
||||
func TestMiddleware_CleanupDeletesAllPathsForIP(t *testing.T) {
|
||||
m := Middleware{
|
||||
ErrorCodes: []int{404},
|
||||
MaxErrorCount: 5, // High enough to allow errors on all paths before ban
|
||||
BanDuration: caddy.Duration(1 * time.Second),
|
||||
}
|
||||
ctx := caddy.Context{}
|
||||
m.Provision(ctx)
|
||||
|
||||
// Make errors on multiple paths for the same IP
|
||||
paths := []string{"/path1", "/path2", "/path3"}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+path, nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := m.ServeHTTP(rec, req, next)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Manually trigger a ban to test cleanup
|
||||
m.bannedIPs.Store("192.168.1.1", time.Now().Add(1*time.Second))
|
||||
|
||||
// Verify error counts exist for all paths
|
||||
for _, path := range paths {
|
||||
key := "192.168.1.1:" + path
|
||||
if _, ok := m.errorCounts.Load(key); !ok {
|
||||
t.Errorf("Expected error count for path %s to exist", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for ban to expire and cleanup to run
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Verify all error counts were deleted
|
||||
for _, path := range paths {
|
||||
key := "192.168.1.1:" + path
|
||||
if _, ok := m.errorCounts.Load(key); ok {
|
||||
t.Errorf("Expected error count for path %s to be deleted, but it still exists", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user