mirror of
https://github.com/fabriziosalmi/caddy-mib.git
synced 2025-12-23 14:07:44 -05:00
test script updated and minor improvements
This commit is contained in:
@@ -14,9 +14,9 @@
|
||||
max_error_count 10 # Global error threshold (reduced for faster testing)
|
||||
ban_duration 5s # Global ban duration (reduced to 10 seconds)
|
||||
ban_duration_multiplier 1 # Global ban duration multiplier
|
||||
whitelist 127.0.0.1 ::1 # Whitelisted IPs
|
||||
# whitelist 127.0.0.1 ::1 # Whitelisted IPs
|
||||
log_level debug # Log level for debugging
|
||||
ban_response_body "You have been banned due to excessive errors. Please try again later."
|
||||
ban_response_body "You have been banned due to excessive errors. Please try again later."
|
||||
ban_status_code 429 # Custom status code for banned IPs
|
||||
|
||||
# Per-path configuration for /login
|
||||
@@ -37,7 +37,7 @@
|
||||
}
|
||||
# All other requests, respond with "Hello World"
|
||||
handle {
|
||||
respond "Hello world!" 200
|
||||
respond "Hello world!" 404
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
63
caddymib.go
63
caddymib.go
@@ -26,20 +26,20 @@ func init() {
|
||||
|
||||
// Middleware implements the Caddy MIB middleware.
|
||||
type Middleware struct {
|
||||
ErrorCodes []int `json:"error_codes,omitempty"`
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"`
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
|
||||
Whitelist []string `json:"whitelist,omitempty"`
|
||||
CustomResponseHeader string `json:"custom_response_header,omitempty"`
|
||||
LogRequestHeaders []string `json:"log_request_headers,omitempty"`
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
CIDRBans []string `json:"cidr_bans,omitempty"`
|
||||
BanResponseBody string `json:"ban_response_body,omitempty"`
|
||||
BanStatusCode int `json:"ban_status_code,omitempty"`
|
||||
ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense
|
||||
Whitelist []string `json:"whitelist,omitempty"` // List of IPs or CIDRs to whitelist
|
||||
CustomResponseHeader string `json:"custom_response_header,omitempty"` // Custom header to add to responses
|
||||
LogRequestHeaders []string `json:"log_request_headers,omitempty"` // Request headers to log
|
||||
LogLevel string `json:"log_level,omitempty"` // Log level for the middleware
|
||||
CIDRBans []string `json:"cidr_bans,omitempty"` // List of CIDRs to ban
|
||||
BanResponseBody string `json:"ban_response_body,omitempty"` // Response body for banned requests
|
||||
BanStatusCode int `json:"ban_status_code,omitempty"` // HTTP status code for banned requests
|
||||
|
||||
// Per-path configuration
|
||||
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"`
|
||||
PerPathConfig map[string]PathConfig `json:"per_path,omitempty"` // Configuration for specific paths
|
||||
|
||||
logger *zap.Logger
|
||||
errorCounts sync.Map // Tracks errors per IP and path
|
||||
@@ -50,10 +50,10 @@ type Middleware struct {
|
||||
|
||||
// PathConfig defines per-path configuration.
|
||||
type PathConfig struct {
|
||||
ErrorCodes []int `json:"error_codes,omitempty"`
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"`
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"`
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"`
|
||||
ErrorCodes []int `json:"error_codes,omitempty"` // HTTP status codes to track as errors for this path
|
||||
MaxErrorCount int `json:"max_error_count,omitempty"` // Maximum allowed errors before banning for this path
|
||||
BanDuration caddy.Duration `json:"ban_duration,omitempty"` // Base duration for banning for this path
|
||||
BanDurationMultiplier float64 `json:"ban_duration_multiplier,omitempty"` // Multiplier for ban duration after each offense for this path
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
@@ -164,7 +164,6 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
|
||||
}
|
||||
}
|
||||
|
||||
// Rest of the ServeHTTP logic...
|
||||
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to parse client IP",
|
||||
@@ -278,6 +277,7 @@ func normalizeIP(ip string) string {
|
||||
return ip
|
||||
}
|
||||
|
||||
// trackErrorStatus tracks errors for a specific IP and path.
|
||||
func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r *http.Request) {
|
||||
commonFields := []zap.Field{
|
||||
zap.String("client_ip", clientIP),
|
||||
@@ -287,6 +287,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
}
|
||||
|
||||
// Use a composite key for error counts
|
||||
key := fmt.Sprintf("%s:%s", clientIP, path)
|
||||
|
||||
// Check if the path has a specific configuration
|
||||
pathConfig, hasPathConfig := m.PerPathConfig[path]
|
||||
if hasPathConfig {
|
||||
@@ -301,15 +304,15 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
for _, errCode := range m.ErrorCodes {
|
||||
if code == errCode {
|
||||
countBefore := 0
|
||||
if val, ok := m.errorCounts.Load(clientIP); ok {
|
||||
countBefore = val.(map[string]int)["global"]
|
||||
if val, ok := m.errorCounts.Load(key); ok {
|
||||
countBefore = val.(int)
|
||||
}
|
||||
m.logger.Debug("tracking error", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
)...)
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(clientIP, map[string]int{"global": countNow})
|
||||
m.errorCounts.Store(key, countNow)
|
||||
m.logger.Debug("error count incremented", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("max_error_count", m.MaxErrorCount),
|
||||
@@ -317,6 +320,9 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
if countNow >= m.MaxErrorCount {
|
||||
offenses := countNow - m.MaxErrorCount + 1
|
||||
banDuration := time.Duration(m.BanDuration) * time.Duration(math.Pow(m.BanDurationMultiplier, float64(offenses)))
|
||||
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
|
||||
banDuration = 24 * time.Hour
|
||||
}
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
@@ -329,7 +335,7 @@ func (m *Middleware) trackErrorStatus(clientIP string, code int, path string, r
|
||||
// Add configured request headers to the log
|
||||
for _, headerName := range m.LogRequestHeaders {
|
||||
if value := r.Header.Get(headerName); value != "" {
|
||||
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
|
||||
logFields = append(logFields, zap.String(headerName, value))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,18 +356,21 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
|
||||
zap.String("user_agent", r.Header.Get("User-Agent")),
|
||||
}
|
||||
|
||||
// Use a composite key for error counts
|
||||
key := fmt.Sprintf("%s:%s", clientIP, path)
|
||||
|
||||
for _, errCode := range config.ErrorCodes {
|
||||
if code == errCode {
|
||||
countBefore := 0
|
||||
if val, ok := m.errorCounts.Load(clientIP); ok {
|
||||
countBefore = val.(map[string]int)[path]
|
||||
if val, ok := m.errorCounts.Load(key); ok {
|
||||
countBefore = val.(int)
|
||||
}
|
||||
m.logger.Debug("tracking error for path", append(commonFields,
|
||||
zap.Int("current_error_count", countBefore),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
)...)
|
||||
countNow := countBefore + 1
|
||||
m.errorCounts.Store(clientIP, map[string]int{path: countNow})
|
||||
m.errorCounts.Store(key, countNow)
|
||||
m.logger.Debug("error count incremented for path", append(commonFields,
|
||||
zap.Int("new_error_count", countNow),
|
||||
zap.Int("max_error_count", config.MaxErrorCount),
|
||||
@@ -369,6 +378,9 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
|
||||
if countNow >= config.MaxErrorCount {
|
||||
offenses := countNow - config.MaxErrorCount + 1
|
||||
banDuration := time.Duration(config.BanDuration) * time.Duration(math.Pow(config.BanDurationMultiplier, float64(offenses)))
|
||||
if banDuration > 24*time.Hour { // Cap ban duration at 24 hours
|
||||
banDuration = 24 * time.Hour
|
||||
}
|
||||
expiration := time.Now().Add(banDuration)
|
||||
m.bannedIPs.Store(clientIP, expiration)
|
||||
logFields := append(commonFields,
|
||||
@@ -381,7 +393,7 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
|
||||
// Add configured request headers to the log
|
||||
for _, headerName := range m.LogRequestHeaders {
|
||||
if value := r.Header.Get(headerName); value != "" {
|
||||
logFields = append(logFields, zap.String(strings.ToLower(headerName), value))
|
||||
logFields = append(logFields, zap.String(headerName, value))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,7 +404,6 @@ func (m *Middleware) trackErrorsForPath(clientIP string, code int, path string,
|
||||
}
|
||||
}
|
||||
|
||||
// extractStatusCodeFromError extracts the HTTP status code from the error message.
|
||||
// extractStatusCodeFromError extracts the HTTP status code from the error message.
|
||||
func extractStatusCodeFromError(err error) int {
|
||||
if err == nil {
|
||||
|
||||
129
test.py
129
test.py
@@ -13,11 +13,11 @@ NONEXISTENT_URL_PATH = "/nonexistent"
|
||||
LOGIN_URL_PATH = "/login"
|
||||
API_URL_PATH = "/api"
|
||||
|
||||
# Global settings
|
||||
# Global settings (defaults)
|
||||
GLOBAL_MAX_ERRORS = 10 # Matches max_error_count in Caddyfile
|
||||
GLOBAL_BAN_DURATION = 10 # Matches ban_duration in Caddyfile (10 seconds)
|
||||
|
||||
# Per-path settings
|
||||
# Per-path settings (defaults)
|
||||
LOGIN_MAX_ERRORS = 5 # Matches max_error_count for /login
|
||||
LOGIN_BAN_DURATION = 15 # Matches ban_duration for /login (15 seconds)
|
||||
API_MAX_ERRORS = 8 # Matches max_error_count for /api
|
||||
@@ -308,6 +308,91 @@ def test_root_response_with_fab():
|
||||
log(f"{Fore.RED}test_root_response_with_fab failed: Received unacceptable status code {colored_status} for root URL with 'fab' user-agent.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def test_custom_response_header():
|
||||
"""Test that the custom response header is present in the response."""
|
||||
log("Starting custom response header test...")
|
||||
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
if "X-Custom-MIB-Info" in response_body:
|
||||
log(f"{Fore.GREEN}Custom response header found in response.{Style.RESET_ALL}")
|
||||
return True
|
||||
else:
|
||||
log(f"{Fore.RED}Custom response header not found in response.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def test_whitelist():
|
||||
"""Test that whitelisted IPs are not banned."""
|
||||
log("Starting whitelist test...")
|
||||
# Simulate requests from a whitelisted IP
|
||||
for i in range(GLOBAL_MAX_ERRORS + 2):
|
||||
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
log(f"Request {i + 1}: Status Code = {colored_status}")
|
||||
if status_code == 429:
|
||||
log(f"{Fore.RED}Whitelist test failed: IP was banned despite being whitelisted.{Style.RESET_ALL}")
|
||||
return False
|
||||
log(f"{Fore.GREEN}Whitelist test passed: IP was not banned.{Style.RESET_ALL}")
|
||||
return True
|
||||
|
||||
def test_cidr_ban():
|
||||
"""Test that IPs within a banned CIDR range are blocked."""
|
||||
log("Starting CIDR ban test...")
|
||||
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
if status_code == 429:
|
||||
log(f"{Fore.GREEN}CIDR ban test passed: IP within banned CIDR range was blocked.{Style.RESET_ALL}")
|
||||
return True
|
||||
else:
|
||||
log(f"{Fore.RED}CIDR ban test failed: IP within banned CIDR range was not blocked.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def test_ban_duration_multiplier():
|
||||
"""Test that the ban duration increases exponentially based on the multiplier."""
|
||||
log("Starting ban duration multiplier test...")
|
||||
# Trigger multiple bans and measure the duration
|
||||
ban_durations = []
|
||||
for i in range(3): # Trigger 3 bans
|
||||
for _ in range(GLOBAL_MAX_ERRORS + 1):
|
||||
send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
ban_start_time = datetime.now()
|
||||
while True:
|
||||
status_code, _, _ = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
if status_code != 429:
|
||||
ban_end_time = datetime.now()
|
||||
ban_durations.append((ban_end_time - ban_start_time).total_seconds())
|
||||
break
|
||||
time.sleep(1)
|
||||
# Verify that ban durations increase exponentially
|
||||
if ban_durations[1] > ban_durations[0] and ban_durations[2] > ban_durations[1]:
|
||||
log(f"{Fore.GREEN}Ban duration multiplier test passed: Ban durations increased exponentially.{Style.RESET_ALL}")
|
||||
return True
|
||||
else:
|
||||
log(f"{Fore.RED}Ban duration multiplier test failed: Ban durations did not increase exponentially.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def test_log_request_headers():
|
||||
"""Test that specified request headers are logged."""
|
||||
log("Starting log request headers test...")
|
||||
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}", user_agent="test-agent")
|
||||
# Check logs for the presence of the "User-Agent" header
|
||||
if "test-agent" in response_body: # Assuming the response body contains logged headers
|
||||
log(f"{Fore.GREEN}Log request headers test passed: 'User-Agent' header was logged.{Style.RESET_ALL}")
|
||||
return True
|
||||
else:
|
||||
log(f"{Fore.RED}Log request headers test failed: 'User-Agent' header was not logged.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def test_custom_ban_response_body():
|
||||
"""Test that the custom ban response body is returned."""
|
||||
log("Starting custom ban response body test...")
|
||||
# Trigger a ban
|
||||
for _ in range(GLOBAL_MAX_ERRORS + 1):
|
||||
send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
status_code, response_body, colored_status = send_request(f"{BASE_URL}{NONEXISTENT_URL_PATH}")
|
||||
if "custom ban response" in response_body: # Replace with the expected custom response
|
||||
log(f"{Fore.GREEN}Custom ban response body test passed: Custom response body was returned.{Style.RESET_ALL}")
|
||||
return True
|
||||
else:
|
||||
log(f"{Fore.RED}Custom ban response body test failed: Custom response body was not returned.{Style.RESET_ALL}")
|
||||
return False
|
||||
|
||||
def print_summary(test_name, success):
|
||||
"""Print a summary of the test result with colored output."""
|
||||
if success:
|
||||
@@ -319,12 +404,28 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test script for rate limiting and banning.")
|
||||
parser.add_argument("--base-url", dest="base_url", default=BASE_URL,
|
||||
help=f"Base URL of the service (default: {BASE_URL})")
|
||||
parser.add_argument("--global-max-errors", type=int, default=GLOBAL_MAX_ERRORS,
|
||||
help=f"Global max errors before banning (default: {GLOBAL_MAX_ERRORS})")
|
||||
parser.add_argument("--global-ban-duration", type=int, default=GLOBAL_BAN_DURATION,
|
||||
help=f"Global ban duration in seconds (default: {GLOBAL_BAN_DURATION})")
|
||||
parser.add_argument("--login-max-errors", type=int, default=LOGIN_MAX_ERRORS,
|
||||
help=f"Max errors for /login before banning (default: {LOGIN_MAX_ERRORS})")
|
||||
parser.add_argument("--login-ban-duration", type=int, default=LOGIN_BAN_DURATION,
|
||||
help=f"Ban duration for /login in seconds (default: {LOGIN_BAN_DURATION})")
|
||||
parser.add_argument("--api-max-errors", type=int, default=API_MAX_ERRORS,
|
||||
help=f"Max errors for /api before banning (default: {API_MAX_ERRORS})")
|
||||
parser.add_argument("--api-ban-duration", type=int, default=API_BAN_DURATION,
|
||||
help=f"Ban duration for /api in seconds (default: {API_BAN_DURATION})")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Update configuration with command-line arguments
|
||||
BASE_URL = args.base_url
|
||||
NONEXISTENT_URL = f"{BASE_URL}{NONEXISTENT_URL_PATH}"
|
||||
LOGIN_URL = f"{BASE_URL}{LOGIN_URL_PATH}"
|
||||
API_URL = f"{BASE_URL}{API_URL_PATH}"
|
||||
GLOBAL_MAX_ERRORS = args.global_max_errors
|
||||
GLOBAL_BAN_DURATION = args.global_ban_duration
|
||||
LOGIN_MAX_ERRORS = args.login_max_errors
|
||||
LOGIN_BAN_DURATION = args.login_ban_duration
|
||||
API_MAX_ERRORS = args.api_max_errors
|
||||
API_BAN_DURATION = args.api_ban_duration
|
||||
|
||||
# Run all tests and collect results
|
||||
results = {
|
||||
@@ -333,6 +434,12 @@ if __name__ == "__main__":
|
||||
"API Ban Test": test_api_ban(),
|
||||
"Specific 404 Test": test_specific_404(),
|
||||
"Root Response with fab Test": test_root_response_with_fab(),
|
||||
"Custom Response Header Test": test_custom_response_header(),
|
||||
"Whitelist Test": test_whitelist(),
|
||||
"CIDR Ban Test": test_cidr_ban(),
|
||||
"Ban Duration Multiplier Test": test_ban_duration_multiplier(),
|
||||
"Log Request Headers Test": test_log_request_headers(),
|
||||
"Custom Ban Response Body Test": test_custom_ban_response_body(),
|
||||
}
|
||||
|
||||
# Print summary
|
||||
@@ -378,3 +485,15 @@ if __name__ == "__main__":
|
||||
print(f"{Fore.YELLOW}- The server is not returning the expected 404 status for nonexistent URLs, which could indicate a routing issue.{Style.RESET_ALL}")
|
||||
if test_details.get("Root Response with fab Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- The server is not returning an acceptable status code (2xx, 3xx, or 400, excluding 401, 402, 403, and above 404) for the root URL with the 'fab' user-agent.{Style.RESET_ALL}")
|
||||
if test_details.get("Custom Response Header Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- The custom response header is not being added to the response.{Style.RESET_ALL}")
|
||||
if test_details.get("Whitelist Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- Whitelisted IPs are being banned despite being whitelisted.{Style.RESET_ALL}")
|
||||
if test_details.get("CIDR Ban Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- IPs within banned CIDR ranges are not being blocked.{Style.RESET_ALL}")
|
||||
if test_details.get("Ban Duration Multiplier Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- The ban duration is not increasing exponentially based on the multiplier.{Style.RESET_ALL}")
|
||||
if test_details.get("Log Request Headers Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- Specified request headers are not being logged.{Style.RESET_ALL}")
|
||||
if test_details.get("Custom Ban Response Body Test") == "FAIL":
|
||||
print(f"{Fore.YELLOW}- The custom ban response body is not being returned.{Style.RESET_ALL}")
|
||||
Reference in New Issue
Block a user