mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
559 lines
18 KiB
Go
559 lines
18 KiB
Go
package caddywaf
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
"go.uber.org/zap/zapcore"
|
|
|
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
|
)
|
|
|
|
type (
|
|
ContextKeyLogId string
|
|
ContextKeyRule string
|
|
)
|
|
|
|
// ServeHTTP implements caddyhttp.Handler.
|
|
// handler.go
|
|
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
|
logID := uuid.New().String()
|
|
|
|
// Add panic recovery to catch and log panics
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
m.logger.Error("PANIC in ServeHTTP",
|
|
zap.String("log_id", logID),
|
|
zap.Any("panic", rec),
|
|
zap.Stack("stack"),
|
|
)
|
|
// Return 500 error to client
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
if _, err := w.Write([]byte("Internal Server Error")); err != nil {
|
|
m.logger.Error(err.Error(),
|
|
zap.String("log_id", logID),
|
|
zap.Any("panic", rec),
|
|
zap.Stack("stack"),
|
|
)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
m.logRequestStart(r, logID)
|
|
|
|
// Propagate log ID within the request context for logging
|
|
ctx := context.WithValue(r.Context(), ContextKeyLogId("logID"), logID)
|
|
r = r.WithContext(ctx)
|
|
|
|
m.incrementTotalRequestsMetric()
|
|
|
|
// Initialize WAF state for this request
|
|
state := m.initializeWAFState()
|
|
|
|
// Phase 1: Pre-request checks and blocking
|
|
if m.isPhaseBlocked(w, r, 1, state) {
|
|
return nil // Request blocked, short-circuit
|
|
}
|
|
|
|
// Phase 2: Request analysis and blocking
|
|
if m.isPhaseBlocked(w, r, 2, state) {
|
|
return nil // Request blocked, short-circuit
|
|
}
|
|
|
|
// Response capture and processing
|
|
recorder := NewResponseRecorder(w)
|
|
err := next.ServeHTTP(recorder, r)
|
|
|
|
// Phase 3: Response Header analysis
|
|
if m.isPhaseBlocked(recorder, r, 3, state) {
|
|
return nil // Request blocked in Phase 3, short-circuit
|
|
}
|
|
|
|
// Phase 4: Response Body analysis (if not already blocked)
|
|
m.handleResponseBodyPhase(recorder, r, state)
|
|
|
|
if state.Blocked {
|
|
// Metrics and response handling if blocked after headers phase
|
|
m.incrementBlockedRequestsMetric()
|
|
m.writeCustomResponse(recorder, state.StatusCode)
|
|
return nil
|
|
}
|
|
|
|
// Handle metrics request separately
|
|
if m.isMetricsRequest(r) {
|
|
return m.handleMetricsRequest(w, r)
|
|
}
|
|
|
|
// If not blocked, copy recorded response back to original writer
|
|
// Moved this inside if check to call only if not blocked
|
|
if !state.Blocked {
|
|
m.incrementAllowedRequestsMetric() // Increment here only if not blocked
|
|
m.copyResponse(w, recorder, r)
|
|
}
|
|
m.logRequestCompletion(logID, state)
|
|
|
|
return err // Return any error from the next handler
|
|
}
|
|
|
|
// isPhaseBlocked encapsulates the phase handling and blocking check logic.
|
|
func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) bool {
|
|
m.handlePhase(w, r, phase, state)
|
|
|
|
if state.Blocked {
|
|
m.incrementBlockedRequestsMetric()
|
|
|
|
// IMPORTANT: Log the block event with details
|
|
m.logger.Warn("Request blocked in phase evaluation",
|
|
zap.Int("phase", phase),
|
|
zap.Int("status_code", state.StatusCode),
|
|
zap.Int("total_score", state.TotalScore),
|
|
zap.Int("anomaly_threshold", m.AnomalyThreshold),
|
|
)
|
|
|
|
// Only write the status if not already written
|
|
if !state.ResponseWritten {
|
|
w.WriteHeader(state.StatusCode)
|
|
state.ResponseWritten = true
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// logRequestStart logs the start of WAF evaluation.
|
|
func (m *Middleware) logRequestStart(r *http.Request, logID string) {
|
|
m.logger.Info("WAF request evaluation started",
|
|
zap.String("log_id", logID),
|
|
zap.String("method", r.Method),
|
|
zap.String("uri", r.RequestURI),
|
|
zap.String("remote_address", r.RemoteAddr),
|
|
zap.String("user_agent", r.UserAgent()),
|
|
)
|
|
}
|
|
|
|
// incrementTotalRequestsMetric increments the total requests metric.
|
|
func (m *Middleware) incrementTotalRequestsMetric() {
|
|
m.muMetrics.Lock()
|
|
m.totalRequests++
|
|
m.muMetrics.Unlock()
|
|
}
|
|
|
|
// initializeWAFState initializes the WAF state.
|
|
func (m *Middleware) initializeWAFState() *WAFState {
|
|
return &WAFState{
|
|
TotalScore: 0,
|
|
Blocked: false,
|
|
StatusCode: http.StatusOK,
|
|
ResponseWritten: false,
|
|
}
|
|
}
|
|
|
|
// getLogID extracts the logID from the request context.
|
|
func getLogID(ctx context.Context) string {
|
|
if logID, ok := ctx.Value(ContextKeyLogId("logID")).(string); ok {
|
|
return logID
|
|
}
|
|
return "unknown"
|
|
}
|
|
|
|
// handleResponseBodyPhase processes Phase 4 (response body).
|
|
func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) {
|
|
// No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder
|
|
body := recorder.BodyString()
|
|
logID := getLogID(r.Context())
|
|
if logID == "unknown" {
|
|
m.logger.Error("Log ID missing in context")
|
|
return
|
|
}
|
|
m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID))
|
|
|
|
// Check if rules exist for Phase 4 before iterating
|
|
rules, ok := m.Rules[4]
|
|
if !ok || len(rules) == 0 {
|
|
m.logger.Debug("No rules found for Phase 4")
|
|
return
|
|
}
|
|
|
|
for _, rule := range rules {
|
|
if rule.regex.MatchString(body) {
|
|
if m.processRuleMatch(recorder, r, &rule, "RESPONSE_BODY", body, state) { // Pass RESPONSE_BODY as target
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// incrementBlockedRequestsMetric increments the blocked requests metric.
|
|
func (m *Middleware) incrementBlockedRequestsMetric() {
|
|
m.muMetrics.Lock()
|
|
m.blockedRequests++
|
|
m.muMetrics.Unlock()
|
|
}
|
|
|
|
// incrementAllowedRequestsMetric increments the allowed requests metric.
|
|
func (m *Middleware) incrementAllowedRequestsMetric() {
|
|
m.muMetrics.Lock()
|
|
m.allowedRequests++
|
|
m.muMetrics.Unlock()
|
|
}
|
|
|
|
// isMetricsRequest checks if it's a metrics request.
|
|
func (m *Middleware) isMetricsRequest(r *http.Request) bool {
|
|
return m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint
|
|
}
|
|
|
|
// writeCustomResponse writes a custom response.
|
|
func (m *Middleware) writeCustomResponse(w http.ResponseWriter, statusCode int) {
|
|
if customResponse, ok := m.CustomResponses[statusCode]; ok {
|
|
for key, value := range customResponse.Headers {
|
|
w.Header().Set(key, value)
|
|
}
|
|
w.WriteHeader(customResponse.StatusCode)
|
|
if _, err := w.Write([]byte(customResponse.Body)); err != nil {
|
|
m.logger.Error("Failed to write custom response body", zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// logRequestCompletion logs the completion of WAF evaluation.
|
|
func (m *Middleware) logRequestCompletion(logID string, state *WAFState) {
|
|
m.logger.Info("WAF request evaluation completed",
|
|
zap.String("log_id", logID),
|
|
zap.Int("total_score", state.TotalScore),
|
|
zap.Bool("blocked", state.Blocked),
|
|
zap.Int("status_code", state.StatusCode),
|
|
)
|
|
}
|
|
|
|
// copyResponse copies the captured response from the recorder to the original writer
|
|
func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecorder, r *http.Request) {
|
|
header := w.Header()
|
|
for key, values := range recorder.Header() {
|
|
for _, value := range values {
|
|
header.Add(key, value)
|
|
}
|
|
}
|
|
w.WriteHeader(recorder.StatusCode())
|
|
|
|
logID := getLogID(r.Context())
|
|
if logID == "unknown" {
|
|
m.logger.Error("Log ID not found in context during response copy") // added error log for clarity
|
|
}
|
|
_, err := w.Write(recorder.body.Bytes()) // Copy body from recorder to original writer
|
|
if err != nil {
|
|
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", logID))
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) {
|
|
m.logger.Debug("Starting phase evaluation",
|
|
zap.Int("phase", phase),
|
|
zap.String("source_ip", r.RemoteAddr),
|
|
zap.String("user_agent", r.UserAgent()),
|
|
)
|
|
|
|
if phase == 1 {
|
|
// IP blacklisting - the highest priority
|
|
m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call
|
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
|
if xForwardedFor != "" {
|
|
ips := strings.Split(xForwardedFor, ",")
|
|
if len(ips) > 0 {
|
|
firstIP := strings.TrimSpace(ips[0])
|
|
m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr))
|
|
if m.isIPBlacklisted(firstIP) {
|
|
m.logger.Debug("Starting IP blacklist phase")
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule",
|
|
zap.String("message", "Request blocked by IP blacklist"),
|
|
)
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
} else {
|
|
m.logger.Debug("X-Forwarded-For header present but empty or invalid")
|
|
}
|
|
} else {
|
|
m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr")
|
|
if m.isIPBlacklisted(r.RemoteAddr) {
|
|
m.logger.Debug("Starting IP blacklist phase")
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule",
|
|
zap.String("message", "Request blocked by IP blacklist"),
|
|
)
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// DNS blacklisting
|
|
if m.isDNSBlacklisted(r.Host) {
|
|
m.logger.Debug("Starting DNS blacklist phase")
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule",
|
|
zap.String("message", "Request blocked by DNS blacklist"),
|
|
zap.String("host", r.Host),
|
|
)
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
if m.rateLimiter != nil {
|
|
m.logger.Debug("Starting rate limiting phase")
|
|
ip := extractIP(r.RemoteAddr) // Pass the logger here
|
|
path := r.URL.Path // Get the request path
|
|
if m.rateLimiter.isRateLimited(ip, path) {
|
|
m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware
|
|
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule",
|
|
zap.String("message", "Request blocked by rate limit"),
|
|
)
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
m.logger.Debug("Rate limiting phase completed - not blocked")
|
|
}
|
|
|
|
// Whitelisting
|
|
if m.CountryWhitelist.Enabled {
|
|
m.logger.Debug("Starting country whitelisting phase")
|
|
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
|
|
if err != nil {
|
|
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
|
|
r,
|
|
zap.Error(err),
|
|
)
|
|
if m.GeoIPFailOpen {
|
|
m.logger.Warn("GeoIP lookup failed (Whitelist); Failing OPEN")
|
|
} else {
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
|
zap.String("message", "Request blocked due to internal error"),
|
|
)
|
|
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
|
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
|
return
|
|
}
|
|
} else if !allowed {
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
|
|
zap.String("message", "Request blocked by country"))
|
|
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
m.logger.Debug("Country whitelisting phase completed - not blocked")
|
|
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
|
}
|
|
|
|
// ASN Blocking
|
|
if m.BlockASNs.Enabled {
|
|
m.logger.Debug("Starting ASN blocking phase")
|
|
blocked, err := m.geoIPHandler.IsASNInList(r.RemoteAddr, m.BlockASNs.BlockedASNs, m.BlockASNs.geoIP)
|
|
if err != nil {
|
|
m.logRequest(zapcore.ErrorLevel, "Failed to check ASN blocking",
|
|
r,
|
|
zap.Error(err),
|
|
)
|
|
if m.GeoIPFailOpen {
|
|
m.logger.Warn("ASN lookup failed; Failing OPEN")
|
|
} else {
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule",
|
|
zap.String("message", "Request blocked due to internal error"),
|
|
)
|
|
m.logger.Debug("ASN blocking phase completed - blocked due to error")
|
|
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
|
return
|
|
}
|
|
} else if blocked {
|
|
asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP)
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule",
|
|
zap.String("message", "Request blocked by ASN"),
|
|
zap.String("asn", asnInfo),
|
|
)
|
|
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
m.logger.Debug("ASN blocking phase completed - not blocked")
|
|
}
|
|
|
|
// Blacklisting
|
|
if m.CountryBlacklist.Enabled {
|
|
m.logger.Debug("Starting country blacklisting phase")
|
|
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
|
|
if err != nil {
|
|
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
|
|
r,
|
|
zap.Error(err),
|
|
)
|
|
if m.GeoIPFailOpen {
|
|
m.logger.Warn("GeoIP lookup failed (Blacklist); Failing OPEN")
|
|
} else {
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
|
|
zap.String("message", "Request blocked due to internal error"),
|
|
)
|
|
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
|
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
|
return
|
|
}
|
|
} else if blocked {
|
|
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
|
|
zap.String("message", "Request blocked by country"))
|
|
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
m.logger.Debug("Country blacklisting phase completed - not blocked")
|
|
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
|
}
|
|
}
|
|
|
|
rules, ok := m.Rules[phase]
|
|
if !ok {
|
|
m.logger.Debug("No rules found for phase", zap.Int("phase", phase))
|
|
// Don't block on empty rules. There may be no rules specified
|
|
// return
|
|
}
|
|
|
|
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
|
|
|
|
for _, rule := range rules {
|
|
m.logger.Debug("Processing rule", zap.String("rule_id", rule.ID), zap.Int("target_count", len(rule.Targets)))
|
|
|
|
// Use the custom type as the key
|
|
ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID)
|
|
r = r.WithContext(ctx)
|
|
|
|
for _, target := range rule.Targets {
|
|
m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", rule.ID))
|
|
var value string
|
|
var err error
|
|
|
|
if phase == 3 || phase == 4 {
|
|
if recorder, ok := w.(*responseRecorder); ok {
|
|
value, err = m.extractValue(target, r, recorder)
|
|
} else {
|
|
m.logger.Error("response recorder is not available in phase 3 or 4 when required")
|
|
value, err = m.extractValue(target, r, nil)
|
|
}
|
|
} else {
|
|
value, err = m.extractValue(target, r, nil)
|
|
}
|
|
|
|
if err != nil {
|
|
m.logger.Debug("Failed to extract value for target, skipping rule for this target",
|
|
zap.String("target", target),
|
|
zap.String("rule_id", rule.ID),
|
|
zap.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
|
|
redactedValue := m.requestValueExtractor.RedactValueIfSensitive(target, value)
|
|
|
|
m.logger.Debug("Extracted value",
|
|
zap.String("rule_id", rule.ID),
|
|
zap.String("target", target),
|
|
zap.String("value", redactedValue),
|
|
)
|
|
|
|
if rule.regex.MatchString(value) {
|
|
m.logger.Debug("Rule matched",
|
|
zap.String("rule_id", rule.ID),
|
|
zap.String("target", target),
|
|
zap.String("value", redactedValue),
|
|
)
|
|
|
|
// FIXED: Correctly interpret processRuleMatch return value
|
|
var shouldContinue bool
|
|
if phase == 3 || phase == 4 {
|
|
if recorder, ok := w.(*responseRecorder); ok {
|
|
shouldContinue = m.processRuleMatch(recorder, r, &rule, target, value, state)
|
|
} else {
|
|
shouldContinue = m.processRuleMatch(w, r, &rule, target, value, state)
|
|
}
|
|
} else {
|
|
shouldContinue = m.processRuleMatch(w, r, &rule, target, value, state)
|
|
}
|
|
|
|
// If processRuleMatch returned false or state is now blocked, stop processing
|
|
if !shouldContinue || state.Blocked || state.ResponseWritten {
|
|
m.logger.Debug("Rule evaluation stopping due to blocking or rule directive",
|
|
zap.Int("phase", phase),
|
|
zap.String("rule_id", rule.ID),
|
|
zap.Bool("continue", shouldContinue),
|
|
zap.Bool("blocked", state.Blocked),
|
|
)
|
|
|
|
if m.CustomResponses != nil {
|
|
m.writeCustomResponse(w, state.StatusCode)
|
|
}
|
|
return
|
|
}
|
|
} else {
|
|
m.logger.Debug("Rule did not match",
|
|
zap.String("rule_id", rule.ID),
|
|
zap.String("target", target),
|
|
zap.String("value", redactedValue),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
m.logger.Debug("Rule evaluation completed for phase", zap.Int("phase", phase))
|
|
|
|
if phase == 3 {
|
|
m.logger.Debug("Starting response headers phase")
|
|
if _, ok := w.(*responseRecorder); ok {
|
|
m.logger.Debug("Response headers phase completed")
|
|
}
|
|
}
|
|
|
|
if phase == 4 {
|
|
m.logger.Debug("Starting response body phase")
|
|
if _, ok := w.(*responseRecorder); ok {
|
|
m.logger.Debug("Response body phase completed")
|
|
}
|
|
}
|
|
|
|
m.logger.Debug("Completed phase evaluation",
|
|
zap.Int("phase", phase),
|
|
zap.Int("total_score", state.TotalScore),
|
|
zap.Int("anomaly_threshold", m.AnomalyThreshold),
|
|
)
|
|
|
|
m.allowRequest(state)
|
|
}
|
|
|
|
// incrementRateLimiterBlockedRequestsMetric increments the blocked requests metric for the rate limiter.
|
|
func (m *Middleware) incrementRateLimiterBlockedRequestsMetric() {
|
|
m.muRateLimiterMetrics.Lock()
|
|
defer m.muRateLimiterMetrics.Unlock()
|
|
m.rateLimiterBlockedRequests++
|
|
}
|
|
|
|
// incrementGeoIPRequestsMetric increments the GeoIP requests metric.
|
|
func (m *Middleware) incrementGeoIPRequestsMetric(blocked bool) {
|
|
m.muMetrics.Lock()
|
|
defer m.muMetrics.Unlock()
|
|
if blocked {
|
|
m.geoIPBlocked++
|
|
}
|
|
}
|