mirror of
https://github.com/fabriziosalmi/caddy-mib.git
synced 2025-12-23 14:07:44 -05:00
702 lines
19 KiB
Go
702 lines
19 KiB
Go
package caddymib
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/caddyserver/caddy/v2"
|
|
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
|
)
|
|
|
|
func TestMiddleware_Provision(t *testing.T) {
|
|
m := Middleware{}
|
|
ctx := caddy.Context{}
|
|
|
|
err := m.Provision(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Provision failed: %v", err)
|
|
}
|
|
|
|
if m.MaxErrorCount != 5 {
|
|
t.Errorf("Expected MaxErrorCount to be 5, got %d", m.MaxErrorCount)
|
|
}
|
|
|
|
if m.BanDuration != caddy.Duration(10*time.Minute) {
|
|
t.Errorf("Expected BanDuration to be 10m, got %v", m.BanDuration)
|
|
}
|
|
|
|
if m.BanDurationMultiplier != 1 {
|
|
t.Errorf("Expected BanDurationMultiplier to be 1, got %v", m.BanDurationMultiplier)
|
|
}
|
|
|
|
if m.BanStatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected BanStatusCode to be 403, got %d", m.BanStatusCode)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_Whitelist(t *testing.T) {
|
|
m := Middleware{
|
|
Whitelist: []string{"127.0.0.1"},
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Expected status code 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_Ban(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{500},
|
|
MaxErrorCount: 1,
|
|
BanDuration: caddy.Duration(1 * time.Minute),
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return nil
|
|
})
|
|
|
|
// First request should trigger a ban
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status code 500, got %d", rec.Code)
|
|
}
|
|
|
|
// Second request should be banned
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_CIDRBans(t *testing.T) {
|
|
m := Middleware{
|
|
CIDRBans: []string{"192.168.1.0/24"},
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_CustomResponseHeader(t *testing.T) {
|
|
m := Middleware{
|
|
CustomResponseHeader: "TestHeaderValue",
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Header().Get("X-Custom-MIB-Info") != "TestHeaderValue" {
|
|
t.Errorf("Expected custom header 'X-Custom-MIB-Info' to be 'TestHeaderValue', got '%s'", rec.Header().Get("X-Custom-MIB-Info"))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_CleanupExpiredBans(t *testing.T) {
|
|
m := Middleware{}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
m.bannedIPs.Store("192.168.1.1", time.Now().Add(-1*time.Minute))
|
|
|
|
go m.cleanupExpiredBans()
|
|
time.Sleep(2 * time.Second) // Wait for cleanup to run
|
|
|
|
if _, banned := m.bannedIPs.Load("192.168.1.1"); banned {
|
|
t.Error("Expected ban to be cleaned up, but it still exists")
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_UnmarshalCaddyfile(t *testing.T) {
|
|
m := Middleware{}
|
|
d := caddyfile.NewTestDispenser(`
|
|
caddy_mib {
|
|
error_codes 500 404
|
|
max_error_count 3
|
|
ban_duration 5m
|
|
ban_duration_multiplier 2
|
|
whitelist 127.0.0.1
|
|
custom_response_header "TestHeader"
|
|
log_request_headers User-Agent
|
|
log_level debug
|
|
cidr_bans 192.168.1.0/24
|
|
ban_response_body "Banned"
|
|
ban_status_code 429
|
|
per_path /test {
|
|
error_codes 400
|
|
max_error_count 2
|
|
ban_duration 10m
|
|
ban_duration_multiplier 3
|
|
}
|
|
}
|
|
`)
|
|
|
|
err := m.UnmarshalCaddyfile(d)
|
|
if err != nil {
|
|
t.Fatalf("UnmarshalCaddyfile failed: %v", err)
|
|
}
|
|
|
|
if len(m.ErrorCodes) != 2 || m.ErrorCodes[0] != 500 || m.ErrorCodes[1] != 404 {
|
|
t.Errorf("Expected error_codes to be [500, 404], got %v", m.ErrorCodes)
|
|
}
|
|
|
|
if m.MaxErrorCount != 3 {
|
|
t.Errorf("Expected max_error_count to be 3, got %d", m.MaxErrorCount)
|
|
}
|
|
|
|
if m.BanDuration != caddy.Duration(5*time.Minute) {
|
|
t.Errorf("Expected ban_duration to be 5m, got %v", m.BanDuration)
|
|
}
|
|
|
|
if m.BanDurationMultiplier != 2 {
|
|
t.Errorf("Expected ban_duration_multiplier to be 2, got %v", m.BanDurationMultiplier)
|
|
}
|
|
|
|
if len(m.Whitelist) != 1 || m.Whitelist[0] != "127.0.0.1" {
|
|
t.Errorf("Expected whitelist to be [127.0.0.1], got %v", m.Whitelist)
|
|
}
|
|
|
|
if m.CustomResponseHeader != "TestHeader" {
|
|
t.Errorf("Expected custom_response_header to be 'TestHeader', got '%s'", m.CustomResponseHeader)
|
|
}
|
|
|
|
if len(m.LogRequestHeaders) != 1 || m.LogRequestHeaders[0] != "User-Agent" {
|
|
t.Errorf("Expected log_request_headers to be [User-Agent], got %v", m.LogRequestHeaders)
|
|
}
|
|
|
|
if m.LogLevel != "debug" {
|
|
t.Errorf("Expected log_level to be 'debug', got '%s'", m.LogLevel)
|
|
}
|
|
|
|
if len(m.CIDRBans) != 1 || m.CIDRBans[0] != "192.168.1.0/24" {
|
|
t.Errorf("Expected cidr_bans to be [192.168.1.0/24], got %v", m.CIDRBans)
|
|
}
|
|
|
|
if m.BanResponseBody != "Banned" {
|
|
t.Errorf("Expected ban_response_body to be 'Banned', got '%s'", m.BanResponseBody)
|
|
}
|
|
|
|
if m.BanStatusCode != 429 {
|
|
t.Errorf("Expected ban_status_code to be 429, got %d", m.BanStatusCode)
|
|
}
|
|
|
|
if len(m.PerPathConfig) != 1 {
|
|
t.Errorf("Expected per_path config to have 1 entry, got %d", len(m.PerPathConfig))
|
|
}
|
|
|
|
pathConfig, ok := m.PerPathConfig["/test"]
|
|
if !ok {
|
|
t.Fatal("Expected per_path config for /test, but not found")
|
|
}
|
|
|
|
if len(pathConfig.ErrorCodes) != 1 || pathConfig.ErrorCodes[0] != 400 {
|
|
t.Errorf("Expected per_path error_codes to be [400], got %v", pathConfig.ErrorCodes)
|
|
}
|
|
|
|
if pathConfig.MaxErrorCount != 2 {
|
|
t.Errorf("Expected per_path max_error_count to be 2, got %d", pathConfig.MaxErrorCount)
|
|
}
|
|
|
|
if pathConfig.BanDuration != caddy.Duration(10*time.Minute) {
|
|
t.Errorf("Expected per_path ban_duration to be 10m, got %v", pathConfig.BanDuration)
|
|
}
|
|
|
|
if pathConfig.BanDurationMultiplier != 3 {
|
|
t.Errorf("Expected per_path ban_duration_multiplier to be 3, got %v", pathConfig.BanDurationMultiplier)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_EmptyCustomHeader(t *testing.T) {
|
|
m := Middleware{
|
|
CustomResponseHeader: "",
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
if rec.Header().Get("X-Custom-MIB-Info") != "" {
|
|
t.Errorf("Expected no custom header, got '%s'", rec.Header().Get("X-Custom-MIB-Info"))
|
|
}
|
|
}
|
|
|
|
//
|
|
|
|
func TestMiddleware_ServeHTTP_MultipleCustomHeaders(t *testing.T) {
|
|
m := Middleware{
|
|
CustomResponseHeader: "TestHeaderValue1,TestHeaderValue2",
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
headers := rec.Header().Values("X-Custom-MIB-Info")
|
|
if len(headers) != 2 {
|
|
t.Errorf("Expected 2 custom headers, got %d", len(headers))
|
|
}
|
|
if headers[0] != "TestHeaderValue1" || headers[1] != "TestHeaderValue2" {
|
|
t.Errorf("Expected custom headers 'TestHeaderValue1' and 'TestHeaderValue2', got %v", headers)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_ServeHTTP_LogRequestHeaders(t *testing.T) {
|
|
m := Middleware{
|
|
LogRequestHeaders: []string{"User-Agent", "X-Forwarded-For"},
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
req.Header.Set("User-Agent", "TestAgent")
|
|
req.Header.Set("X-Forwarded-For", "192.168.1.2")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
}
|
|
|
|
// TestMiddleware_ErrorCountTimeout tests the sliding window behavior
|
|
func TestMiddleware_ErrorCountTimeout(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 3,
|
|
BanDuration: caddy.Duration(1 * time.Minute),
|
|
ErrorCountTimeout: caddy.Duration(2 * time.Second), // 2 second window
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return nil
|
|
})
|
|
|
|
// Make 2 errors within the timeout window
|
|
for i := 0; i < 2; i++ {
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// Wait for timeout to expire
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// Make another error - count should reset to 1
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404, got %d", rec.Code)
|
|
}
|
|
|
|
// Verify error count was reset by checking we can make 2 more errors without ban
|
|
for i := 0; i < 2; i++ {
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if i < 1 && rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// Now we should be banned (3rd error in this window)
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestMiddleware_ErrorCountResetOnBanExpiry tests that error counts are cleared when ban expires
|
|
func TestMiddleware_ErrorCountResetOnBanExpiry(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 2,
|
|
BanDuration: caddy.Duration(1 * time.Second), // Short ban for testing
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/path1", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return nil
|
|
})
|
|
|
|
// Trigger ban with 2 errors
|
|
for i := 0; i < 2; i++ {
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Verify banned
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
|
}
|
|
|
|
// Wait for ban to expire
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Make request - should unban and clear error counts
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404 (unbanned), got %d", rec.Code)
|
|
}
|
|
|
|
// Verify error count was reset - we should be able to make another error without immediate ban
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404 (not banned yet), got %d", rec.Code)
|
|
}
|
|
|
|
// Wait for cleanup goroutine to process the second ban expiry
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Verify the error counts for path1 were actually deleted
|
|
key := "192.168.1.1:/path1"
|
|
if _, ok := m.errorCounts.Load(key); ok {
|
|
t.Error("Expected error count to be deleted after ban expired, but it still exists")
|
|
}
|
|
}
|
|
|
|
// TestMiddleware_PerPathErrorCountTimeout tests per-path timeout configuration
|
|
func TestMiddleware_PerPathErrorCountTimeout(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 3,
|
|
BanDuration: caddy.Duration(1 * time.Minute),
|
|
ErrorCountTimeout: caddy.Duration(5 * time.Second), // Global: 5 seconds
|
|
PerPathConfig: map[string]PathConfig{
|
|
"/api": {
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 2,
|
|
BanDuration: caddy.Duration(1 * time.Minute),
|
|
BanDurationMultiplier: 1, // Add multiplier
|
|
ErrorCountTimeout: caddy.Duration(1 * time.Second), // Override: 1 second for /api
|
|
},
|
|
},
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
// Test /api path with shorter timeout
|
|
reqAPI := httptest.NewRequest("GET", "http://example.com/api", nil)
|
|
reqAPI.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return nil
|
|
})
|
|
|
|
// Make 1 error on /api
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, reqAPI, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
// Wait for /api timeout to expire (1 second)
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Make another error - count should be reset to 1 (not banned)
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, reqAPI, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404, got %d", rec.Code)
|
|
}
|
|
|
|
// Make one more error - this should be the 2nd error in the new window, triggering ban
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, reqAPI, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
// This request should succeed (404) because it's the 2nd error which triggers the ban
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Errorf("Expected status code 404 before ban, got %d", rec.Code)
|
|
}
|
|
|
|
// Now verify we are banned on the next request
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, reqAPI, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout tests Caddyfile parsing with error_count_timeout
|
|
func TestMiddleware_UnmarshalCaddyfile_WithErrorCountTimeout(t *testing.T) {
|
|
m := Middleware{}
|
|
d := caddyfile.NewTestDispenser(`
|
|
caddy_mib {
|
|
error_codes 404 500
|
|
max_error_count 5
|
|
ban_duration 10m
|
|
error_count_timeout 1h
|
|
per_path /admin {
|
|
error_codes 401
|
|
max_error_count 3
|
|
ban_duration 30m
|
|
error_count_timeout 15m
|
|
}
|
|
}
|
|
`)
|
|
|
|
err := m.UnmarshalCaddyfile(d)
|
|
if err != nil {
|
|
t.Fatalf("UnmarshalCaddyfile failed: %v", err)
|
|
}
|
|
|
|
// Verify global error_count_timeout
|
|
if m.ErrorCountTimeout != caddy.Duration(1*time.Hour) {
|
|
t.Errorf("Expected error_count_timeout to be 1h, got %v", m.ErrorCountTimeout)
|
|
}
|
|
|
|
// Verify per-path error_count_timeout
|
|
pathConfig, ok := m.PerPathConfig["/admin"]
|
|
if !ok {
|
|
t.Fatal("Expected per_path config for /admin, but not found")
|
|
}
|
|
|
|
if pathConfig.ErrorCountTimeout != caddy.Duration(15*time.Minute) {
|
|
t.Errorf("Expected per_path error_count_timeout to be 15m, got %v", pathConfig.ErrorCountTimeout)
|
|
}
|
|
}
|
|
|
|
// TestMiddleware_NoErrorCountTimeout tests that without timeout, errors accumulate indefinitely
|
|
func TestMiddleware_NoErrorCountTimeout(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 3,
|
|
BanDuration: caddy.Duration(1 * time.Minute),
|
|
ErrorCountTimeout: 0, // Disabled
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return nil
|
|
})
|
|
|
|
// Make 2 errors
|
|
for i := 0; i < 2; i++ {
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Wait long time (would expire if timeout was set)
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// Make one more error - should trigger ban (count was not reset)
|
|
rec := httptest.NewRecorder()
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
|
|
// Verify banned (3rd error total)
|
|
rec = httptest.NewRecorder()
|
|
err = m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status code 403 (banned), got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// TestMiddleware_CleanupDeletesAllPathsForIP tests the bug fix for cleanup
|
|
func TestMiddleware_CleanupDeletesAllPathsForIP(t *testing.T) {
|
|
m := Middleware{
|
|
ErrorCodes: []int{404},
|
|
MaxErrorCount: 5, // High enough to allow errors on all paths before ban
|
|
BanDuration: caddy.Duration(1 * time.Second),
|
|
}
|
|
ctx := caddy.Context{}
|
|
m.Provision(ctx)
|
|
|
|
// Make errors on multiple paths for the same IP
|
|
paths := []string{"/path1", "/path2", "/path3"}
|
|
for _, path := range paths {
|
|
req := httptest.NewRequest("GET", "http://example.com"+path, nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
|
|
next := caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return nil
|
|
})
|
|
|
|
err := m.ServeHTTP(rec, req, next)
|
|
if err != nil {
|
|
t.Fatalf("ServeHTTP failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Manually trigger a ban to test cleanup
|
|
m.bannedIPs.Store("192.168.1.1", time.Now().Add(1*time.Second))
|
|
|
|
// Verify error counts exist for all paths
|
|
for _, path := range paths {
|
|
key := "192.168.1.1:" + path
|
|
if _, ok := m.errorCounts.Load(key); !ok {
|
|
t.Errorf("Expected error count for path %s to exist", path)
|
|
}
|
|
}
|
|
|
|
// Wait for ban to expire and cleanup to run
|
|
time.Sleep(3 * time.Second)
|
|
|
|
// Verify all error counts were deleted
|
|
for _, path := range paths {
|
|
key := "192.168.1.1:" + path
|
|
if _, ok := m.errorCounts.Load(key); ok {
|
|
t.Errorf("Expected error count for path %s to be deleted, but it still exists", path)
|
|
}
|
|
}
|
|
}
|