From 7f77483e36ca073c4d0bebaba88ca8374fcd5421 Mon Sep 17 00:00:00 2001 From: fabriziosalmi Date: Mon, 20 Jan 2025 23:12:25 +0100 Subject: [PATCH] test script updated and minor improvements --- Caddyfile | 6 +-- caddymib.go | 63 ++++++++++++++----------- test.py | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 164 insertions(+), 34 deletions(-) diff --git a/Caddyfile b/Caddyfile index 240008d..968dcb2 100644 --- a/Caddyfile +++ b/Caddyfile @@ -14,9 +14,9 @@ max_error_count 10 # Global error threshold (reduced for faster testing) ban_duration 5s # Global ban duration (reduced to 10 seconds) ban_duration_multiplier 1 # Global ban duration multiplier - whitelist 127.0.0.1 ::1 # Whitelisted IPs + # whitelist 127.0.0.1 ::1 # Whitelisted IPs log_level debug # Log level for debugging - ban_response_body "You have been banned due to excessive errors. Please try again later." + ban_response_body "You have been banned due to excessive errors. Please try again later." ban_status_code 429 # Custom status code for banned IPs # Per-path configuration for /login @@ -37,7 +37,7 @@ } # All other requests, respond with "Hello World" handle { - respond "Hello world!" 200 + respond "Hello world!" 404 } } } diff --git a/caddymib.go b/caddymib.go index 44ad8a9..319d634 100644 --- a/caddymib.go +++ b/caddymib.go @@ -26,20 +26,20 @@ func init() { // Middleware implements the Caddy MIB middleware. type Middleware struct { - ErrorCodes []int `json:"error_codes,omitempty"` - MaxErrorCount int `json:"max_error_count,omitempty"` - BanDuration caddy.Duration `json:"ban_duration,omitempty"` - BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` - Whitelist []string `json:"whitelist,omitempty"` - CustomResponseHeader string `json:"custom_response_header,omitempty"` - LogRequestHeaders []string `json:"log_request_headers,omitempty"` - LogLevel string `json:"log_level,omitempty"` - CIDRBans []string `json:"cidr_bans,omitempty"` - BanResponseBody string `json:"ban_response_body,omitempty"` - BanStatusCode int `json:"ban_status_code,omitempty"` + ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors + MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning + BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning + BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense + Whitelist []string `json:"whitelist,omitempty"` // List of IPs or CIDRs to whitelist + CustomResponseHeader string `json:"custom_response_header,omitempty"` // Custom header to add to responses + LogRequestHeaders []string `json:"log_request_headers,omitempty"` // Request headers to log + LogLevel string `json:"log_level,omitempty"` // Log level for the middleware + CIDRBans []string `json:"cidr_bans,omitempty"` // List of CIDRs to ban + BanResponseBody string `json:"ban_response_body,omitempty"` // Response body for banned requests + BanStatusCode int `json:"ban_status_code,omitempty"` // HTTP status code for banned requests // Per-path configuration - PerPathConfig map[string]PathConfig `json:"per_path,omitempty"` + PerPathConfig map[string]PathConfig `json:"per_path,omitempty"` // Configuration for specific paths logger *zap.Logger errorCounts sync.Map // Tracks errors per IP and path @@ -50,10 +50,10 @@ type Middleware struct { // PathConfig defines per-path configuration. type PathConfig struct { - ErrorCodes []int `json:"error_codes,omitempty"` - MaxErrorCount int `json:"max_error_count,omitempty"` - BanDuration caddy.Duration `json:"ban_duration,omitempty"` - BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` + ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors for this path + MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning for this path + BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning for this path + BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense for this path } // CaddyModule returns the Caddy module information. @@ -164,7 +164,6 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd } } - // Rest of the ServeHTTP logic... clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { m.logger.Error("failed to parse client IP", @@ -278,6 +277,7 @@ func normalizeIP(ip string) string { return ip } +// trackErrorStatus tracks errors for a specific IP and path. func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r *http.Request) { commonFields := []zap.Field{ zap.String("client_ip", clientIP), @@ -287,6 +287,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r zap.String("user_agent", r.Header.Get("User-Agent")), } + // Use a composite key for error counts + key := fmt.Sprintf("%s:%s", clientIP, path) + // Check if the path has a specific configuration pathConfig, hasPathConfig := m.PerPathConfig[path] if hasPathConfig { @@ -301,15 +304,15 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r for _, errCode := range m.ErrorCodes { if code == errCode { countBefore := 0 - if val, ok := m.errorCounts.Load(clientIP); ok { - countBefore = val.(map[string]int)["global"] + if val, ok := m.errorCounts.Load(key); ok { + countBefore = val.(int) } m.logger.Debug("tracking error", append(commonFields, zap.Int("current_error_count", countBefore), zap.Int("max_error_count", m.MaxErrorCount), )...) countNow := countBefore + 1 - m.errorCounts.Store(clientIP, map[string]int{"global": countNow}) + m.errorCounts.Store(key, countNow) m.logger.Debug("error count incremented", append(commonFields, zap.Int("new_error_count", countNow), zap.Int("max_error_count", m.MaxErrorCount), @@ -317,6 +320,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r if countNow >= m.MaxErrorCount { offenses := countNow - m.MaxErrorCount + 1 banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses))) + if banDuration > 24*time.Hour { // Cap ban duration at 24 hours + banDuration = 24 * time.Hour + } expiration := time.Now().Add(banDuration) m.bannedIPs.Store(clientIP, expiration) logFields := append(commonFields, @@ -329,7 +335,7 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r // Add configured request headers to the log for _, headerName := range m.LogRequestHeaders { if value := r.Header.Get(headerName); value != "" { - logFields = append(logFields, zap.String(strings.ToLower(headerName), value)) + logFields = append(logFields, zap.String(headerName, value)) } } @@ -350,18 +356,21 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, zap.String("user_agent", r.Header.Get("User-Agent")), } + // Use a composite key for error counts + key := fmt.Sprintf("%s:%s", clientIP, path) + for _, errCode := range config.ErrorCodes { if code == errCode { countBefore := 0 - if val, ok := m.errorCounts.Load(clientIP); ok { - countBefore = val.(map[string]int)[path] + if val, ok := m.errorCounts.Load(key); ok { + countBefore = val.(int) } m.logger.Debug("tracking error for path", append(commonFields, zap.Int("current_error_count", countBefore), zap.Int("max_error_count", config.MaxErrorCount), )...) countNow := countBefore + 1 - m.errorCounts.Store(clientIP, map[string]int{path: countNow}) + m.errorCounts.Store(key, countNow) m.logger.Debug("error count incremented for path", append(commonFields, zap.Int("new_error_count", countNow), zap.Int("max_error_count", config.MaxErrorCount), @@ -369,6 +378,9 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, if countNow >= config.MaxErrorCount { offenses := countNow - config.MaxErrorCount + 1 banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses))) + if banDuration > 24*time.Hour { // Cap ban duration at 24 hours + banDuration = 24 * time.Hour + } expiration := time.Now().Add(banDuration) m.bannedIPs.Store(clientIP, expiration) logFields := append(commonFields, @@ -381,7 +393,7 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, // Add configured request headers to the log for _, headerName := range m.LogRequestHeaders { if value := r.Header.Get(headerName); value != "" { - logFields = append(logFields, zap.String(strings.ToLower(headerName), value)) + logFields = append(logFields, zap.String(headerName, value)) } } @@ -392,7 +404,6 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string, } } -// extractStatusCodeFromError extracts the HTTP status code from the error message. // extractStatusCodeFromError extracts the HTTP status code from the error message. func extractStatusCodeFromError(err error) int { if err == nil { diff --git a/test.py b/test.py index a0757c3..e1a1b29 100644 --- a/test.py +++ b/test.py @@ -13,11 +13,11 @@ NONEXISTENT_URL_PATH = "/nonexistent" LOGIN_URL_PATH = "/login" API_URL_PATH = "/api" -# Global settings +# Global settings (defaults) GLOBAL_MAX_ERRORS = 10 # Matches max_error_count in Caddyfile GLOBAL_BAN_DURATION = 10 # Matches ban_duration in Caddyfile (10 seconds) -# Per-path settings +# Per-path settings (defaults) LOGIN_MAX_ERRORS = 5 # Matches max_error_count for /login LOGIN_BAN_DURATION = 15 # Matches ban_duration for /login (15 seconds) API_MAX_ERRORS = 8 # Matches max_error_count for /api @@ -308,6 +308,91 @@ def test_root_response_with_fab(): log(f"{Fore.RED}test_root_response_with_fab failed: Received unacceptable status code {colored_status} for root URL with 'fab' user-agent.{Style.RESET_ALL}") return False +def test_custom_response_header(): + """Test that the custom response header is present in the response.""" + log("Starting custom response header test...") + status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + if "X-Custom-MIB-Info" in response_body: + log(f"{Fore.GREEN}Custom response header found in response.{Style.RESET_ALL}") + return True + else: + log(f"{Fore.RED}Custom response header not found in response.{Style.RESET_ALL}") + return False + +def test_whitelist(): + """Test that whitelisted IPs are not banned.""" + log("Starting whitelist test...") + # Simulate requests from a whitelisted IP + for i in range(GLOBAL_MAX_ERRORS + 2): + status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + log(f"Request {i + 1}: Status Code = {colored_status}") + if status_code == 429: + log(f"{Fore.RED}Whitelist test failed: IP was banned despite being whitelisted.{Style.RESET_ALL}") + return False + log(f"{Fore.GREEN}Whitelist test passed: IP was not banned.{Style.RESET_ALL}") + return True + +def test_cidr_ban(): + """Test that IPs within a banned CIDR range are blocked.""" + log("Starting CIDR ban test...") + status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + if status_code == 429: + log(f"{Fore.GREEN}CIDR ban test passed: IP within banned CIDR range was blocked.{Style.RESET_ALL}") + return True + else: + log(f"{Fore.RED}CIDR ban test failed: IP within banned CIDR range was not blocked.{Style.RESET_ALL}") + return False + +def test_ban_duration_multiplier(): + """Test that the ban duration increases exponentially based on the multiplier.""" + log("Starting ban duration multiplier test...") + # Trigger multiple bans and measure the duration + ban_durations = [] + for i in range(3): # Trigger 3 bans + for _ in range(GLOBAL_MAX_ERRORS + 1): + send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + ban_start_time = datetime.now() + while True: + status_code, _, _ = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + if status_code != 429: + ban_end_time = datetime.now() + ban_durations.append((ban_end_time - ban_start_time).total_seconds()) + break + time.sleep(1) + # Verify that ban durations increase exponentially + if ban_durations[1] > ban_durations[0] and ban_durations[2] > ban_durations[1]: + log(f"{Fore.GREEN}Ban duration multiplier test passed: Ban durations increased exponentially.{Style.RESET_ALL}") + return True + else: + log(f"{Fore.RED}Ban duration multiplier test failed: Ban durations did not increase exponentially.{Style.RESET_ALL}") + return False + +def test_log_request_headers(): + """Test that specified request headers are logged.""" + log("Starting log request headers test...") + status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}", user_agent="test-agent") + # Check logs for the presence of the "User-Agent" header + if "test-agent" in response_body: # Assuming the response body contains logged headers + log(f"{Fore.GREEN}Log request headers test passed: 'User-Agent' header was logged.{Style.RESET_ALL}") + return True + else: + log(f"{Fore.RED}Log request headers test failed: 'User-Agent' header was not logged.{Style.RESET_ALL}") + return False + +def test_custom_ban_response_body(): + """Test that the custom ban response body is returned.""" + log("Starting custom ban response body test...") + # Trigger a ban + for _ in range(GLOBAL_MAX_ERRORS + 1): + send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}") + if "custom ban response" in response_body: # Replace with the expected custom response + log(f"{Fore.GREEN}Custom ban response body test passed: Custom response body was returned.{Style.RESET_ALL}") + return True + else: + log(f"{Fore.RED}Custom ban response body test failed: Custom response body was not returned.{Style.RESET_ALL}") + return False + def print_summary(test_name, success): """Print a summary of the test result with colored output.""" if success: @@ -319,12 +404,28 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test script for rate limiting and banning.") parser.add_argument("--base-url", dest="base_url", default=BASE_URL, help=f"Base URL of the service (default: {BASE_URL})") + parser.add_argument("--global-max-errors", type=int, default=GLOBAL_MAX_ERRORS, + help=f"Global max errors before banning (default: {GLOBAL_MAX_ERRORS})") + parser.add_argument("--global-ban-duration", type=int, default=GLOBAL_BAN_DURATION, + help=f"Global ban duration in seconds (default: {GLOBAL_BAN_DURATION})") + parser.add_argument("--login-max-errors", type=int, default=LOGIN_MAX_ERRORS, + help=f"Max errors for /login before banning (default: {LOGIN_MAX_ERRORS})") + parser.add_argument("--login-ban-duration", type=int, default=LOGIN_BAN_DURATION, + help=f"Ban duration for /login in seconds (default: {LOGIN_BAN_DURATION})") + parser.add_argument("--api-max-errors", type=int, default=API_MAX_ERRORS, + help=f"Max errors for /api before banning (default: {API_MAX_ERRORS})") + parser.add_argument("--api-ban-duration", type=int, default=API_BAN_DURATION, + help=f"Ban duration for /api in seconds (default: {API_BAN_DURATION})") args = parser.parse_args() + # Update configuration with command-line arguments BASE_URL = args.base_url - NONEXISTENT_URL = f"{BASE_URL}{NONEXISTENT_URL_PATH}" - LOGIN_URL = f"{BASE_URL}{LOGIN_URL_PATH}" - API_URL = f"{BASE_URL}{API_URL_PATH}" + GLOBAL_MAX_ERRORS = args.global_max_errors + GLOBAL_BAN_DURATION = args.global_ban_duration + LOGIN_MAX_ERRORS = args.login_max_errors + LOGIN_BAN_DURATION = args.login_ban_duration + API_MAX_ERRORS = args.api_max_errors + API_BAN_DURATION = args.api_ban_duration # Run all tests and collect results results = { @@ -333,6 +434,12 @@ if __name__ == "__main__": "API Ban Test": test_api_ban(), "Specific 404 Test": test_specific_404(), "Root Response with fab Test": test_root_response_with_fab(), + "Custom Response Header Test": test_custom_response_header(), + "Whitelist Test": test_whitelist(), + "CIDR Ban Test": test_cidr_ban(), + "Ban Duration Multiplier Test": test_ban_duration_multiplier(), + "Log Request Headers Test": test_log_request_headers(), + "Custom Ban Response Body Test": test_custom_ban_response_body(), } # Print summary @@ -378,3 +485,15 @@ if __name__ == "__main__": print(f"{Fore.YELLOW}- The server is not returning the expected 404 status for nonexistent URLs, which could indicate a routing issue.{Style.RESET_ALL}") if test_details.get("Root Response with fab Test") == "FAIL": print(f"{Fore.YELLOW}- The server is not returning an acceptable status code (2xx, 3xx, or 400, excluding 401, 402, 403, and above 404) for the root URL with the 'fab' user-agent.{Style.RESET_ALL}") + if test_details.get("Custom Response Header Test") == "FAIL": + print(f"{Fore.YELLOW}- The custom response header is not being added to the response.{Style.RESET_ALL}") + if test_details.get("Whitelist Test") == "FAIL": + print(f"{Fore.YELLOW}- Whitelisted IPs are being banned despite being whitelisted.{Style.RESET_ALL}") + if test_details.get("CIDR Ban Test") == "FAIL": + print(f"{Fore.YELLOW}- IPs within banned CIDR ranges are not being blocked.{Style.RESET_ALL}") + if test_details.get("Ban Duration Multiplier Test") == "FAIL": + print(f"{Fore.YELLOW}- The ban duration is not increasing exponentially based on the multiplier.{Style.RESET_ALL}") + if test_details.get("Log Request Headers Test") == "FAIL": + print(f"{Fore.YELLOW}- Specified request headers are not being logged.{Style.RESET_ALL}") + if test_details.get("Custom Ban Response Body Test") == "FAIL": + print(f"{Fore.YELLOW}- The custom ban response body is not being returned.{Style.RESET_ALL}") \ No newline at end of file