test script updated and minor improvements

This commit is contained in:
fabriziosalmi
2025-01-20 23:12:25 +01:00
parent 2b70c1f9a1
commit 7f77483e36
3 changed files with 164 additions and 34 deletions

View File

@@ -14,9 +14,9 @@
max_error_count 10 # Global error threshold (reduced for faster testing)
ban_duration 5s # Global ban duration (reduced to 10 seconds)
ban_duration_multiplier 1 # Global ban duration multiplier
whitelist 127.0.0.1 ::1 # Whitelisted IPs
# whitelist 127.0.0.1 ::1 # Whitelisted IPs
log_level debug # Log level for debugging
ban_response_body "You have been banned due to excessive errors. Please try again later."
ban_response_body "You have been banned due to excessive errors. Please try again later."
ban_status_code 429 # Custom status code for banned IPs
# Per-path configuration for /login
@@ -37,7 +37,7 @@
}
# All other requests, respond with "Hello World"
handle {
respond "Hello world!" 200
respond "Hello world!" 404
}
}
}

View File

@@ -26,20 +26,20 @@ func init() {
// Middleware implements the Caddy MIB middleware.
type Middleware struct {
ErrorCodes []int `json:"error_codes,omitempty"`
MaxErrorCount int `json:"max_error_count,omitempty"`
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
Whitelist []string `json:"whitelist,omitempty"`
CustomResponseHeader string `json:"custom_response_header,omitempty"`
LogRequestHeaders []string `json:"log_request_headers,omitempty"`
LogLevel string `json:"log_level,omitempty"`
CIDRBans []string `json:"cidr_bans,omitempty"`
BanResponseBody string `json:"ban_response_body,omitempty"`
BanStatusCode int `json:"ban_status_code,omitempty"`
ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense
Whitelist []string `json:"whitelist,omitempty"` // List of IPs or CIDRs to whitelist
CustomResponseHeader string `json:"custom_response_header,omitempty"` // Custom header to add to responses
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // Request headers to log
LogLevel string `json:"log_level,omitempty"` // Log level for the middleware
CIDRBans []string `json:"cidr_bans,omitempty"` // List of CIDRs to ban
BanResponseBody string `json:"ban_response_body,omitempty"` // Response body for banned requests
BanStatusCode int `json:"ban_status_code,omitempty"` // HTTP status code for banned requests
// Per-path configuration
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"`
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"` // Configuration for specific paths
logger *zap.Logger
errorCounts sync.Map // Tracks errors per IP and path
@@ -50,10 +50,10 @@ type Middleware struct {
// PathConfig defines per-path configuration.
type PathConfig struct {
ErrorCodes []int `json:"error_codes,omitempty"`
MaxErrorCount int `json:"max_error_count,omitempty"`
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors for this path
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning for this path
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning for this path
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense for this path
}
// CaddyModule returns the Caddy module information.
@@ -164,7 +164,6 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
}
}
// Rest of the ServeHTTP logic...
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
m.logger.Error("failed to parse client IP",
@@ -278,6 +277,7 @@ func normalizeIP(ip string) string {
return ip
}
// trackErrorStatus tracks errors for a specific IP and path.
func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r *http.Request) {
commonFields := []zap.Field{
zap.String("client_ip", clientIP),
@@ -287,6 +287,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
zap.String("user_agent", r.Header.Get("User-Agent")),
}
// Use a composite key for error counts
key := fmt.Sprintf("%s:%s", clientIP, path)
// Check if the path has a specific configuration
pathConfig, hasPathConfig := m.PerPathConfig[path]
if hasPathConfig {
@@ -301,15 +304,15 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
for _, errCode := range m.ErrorCodes {
if code == errCode {
countBefore := 0
if val, ok := m.errorCounts.Load(clientIP); ok {
countBefore = val.(map[string]int)["global"]
if val, ok := m.errorCounts.Load(key); ok {
countBefore = val.(int)
}
m.logger.Debug("tracking error", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", m.MaxErrorCount),
)...)
countNow := countBefore + 1
m.errorCounts.Store(clientIP, map[string]int{"global": countNow})
m.errorCounts.Store(key, countNow)
m.logger.Debug("error count incremented", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("max_error_count", m.MaxErrorCount),
@@ -317,6 +320,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
if countNow >= m.MaxErrorCount {
offenses := countNow - m.MaxErrorCount + 1
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
banDuration = 24 * time.Hour
}
expiration := time.Now().Add(banDuration)
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
@@ -329,7 +335,7 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
// Add configured request headers to the log
for _, headerName := range m.LogRequestHeaders {
if value := r.Header.Get(headerName); value != "" {
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
logFields = append(logFields, zap.String(headerName, value))
}
}
@@ -350,18 +356,21 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
zap.String("user_agent", r.Header.Get("User-Agent")),
}
// Use a composite key for error counts
key := fmt.Sprintf("%s:%s", clientIP, path)
for _, errCode := range config.ErrorCodes {
if code == errCode {
countBefore := 0
if val, ok := m.errorCounts.Load(clientIP); ok {
countBefore = val.(map[string]int)[path]
if val, ok := m.errorCounts.Load(key); ok {
countBefore = val.(int)
}
m.logger.Debug("tracking error for path", append(commonFields,
zap.Int("current_error_count", countBefore),
zap.Int("max_error_count", config.MaxErrorCount),
)...)
countNow := countBefore + 1
m.errorCounts.Store(clientIP, map[string]int{path: countNow})
m.errorCounts.Store(key, countNow)
m.logger.Debug("error count incremented for path", append(commonFields,
zap.Int("new_error_count", countNow),
zap.Int("max_error_count", config.MaxErrorCount),
@@ -369,6 +378,9 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
if countNow >= config.MaxErrorCount {
offenses := countNow - config.MaxErrorCount + 1
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
banDuration = 24 * time.Hour
}
expiration := time.Now().Add(banDuration)
m.bannedIPs.Store(clientIP, expiration)
logFields := append(commonFields,
@@ -381,7 +393,7 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
// Add configured request headers to the log
for _, headerName := range m.LogRequestHeaders {
if value := r.Header.Get(headerName); value != "" {
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
logFields = append(logFields, zap.String(headerName, value))
}
}
@@ -392,7 +404,6 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
}
}
// extractStatusCodeFromError extracts the HTTP status code from the error message.
// extractStatusCodeFromError extracts the HTTP status code from the error message.
func extractStatusCodeFromError(err error) int {
if err == nil {

129
test.py
View File

@@ -13,11 +13,11 @@ NONEXISTENT_URL_PATH = "/nonexistent"
LOGIN_URL_PATH = "/login"
API_URL_PATH = "/api"
# Global settings
# Global settings (defaults)
GLOBAL_MAX_ERRORS = 10 # Matches max_error_count in Caddyfile
GLOBAL_BAN_DURATION = 10 # Matches ban_duration in Caddyfile (10 seconds)
# Per-path settings
# Per-path settings (defaults)
LOGIN_MAX_ERRORS = 5 # Matches max_error_count for /login
LOGIN_BAN_DURATION = 15 # Matches ban_duration for /login (15 seconds)
API_MAX_ERRORS = 8 # Matches max_error_count for /api
@@ -308,6 +308,91 @@ def test_root_response_with_fab():
log(f"{Fore.RED}test_root_response_with_fab failed: Received unacceptable status code {colored_status} for root URL with 'fab' user-agent.{Style.RESET_ALL}")
return False
def test_custom_response_header():
"""Test that the custom response header is present in the response."""
log("Starting custom response header test...")
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
if "X-Custom-MIB-Info" in response_body:
log(f"{Fore.GREEN}Custom response header found in response.{Style.RESET_ALL}")
return True
else:
log(f"{Fore.RED}Custom response header not found in response.{Style.RESET_ALL}")
return False
def test_whitelist():
"""Test that whitelisted IPs are not banned."""
log("Starting whitelist test...")
# Simulate requests from a whitelisted IP
for i in range(GLOBAL_MAX_ERRORS + 2):
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
log(f"Request {i + 1}: Status Code = {colored_status}")
if status_code == 429:
log(f"{Fore.RED}Whitelist test failed: IP was banned despite being whitelisted.{Style.RESET_ALL}")
return False
log(f"{Fore.GREEN}Whitelist test passed: IP was not banned.{Style.RESET_ALL}")
return True
def test_cidr_ban():
"""Test that IPs within a banned CIDR range are blocked."""
log("Starting CIDR ban test...")
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
if status_code == 429:
log(f"{Fore.GREEN}CIDR ban test passed: IP within banned CIDR range was blocked.{Style.RESET_ALL}")
return True
else:
log(f"{Fore.RED}CIDR ban test failed: IP within banned CIDR range was not blocked.{Style.RESET_ALL}")
return False
def test_ban_duration_multiplier():
"""Test that the ban duration increases exponentially based on the multiplier."""
log("Starting ban duration multiplier test...")
# Trigger multiple bans and measure the duration
ban_durations = []
for i in range(3): # Trigger 3 bans
for _ in range(GLOBAL_MAX_ERRORS + 1):
send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
ban_start_time = datetime.now()
while True:
status_code, _, _ = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
if status_code != 429:
ban_end_time = datetime.now()
ban_durations.append((ban_end_time - ban_start_time).total_seconds())
break
time.sleep(1)
# Verify that ban durations increase exponentially
if ban_durations[1] > ban_durations[0] and ban_durations[2] > ban_durations[1]:
log(f"{Fore.GREEN}Ban duration multiplier test passed: Ban durations increased exponentially.{Style.RESET_ALL}")
return True
else:
log(f"{Fore.RED}Ban duration multiplier test failed: Ban durations did not increase exponentially.{Style.RESET_ALL}")
return False
def test_log_request_headers():
"""Test that specified request headers are logged."""
log("Starting log request headers test...")
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}", user_agent="test-agent")
# Check logs for the presence of the "User-Agent" header
if "test-agent" in response_body: # Assuming the response body contains logged headers
log(f"{Fore.GREEN}Log request headers test passed: 'User-Agent' header was logged.{Style.RESET_ALL}")
return True
else:
log(f"{Fore.RED}Log request headers test failed: 'User-Agent' header was not logged.{Style.RESET_ALL}")
return False
def test_custom_ban_response_body():
"""Test that the custom ban response body is returned."""
log("Starting custom ban response body test...")
# Trigger a ban
for _ in range(GLOBAL_MAX_ERRORS + 1):
send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
if "custom ban response" in response_body: # Replace with the expected custom response
log(f"{Fore.GREEN}Custom ban response body test passed: Custom response body was returned.{Style.RESET_ALL}")
return True
else:
log(f"{Fore.RED}Custom ban response body test failed: Custom response body was not returned.{Style.RESET_ALL}")
return False
def print_summary(test_name, success):
"""Print a summary of the test result with colored output."""
if success:
@@ -319,12 +404,28 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test script for rate limiting and banning.")
parser.add_argument("--base-url", dest="base_url", default=BASE_URL,
help=f"Base URL of the service (default: {BASE_URL})")
parser.add_argument("--global-max-errors", type=int, default=GLOBAL_MAX_ERRORS,
help=f"Global max errors before banning (default: {GLOBAL_MAX_ERRORS})")
parser.add_argument("--global-ban-duration", type=int, default=GLOBAL_BAN_DURATION,
help=f"Global ban duration in seconds (default: {GLOBAL_BAN_DURATION})")
parser.add_argument("--login-max-errors", type=int, default=LOGIN_MAX_ERRORS,
help=f"Max errors for /login before banning (default: {LOGIN_MAX_ERRORS})")
parser.add_argument("--login-ban-duration", type=int, default=LOGIN_BAN_DURATION,
help=f"Ban duration for /login in seconds (default: {LOGIN_BAN_DURATION})")
parser.add_argument("--api-max-errors", type=int, default=API_MAX_ERRORS,
help=f"Max errors for /api before banning (default: {API_MAX_ERRORS})")
parser.add_argument("--api-ban-duration", type=int, default=API_BAN_DURATION,
help=f"Ban duration for /api in seconds (default: {API_BAN_DURATION})")
args = parser.parse_args()
# Update configuration with command-line arguments
BASE_URL = args.base_url
NONEXISTENT_URL = f"{BASE_URL}{NONEXISTENT_URL_PATH}"
LOGIN_URL = f"{BASE_URL}{LOGIN_URL_PATH}"
API_URL = f"{BASE_URL}{API_URL_PATH}"
GLOBAL_MAX_ERRORS = args.global_max_errors
GLOBAL_BAN_DURATION = args.global_ban_duration
LOGIN_MAX_ERRORS = args.login_max_errors
LOGIN_BAN_DURATION = args.login_ban_duration
API_MAX_ERRORS = args.api_max_errors
API_BAN_DURATION = args.api_ban_duration
# Run all tests and collect results
results = {
@@ -333,6 +434,12 @@ if __name__ == "__main__":
"API Ban Test": test_api_ban(),
"Specific 404 Test": test_specific_404(),
"Root Response with fab Test": test_root_response_with_fab(),
"Custom Response Header Test": test_custom_response_header(),
"Whitelist Test": test_whitelist(),
"CIDR Ban Test": test_cidr_ban(),
"Ban Duration Multiplier Test": test_ban_duration_multiplier(),
"Log Request Headers Test": test_log_request_headers(),
"Custom Ban Response Body Test": test_custom_ban_response_body(),
}
# Print summary
@@ -378,3 +485,15 @@ if __name__ == "__main__":
print(f"{Fore.YELLOW}- The server is not returning the expected 404 status for nonexistent URLs, which could indicate a routing issue.{Style.RESET_ALL}")
if test_details.get("Root Response with fab Test") == "FAIL":
print(f"{Fore.YELLOW}- The server is not returning an acceptable status code (2xx, 3xx, or 400, excluding 401, 402, 403, and above 404) for the root URL with the 'fab' user-agent.{Style.RESET_ALL}")
if test_details.get("Custom Response Header Test") == "FAIL":
print(f"{Fore.YELLOW}- The custom response header is not being added to the response.{Style.RESET_ALL}")
if test_details.get("Whitelist Test") == "FAIL":
print(f"{Fore.YELLOW}- Whitelisted IPs are being banned despite being whitelisted.{Style.RESET_ALL}")
if test_details.get("CIDR Ban Test") == "FAIL":
print(f"{Fore.YELLOW}- IPs within banned CIDR ranges are not being blocked.{Style.RESET_ALL}")
if test_details.get("Ban Duration Multiplier Test") == "FAIL":
print(f"{Fore.YELLOW}- The ban duration is not increasing exponentially based on the multiplier.{Style.RESET_ALL}")
if test_details.get("Log Request Headers Test") == "FAIL":
print(f"{Fore.YELLOW}- Specified request headers are not being logged.{Style.RESET_ALL}")
if test_details.get("Custom Ban Response Body Test") == "FAIL":
print(f"{Fore.YELLOW}- The custom ban response body is not being returned.{Style.RESET_ALL}")