mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
refactor: update priorities for block/allow actions
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
os:
|
||||
- ubuntu-latest
|
||||
go:
|
||||
- '1.25.2'
|
||||
- '1.25'
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
permissions:
|
||||
|
||||
5
go.mod
5
go.mod
@@ -2,7 +2,7 @@ module github.com/fabriziosalmi/caddy-waf
|
||||
|
||||
go 1.25
|
||||
|
||||
toolchain go1.25.2
|
||||
toolchain go1.25.3
|
||||
|
||||
require (
|
||||
github.com/caddyserver/caddy/v2 v2.10.2
|
||||
@@ -111,9 +111,8 @@ require (
|
||||
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
go.step.sm/crypto v0.70.0 // indirect
|
||||
go.step.sm/crypto v0.71.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/mock v0.6.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap/exp v0.3.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
|
||||
68
go.sum
68
go.sum
@@ -14,8 +14,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
cloud.google.com/go/iam v1.5.2 h1:qgFRAGEmd8z6dJ/qyEchAuL9jpswyODjA2lS+w234g8=
|
||||
cloud.google.com/go/iam v1.5.2/go.mod h1:SE1vg0N81zQqLzQEwxL2WI6yhetBdbNQuTvIKCSkUHE=
|
||||
cloud.google.com/go/kms v1.22.0 h1:dBRIj7+GDeeEvatJeTB19oYZNV0aj6wEqSIT/7gLqtk=
|
||||
cloud.google.com/go/kms v1.22.0/go.mod h1:U7mf8Sva5jpOb4bxYZdtw/9zsbIjrklYwPcvMk34AL8=
|
||||
cloud.google.com/go/kms v1.23.0 h1:WaqAZsUptyHwOo9II8rFC1Kd2I+yvNsNP2IJ14H2sUw=
|
||||
cloud.google.com/go/kms v1.23.0/go.mod h1:rZ5kK0I7Kn9W4erhYVoIRPtpizjunlrfU4fUkumUp8g=
|
||||
cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE=
|
||||
cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY=
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
@@ -50,34 +50,34 @@ github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmO
|
||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||
github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b h1:uUXgbcPDK3KpW29o4iy7GtuappbWT0l5NaMo9H9pJDw=
|
||||
github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A=
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.0 h1:UCRQ5mlqcFk9HJDIqENSLR3wiG1VTWlyUfLDEvY7RxU=
|
||||
github.com/aws/aws-sdk-go-v2 v1.38.0/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.0 h1:9yH0xiY5fUnVNLRWO0AtayqwU1ndriZdN78LlhruJR4=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.0/go.mod h1:VeV3K72nXnhbe4EuxxhzsDc/ByrCSlZwUnWH52Nde/I=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.4 h1:IPd0Algf1b+Qy9BcDp0sCUcIWdCQPSzDoMK3a8pcbUM=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.4/go.mod h1:nwg78FjH2qvsRM1EVZlX9WuGUJOL5od+0qvm0adEzHk=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3 h1:GicIdnekoJsjq9wqnvyi2elW6CGMSYKhdozE7/Svh78=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.3/go.mod h1:R7BIi6WNC5mc1kfRM7XM/VHC3uRWkjc396sfabq4iOo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3 h1:o9RnO+YZ4X+kt5Z7Nvcishlz0nksIt2PIzDglLMP0vA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.3/go.mod h1:+6aLJzOG1fvMOyzIySYjOFjcguGvVRL68R+uoRencN4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3 h1:joyyUFhiTQQmVK6ImzNU9TQSNRNeD9kOklqTzyk5v6s=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.3/go.mod h1:+vNIyZQP3b3B1tSLI0lxvrU9cfM7gpdRXMFfm67ZcPc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I=
|
||||
github.com/aws/aws-sdk-go-v2 v1.39.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.12 h1:pYM1Qgy0dKZLHX2cXslNacbcEFMkDMl+Bcj5ROuS6p8=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.12/go.mod h1:/MM0dyD7KSDPR+39p9ZNVKaHDLb9qnfDurvVS2KAhN8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.16 h1:4JHirI4zp958zC026Sm+V4pSDwW4pwLefKrc0bF2lwI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.16/go.mod h1:qQMtGx9OSw7ty1yLclzLxXCRbrkjWAM7JnObZjmCB7I=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.9 h1:Mv4Bc0mWmv6oDuSWTKnk+wgeqPL5DRFu5bQL9BGPQ8Y=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.9/go.mod h1:IKlKfRppK2a1y0gy1yH6zD+yX5uplJ6UuPlgd48dJiQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9 h1:se2vOWGD3dWQUtfn4wEjRQJb1HK1XsNIt825gskZ970=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9/go.mod h1:hijCGH2VfbZQxqCDN7bwz/4dzxV+hkyhjawAtdPWKZA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 h1:6RBnKZLkJM4hQ+kN6E7yWFveOTg8NLPHAkqrs4ZPlTU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9/go.mod h1:V9rQKRmK7AWuEsOMnHzKj8WyrIir1yUJbZxDuZLFvXI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3 h1:ieRzyHXypu5ByllM7Sp4hC5f/1Fy5wqxqY0yB85hC7s=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.3/go.mod h1:O5ROz8jHiOAKAwx179v+7sHMhfobFVi6nZt8DEyiYoM=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.44.0 h1:Z95XCqqSnwXr0AY7PgsiOUBhUG2GoDM5getw6RfD1Lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.44.0/go.mod h1:DqcSngL7jJeU1fOzh5Ll5rSvX/MlMV6OZlE4mVdFAQc=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.28.0 h1:Mc/MKBf2m4VynyJkABoVEN+QzkfLqGj0aiJuEe7cMeM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.28.0/go.mod h1:iS5OmxEcN4QIPXARGhavH7S8kETNL11kym6jhoS7IUQ=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0 h1:6csaS/aJmqZQbKhi1EyEMM7yBW653Wy/B9hnBofW+sw=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.0/go.mod h1:59qHWaY5B+Rs7HGTuVGaC32m0rdpQ68N8QCN3khYiqs=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.37.0 h1:MG9VFW43M4A8BYeAfaJJZWrroinxeTi2r3+SnmLQfSA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.37.0/go.mod h1:JdeBDPgpJfuS6rU/hNglmOigKhyEZtBmbraLE4GK1J8=
|
||||
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.9 h1:5r34CgVOD4WZudeEKZ9/iKpiT6cM1JyEROpXjOcdWv8=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.9/go.mod h1:dB12CEbNWPbzO2uC6QSWHteqOg4JfBVJOojbAoAUb5I=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.45.6 h1:Br3kil4j7RPW+7LoLVkYt8SuhIWlg6ylmbmzXJ7PgXY=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.45.6/go.mod h1:FKXkHzw1fJZtg1P1qoAIiwen5thz/cDRTTDCIu8ljxc=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.29.6 h1:A1oRkiSQOWstGh61y4Wc/yQ04sqrQZr1Si/oAXj20/s=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.29.6/go.mod h1:5PfYspyCU5Vw1wNPsxi15LZovOnULudOQuVxphSflQA=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.1 h1:5fm5RTONng73/QA73LhCNR7UT9RpFH3hR6HWL6bIgVY=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.1/go.mod h1:xBEjWD13h+6nq+z4AkqSfSvqRKFgDIQeaMguAJndOWo=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.38.6 h1:p3jIvqYwUZgu/XYeI48bJxOhvm47hZb5HUQ0tn6Q9kA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.38.6/go.mod h1:WtKK+ppze5yKPkZ0XwqIVWD4beCwv056ZbPQNoeHqM8=
|
||||
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
|
||||
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
@@ -184,10 +184,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
|
||||
github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/go-tpm-tools v0.4.5 h1:3fhthtyMDbIZFR5/0y1hvUoZ1Kf4i1eZ7C73R4Pvd+k=
|
||||
github.com/google/go-tpm-tools v0.4.5/go.mod h1:ktjTNq8yZFD6TzdBFefUfen96rF3NpYwpSb2d8bc+Y8=
|
||||
github.com/google/go-tpm v0.9.6 h1:Ku42PT4LmjDu1H5C5ISWLlpI1mj+Zq7sPGKoRw2XROA=
|
||||
github.com/google/go-tpm v0.9.6/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/go-tpm-tools v0.4.6 h1:hwIwPG7w4z5eQEBq11gYw8YYr9xXLfBQ/0JsKyq5AJM=
|
||||
github.com/google/go-tpm-tools v0.4.6/go.mod h1:MsVQbJnRhKDfWwf5zgr3cDGpj13P1uLAFF0wMEP/n5w=
|
||||
github.com/google/go-tspi v0.3.0 h1:ADtq8RKfP+jrTyIWIZDIYcKOMecRqNJFOew2IT0Inus=
|
||||
github.com/google/go-tspi v0.3.0/go.mod h1:xfMGI3G0PhxCdNVcYr1C4C+EizojDg/TXuX5by8CiHI=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
@@ -435,8 +435,8 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.step.sm/crypto v0.70.0 h1:Q9Ft7N637mucyZcHZd1+0VVQJVwDCKqcb9CYcYi7cds=
|
||||
go.step.sm/crypto v0.70.0/go.mod h1:pzfUhS5/ue7ev64PLlEgXvhx1opwbhFCjkvlhsxVds0=
|
||||
go.step.sm/crypto v0.71.0 h1:rAvlQMckgRXjgc8QuwxrbExW/jiX+57mHgX0ka5IGrw=
|
||||
go.step.sm/crypto v0.71.0/go.mod h1:pCXPI1/bChFwKY3eU8V29P3d9nwQfqraF/I9f74mNU8=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
|
||||
170
handler.go
170
handler.go
@@ -230,79 +230,9 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
)
|
||||
|
||||
if phase == 1 && m.CountryBlacklist.Enabled {
|
||||
m.logger.Debug("Starting country blacklisting phase")
|
||||
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if blocked {
|
||||
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"))
|
||||
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country blacklisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
|
||||
if phase == 1 && m.CountryWhitelist.Enabled {
|
||||
m.logger.Debug("Starting country whitelisting phase")
|
||||
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if !allowed {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"))
|
||||
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country whitelisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
|
||||
if phase == 1 && m.rateLimiter != nil {
|
||||
m.logger.Debug("Starting rate limiting phase")
|
||||
ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here
|
||||
path := r.URL.Path // Get the request path
|
||||
if m.rateLimiter.isRateLimited(ip, path) {
|
||||
m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware
|
||||
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by rate limit"),
|
||||
)
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Rate limiting phase completed - not blocked")
|
||||
}
|
||||
|
||||
if phase == 1 {
|
||||
|
||||
// IP blacklisting - the highest priority
|
||||
m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
if xForwardedFor != "" {
|
||||
@@ -337,18 +267,94 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if phase == 1 && m.isDNSBlacklisted(r.Host) {
|
||||
m.logger.Debug("Starting DNS blacklist phase")
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host,
|
||||
zap.String("message", "Request blocked by DNS blacklist"),
|
||||
zap.String("host", r.Host),
|
||||
)
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
// DNS blacklisting
|
||||
if m.isDNSBlacklisted(r.Host) {
|
||||
m.logger.Debug("Starting DNS blacklist phase")
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host,
|
||||
zap.String("message", "Request blocked by DNS blacklist"),
|
||||
zap.String("host", r.Host),
|
||||
)
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if m.rateLimiter != nil {
|
||||
m.logger.Debug("Starting rate limiting phase")
|
||||
ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here
|
||||
path := r.URL.Path // Get the request path
|
||||
if m.rateLimiter.isRateLimited(ip, path) {
|
||||
m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware
|
||||
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by rate limit"),
|
||||
)
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Rate limiting phase completed - not blocked")
|
||||
}
|
||||
|
||||
// Whitelisting
|
||||
if m.CountryWhitelist.Enabled {
|
||||
m.logger.Debug("Starting country whitelisting phase")
|
||||
allowed, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country whitelist",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country whitelisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if !allowed {
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"))
|
||||
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country whitelisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
|
||||
// Blacklisting
|
||||
if m.CountryBlacklist.Enabled {
|
||||
m.logger.Debug("Starting country blacklisting phase")
|
||||
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlacklist.CountryList, m.CountryBlacklist.geoIP)
|
||||
if err != nil {
|
||||
m.logRequest(zapcore.ErrorLevel, "Failed to check country blacklisting",
|
||||
r,
|
||||
zap.Error(err),
|
||||
)
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked due to internal error"),
|
||||
)
|
||||
m.logger.Debug("Country blacklisting phase completed - blocked due to error")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
|
||||
return
|
||||
} else if blocked {
|
||||
|
||||
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
|
||||
zap.String("message", "Request blocked by country"))
|
||||
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
|
||||
if m.CustomResponses != nil {
|
||||
m.writeCustomResponse(w, state.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
m.logger.Debug("Country blacklisting phase completed - not blocked")
|
||||
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rules, ok := m.Rules[phase]
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
geoIPBlock, err := geoIPHandler.LoadGeoIPDatabase(geoIPdata)
|
||||
assert.NoError(t, err)
|
||||
|
||||
blackListMiddleware := &Middleware{
|
||||
blMiddleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
geoIPHandler: geoIPHandler,
|
||||
@@ -80,7 +80,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
whiteListMiddleware := &Middleware{
|
||||
wlMiddleware := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
geoIPHandler: geoIPHandler,
|
||||
@@ -93,6 +93,25 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
blackWhiteMw := &Middleware{
|
||||
logger: logger,
|
||||
ipBlacklist: iptrie.NewTrie(),
|
||||
geoIPHandler: geoIPHandler,
|
||||
CountryWhitelist: CountryAccessFilter{
|
||||
Enabled: true,
|
||||
CountryList: []string{"BR"},
|
||||
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CountryBlacklist: CountryAccessFilter{
|
||||
Enabled: true,
|
||||
CountryList: []string{"US", "RU"},
|
||||
GeoIPDBPath: geoIPdata, // Path to a test GeoIP database
|
||||
geoIP: geoIPBlock,
|
||||
},
|
||||
CustomResponses: customResponse,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", testURL, nil)
|
||||
|
||||
state := &WAFState{}
|
||||
@@ -102,7 +121,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
req.RemoteAddr = aliCNIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
blackListMiddleware.handlePhase(w, req, 1, state)
|
||||
blMiddleware.handlePhase(w, req, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
@@ -112,7 +131,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
req.RemoteAddr = googleUSIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
blackListMiddleware.handlePhase(w, req, 1, state)
|
||||
blMiddleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
@@ -125,7 +144,7 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
req.RemoteAddr = googleBRIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
whiteListMiddleware.handlePhase(w, req, 1, state)
|
||||
wlMiddleware.handlePhase(w, req, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
})
|
||||
@@ -135,13 +154,36 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
||||
req.RemoteAddr = googleRUIP
|
||||
|
||||
// Process the request in Phase 1
|
||||
whiteListMiddleware.handlePhase(w, req, 1, state)
|
||||
wlMiddleware.handlePhase(w, req, 1, state)
|
||||
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
|
||||
t.Run("GeoIP whitelist and blacklist: whitelist has the priority", func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// BR should be allowed
|
||||
req0 := httptest.NewRequest("GET", testURL, nil)
|
||||
req0.RemoteAddr = googleBRIP
|
||||
|
||||
blackWhiteMw.handlePhase(w, req0, 1, state)
|
||||
assert.False(t, state.Blocked, "Request should be allowed")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
|
||||
// US must be blocked
|
||||
req1 := httptest.NewRequest("GET", testURL, nil)
|
||||
req1.RemoteAddr = googleUSIP
|
||||
|
||||
blackWhiteMw.handlePhase(w, req1, 1, state)
|
||||
// Verify that the request was blocked
|
||||
assert.True(t, state.Blocked, "Request should be blocked")
|
||||
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user