diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ca71212..a51e55e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,7 +25,7 @@ jobs: os: - ubuntu-latest go: - - '1.25.2' + - '1.25' runs-on: ${{ matrix.os }} permissions: diff --git a/go.mod b/go.mod index 1b464c1..6cfa1c3 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/fabriziosalmi/caddy-waf go 1.25 -toolchain go1.25.2 +toolchain go1.25.3 require ( github.com/caddyserver/caddy/v2 v2.10.2 @@ -111,9 +111,8 @@ require ( go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - go.step.sm/crypto v0.70.0 // indirect + go.step.sm/crypto v0.71.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect - go.uber.org/mock v0.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect diff --git a/go.sum b/go.sum index 16030ba..2518b03 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/iam v1.5.2 h1:qgFRAGEmd8z6dJ/qyEchAuL9jpswyODjA2lS+w234g8= cloud.google.com/go/iam v1.5.2/go.mod h1:SE1vg0N81zQqLzQEwxL2WI6yhetBdbNQuTvIKCSkUHE= -cloud.google.com/go/kms v1.22.0 h1:dBRIj7+GDeeEvatJeTB19oYZNV0aj6wEqSIT/7gLqtk= -cloud.google.com/go/kms v1.22.0/go.mod h1:U7mf8Sva5jpOb4bxYZdtw/9zsbIjrklYwPcvMk34AL8= +cloud.google.com/go/kms v1.23.0 h1:WaqAZsUptyHwOo9II8rFC1Kd2I+yvNsNP2IJ14H2sUw= +cloud.google.com/go/kms v1.23.0/go.mod h1:rZ5kK0I7Kn9W4erhYVoIRPtpizjunlrfU4fUkumUp8g= cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE= cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= @@ -50,34 +50,34 @@ github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmO github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b h1:uUXgbcPDK3KpW29o4iy7GtuappbWT0l5NaMo9H9pJDw= github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= -github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU= -github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= -github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4= -github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I= -github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM= -github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc= +github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I= +github.com/aws/aws-sdk-go-v2 v1.39.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/config v1.31.12 h1:pYM1Qgy0dKZLHX2cXslNacbcEFMkDMl+Bcj5ROuS6p8= +github.com/aws/aws-sdk-go-v2/config v1.31.12/go.mod h1:/MM0dyD7KSDPR+39p9ZNVKaHDLb9qnfDurvVS2KAhN8= +github.com/aws/aws-sdk-go-v2/credentials v1.18.16 h1:4JHirI4zp958zC026Sm+V4pSDwW4pwLefKrc0bF2lwI= +github.com/aws/aws-sdk-go-v2/credentials v1.18.16/go.mod h1:qQMtGx9OSw7ty1yLclzLxXCRbrkjWAM7JnObZjmCB7I= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.9 h1:Mv4Bc0mWmv6oDuSWTKnk+wgeqPL5DRFu5bQL9BGPQ8Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.9/go.mod h1:IKlKfRppK2a1y0gy1yH6zD+yX5uplJ6UuPlgd48dJiQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9 h1:se2vOWGD3dWQUtfn4wEjRQJb1HK1XsNIt825gskZ970= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9/go.mod h1:hijCGH2VfbZQxqCDN7bwz/4dzxV+hkyhjawAtdPWKZA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 h1:6RBnKZLkJM4hQ+kN6E7yWFveOTg8NLPHAkqrs4ZPlTU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9/go.mod h1:V9rQKRmK7AWuEsOMnHzKj8WyrIir1yUJbZxDuZLFvXI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM= -github.com/aws/aws-sdk-go-v2/service/kms v1.44.0 h1:Z95XCqqSnwXr0AY7PgsiOUBhUG2GoDM5getw6RfD1Lg= -github.com/aws/aws-sdk-go-v2/service/kms v1.44.0/go.mod h1:DqcSngL7jJeU1fOzh5Ll5rSvX/MlMV6OZlE4mVdFAQc= -github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM= -github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs= -github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA= -github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8= -github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= -github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.9 h1:5r34CgVOD4WZudeEKZ9/iKpiT6cM1JyEROpXjOcdWv8= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.9/go.mod h1:dB12CEbNWPbzO2uC6QSWHteqOg4JfBVJOojbAoAUb5I= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.6 h1:Br3kil4j7RPW+7LoLVkYt8SuhIWlg6ylmbmzXJ7PgXY= +github.com/aws/aws-sdk-go-v2/service/kms v1.45.6/go.mod h1:FKXkHzw1fJZtg1P1qoAIiwen5thz/cDRTTDCIu8ljxc= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.6 h1:A1oRkiSQOWstGh61y4Wc/yQ04sqrQZr1Si/oAXj20/s= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.6/go.mod h1:5PfYspyCU5Vw1wNPsxi15LZovOnULudOQuVxphSflQA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.1 h1:5fm5RTONng73/QA73LhCNR7UT9RpFH3hR6HWL6bIgVY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.1/go.mod h1:xBEjWD13h+6nq+z4AkqSfSvqRKFgDIQeaMguAJndOWo= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.6 h1:p3jIvqYwUZgu/XYeI48bJxOhvm47hZb5HUQ0tn6Q9kA= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.6/go.mod h1:WtKK+ppze5yKPkZ0XwqIVWD4beCwv056ZbPQNoeHqM8= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= @@ -184,10 +184,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= -github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU= -github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= -github.com/google/go-tpm-tools v0.4.5 h1:3fhthtyMDbIZFR5/0y1hvUoZ1Kf4i1eZ7C73R4Pvd+k= -github.com/google/go-tpm-tools v0.4.5/go.mod h1:ktjTNq8yZFD6TzdBFefUfen96rF3NpYwpSb2d8bc+Y8= +github.com/google/go-tpm v0.9.6 h1:Ku42PT4LmjDu1H5C5ISWLlpI1mj+Zq7sPGKoRw2XROA= +github.com/google/go-tpm v0.9.6/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.4.6 h1:hwIwPG7w4z5eQEBq11gYw8YYr9xXLfBQ/0JsKyq5AJM= +github.com/google/go-tpm-tools v0.4.6/go.mod h1:MsVQbJnRhKDfWwf5zgr3cDGpj13P1uLAFF0wMEP/n5w= github.com/google/go-tspi v0.3.0 h1:ADtq8RKfP+jrTyIWIZDIYcKOMecRqNJFOew2IT0Inus= github.com/google/go-tspi v0.3.0/go.mod h1:xfMGI3G0PhxCdNVcYr1C4C+EizojDg/TXuX5by8CiHI= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -435,8 +435,8 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6 go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= -go.step.sm/crypto v0.70.0 h1:Q9Ft7N637mucyZcHZd1+0VVQJVwDCKqcb9CYcYi7cds= -go.step.sm/crypto v0.70.0/go.mod h1:pzfUhS5/ue7ev64PLlEgXvhx1opwbhFCjkvlhsxVds0= +go.step.sm/crypto v0.71.0 h1:rAvlQMckgRXjgc8QuwxrbExW/jiX+57mHgX0ka5IGrw= +go.step.sm/crypto v0.71.0/go.mod h1:pCXPI1/bChFwKY3eU8V29P3d9nwQfqraF/I9f74mNU8= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= diff --git a/handler.go b/handler.go index badbd17..1728a80 100644 --- a/handler.go +++ b/handler.go @@ -230,79 +230,9 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i zap.String("user_agent", r.UserAgent()), ) - if phase == 1 && m.CountryBlacklist.Enabled { - m.logger.Debug("Starting country blacklisting phase") - blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP) - if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting", - r, - zap.Error(err), - ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("Country blacklisting phase completed - blocked due to error") - m.incrementGeoIPRequestsMetric(false) // Increment with false for error - return - } else if blocked { - - m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked by country")) - m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked - if m.CustomResponses != nil { - m.writeCustomResponse(w, state.StatusCode) - } - return - } - m.logger.Debug("Country blacklisting phase completed - not blocked") - m.incrementGeoIPRequestsMetric(false) // Increment with false for no block - } - - if phase == 1 && m.CountryWhitelist.Enabled { - m.logger.Debug("Starting country whitelisting phase") - allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP) - if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist", - r, - zap.Error(err), - ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("Country whitelisting phase completed - blocked due to error") - m.incrementGeoIPRequestsMetric(false) // Increment with false for error - return - } else if !allowed { - m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked by country")) - m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked - if m.CustomResponses != nil { - m.writeCustomResponse(w, state.StatusCode) - } - return - } - m.logger.Debug("Country whitelisting phase completed - not blocked") - m.incrementGeoIPRequestsMetric(false) // Increment with false for no block - } - - if phase == 1 && m.rateLimiter != nil { - m.logger.Debug("Starting rate limiting phase") - ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here - path := r.URL.Path // Get the request path - if m.rateLimiter.isRateLimited(ip, path) { - m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware - m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, - zap.String("message", "Request blocked by rate limit"), - ) - if m.CustomResponses != nil { - m.writeCustomResponse(w, state.StatusCode) - } - return - } - m.logger.Debug("Rate limiting phase completed - not blocked") - } - if phase == 1 { + + // IP blacklisting - the highest priority m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call xForwardedFor := r.Header.Get("X-Forwarded-For") if xForwardedFor != "" { @@ -337,18 +267,94 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i return } } - } - if phase == 1 && m.isDNSBlacklisted(r.Host) { - m.logger.Debug("Starting DNS blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, - zap.String("message", "Request blocked by DNS blacklist"), - zap.String("host", r.Host), - ) - if m.CustomResponses != nil { - m.writeCustomResponse(w, state.StatusCode) + // DNS blacklisting + if m.isDNSBlacklisted(r.Host) { + m.logger.Debug("Starting DNS blacklist phase") + m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, + zap.String("message", "Request blocked by DNS blacklist"), + zap.String("host", r.Host), + ) + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + + // Rate limiting + if m.rateLimiter != nil { + m.logger.Debug("Starting rate limiting phase") + ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here + path := r.URL.Path // Get the request path + if m.rateLimiter.isRateLimited(ip, path) { + m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware + m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, + zap.String("message", "Request blocked by rate limit"), + ) + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + m.logger.Debug("Rate limiting phase completed - not blocked") + } + + // Whitelisting + if m.CountryWhitelist.Enabled { + m.logger.Debug("Starting country whitelisting phase") + allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist", + r, + zap.Error(err), + ) + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country whitelisting phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } else if !allowed { + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked by country")) + m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + m.logger.Debug("Country whitelisting phase completed - not blocked") + m.incrementGeoIPRequestsMetric(false) // Increment with false for no block + } + + // Blacklisting + if m.CountryBlacklist.Enabled { + m.logger.Debug("Starting country blacklisting phase") + blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting", + r, + zap.Error(err), + ) + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country blacklisting phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } else if blocked { + + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked by country")) + m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + m.logger.Debug("Country blacklisting phase completed - not blocked") + m.incrementGeoIPRequestsMetric(false) // Increment with false for no block } - return } rules, ok := m.Rules[phase] diff --git a/handler_test.go b/handler_test.go index f0d1178..b6c41d0 100644 --- a/handler_test.go +++ b/handler_test.go @@ -67,7 +67,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata) assert.NoError(t, err) - blackListMiddleware := &Middleware{ + blMiddleware := &Middleware{ logger: logger, ipBlacklist: iptrie.NewTrie(), geoIPHandler: geoIPHandler, @@ -80,7 +80,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { CustomResponses: customResponse, } - whiteListMiddleware := &Middleware{ + wlMiddleware := &Middleware{ logger: logger, ipBlacklist: iptrie.NewTrie(), geoIPHandler: geoIPHandler, @@ -93,6 +93,25 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { CustomResponses: customResponse, } + blackWhiteMw := &Middleware{ + logger: logger, + ipBlacklist: iptrie.NewTrie(), + geoIPHandler: geoIPHandler, + CountryWhitelist: CountryAccessFilter{ + Enabled: true, + CountryList: []string{"BR"}, + GeoIPDBPath: geoIPdata, // Path to a test GeoIP database + geoIP: geoIPBlock, + }, + CountryBlacklist: CountryAccessFilter{ + Enabled: true, + CountryList: []string{"US", "RU"}, + GeoIPDBPath: geoIPdata, // Path to a test GeoIP database + geoIP: geoIPBlock, + }, + CustomResponses: customResponse, + } + req := httptest.NewRequest("GET", testURL, nil) state := &WAFState{} @@ -102,7 +121,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { req.RemoteAddr = aliCNIP // Process the request in Phase 1 - blackListMiddleware.handlePhase(w, req, 1, state) + blMiddleware.handlePhase(w, req, 1, state) assert.False(t, state.Blocked, "Request should be allowed") assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") }) @@ -112,7 +131,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { req.RemoteAddr = googleUSIP // Process the request in Phase 1 - blackListMiddleware.handlePhase(w, req, 1, state) + blMiddleware.handlePhase(w, req, 1, state) // Verify that the request was blocked assert.True(t, state.Blocked, "Request should be blocked") @@ -125,7 +144,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { req.RemoteAddr = googleBRIP // Process the request in Phase 1 - whiteListMiddleware.handlePhase(w, req, 1, state) + wlMiddleware.handlePhase(w, req, 1, state) assert.False(t, state.Blocked, "Request should be allowed") assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") }) @@ -135,13 +154,36 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { req.RemoteAddr = googleRUIP // Process the request in Phase 1 - whiteListMiddleware.handlePhase(w, req, 1, state) + wlMiddleware.handlePhase(w, req, 1, state) // Verify that the request was blocked assert.True(t, state.Blocked, "Request should be blocked") assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") }) + + t.Run("GeoIP whitelist and blacklist: whitelist has the priority", func(t *testing.T) { + + w := httptest.NewRecorder() + + // BR should be allowed + req0 := httptest.NewRequest("GET", testURL, nil) + req0.RemoteAddr = googleBRIP + + blackWhiteMw.handlePhase(w, req0, 1, state) + assert.False(t, state.Blocked, "Request should be allowed") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") + + // US must be blocked + req1 := httptest.NewRequest("GET", testURL, nil) + req1.RemoteAddr = googleUSIP + + blackWhiteMw.handlePhase(w, req1, 1, state) + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") + }) } func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {