fix: lint errors

This commit is contained in:
drev74
2025-10-22 23:04:48 +03:00
parent 06a496e3d3
commit c8c0fed9e2
13 changed files with 103 additions and 67 deletions

26
GNUmakefile Normal file
View File

@@ -0,0 +1,26 @@
tidy:
@go mod tidy
@echo "Done!"
upd:
@go get -u ./...
@echo "Done!"
fmt:
@go fmt ./...
test:
@go test -v ./...
@echo "Done!"
it:
@go test -v ./... -tags=it
@echo "Done!"
lint:
@echo "==> Checking source code with golangci-lint..."
@golangci-lint run
lintfix:
@echo "==> Checking source code with golangci-lint..."
@golangci-lint run --fix

View File

@@ -287,17 +287,13 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
m.logger.Debug("Logging worker stopped.") m.logger.Debug("Logging worker stopped.")
var firstError error var firstError error
var errorOccurred bool
// Close GeoIP databases // Close GeoIP databases
if m.CountryBlacklist.geoIP != nil { if m.CountryBlacklist.geoIP != nil {
m.logger.Debug("Closing country blacklist GeoIP database...") m.logger.Debug("Closing country blacklist GeoIP database...")
if err := m.CountryBlacklist.geoIP.Close(); err != nil { if err := m.CountryBlacklist.geoIP.Close(); err != nil {
m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err)) m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err))
if !errorOccurred { firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err)
firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err)
errorOccurred = true
}
} else { } else {
m.logger.Debug("Country blacklist GeoIP database closed successfully.") m.logger.Debug("Country blacklist GeoIP database closed successfully.")
} }

View File

@@ -59,25 +59,31 @@ func (m *Middleware) DumpRulesToFile(path string) error {
} }
defer f.Close() defer f.Close()
f.WriteString("=== WAF Rules Dump ===\n\n") if _, err := f.WriteString("=== WAF Rules Dump ===\n\n"); err != nil {
return err
}
for phase := 1; phase <= 4; phase++ { for phase := 1; phase <= 4; phase++ {
f.WriteString(fmt.Sprintf("== Phase %d Rules ==\n", phase)) fmt.Fprintf(f, "== Phase %d Rules ==\n", phase)
rules, ok := m.Rules[phase] rules, ok := m.Rules[phase]
if !ok || len(rules) == 0 { if !ok || len(rules) == 0 {
f.WriteString(" No rules for this phase\n\n") if _, err := f.WriteString(" No rules for this phase\n\n"); err != nil {
return err
}
continue continue
} }
for i, rule := range rules { for i, rule := range rules {
f.WriteString(fmt.Sprintf(" Rule %d:\n", i+1)) fmt.Fprintf(f, " Rule %d:\n", i+1)
f.WriteString(fmt.Sprintf(" ID: %s\n", rule.ID)) fmt.Fprintf(f, " ID: %s\n", rule.ID)
f.WriteString(fmt.Sprintf(" Pattern: %s\n", rule.Pattern)) fmt.Fprintf(f, " Pattern: %s\n", rule.Pattern)
f.WriteString(fmt.Sprintf(" Targets: %v\n", rule.Targets)) fmt.Fprintf(f, " Targets: %v\n", rule.Targets)
f.WriteString(fmt.Sprintf(" Score: %d\n", rule.Score)) fmt.Fprintf(f, " Score: %d\n", rule.Score)
f.WriteString(fmt.Sprintf(" Action: %s\n", rule.Action)) fmt.Fprintf(f, " Action: %s\n", rule.Action)
f.WriteString(fmt.Sprintf(" Description: %s\n", rule.Description)) fmt.Fprintf(f, " Description: %s\n", rule.Description)
f.WriteString("\n") if _, err := f.WriteString("\n"); err != nil {
return err
}
} }
} }

12
go.mod
View File

@@ -74,7 +74,7 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/pgx/v5 v5.7.6 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/libdns/libdns v1.1.1 // indirect github.com/libdns/libdns v1.1.1 // indirect
github.com/manifoldco/promptui v0.9.0 // indirect github.com/manifoldco/promptui v0.9.0 // indirect
@@ -94,7 +94,7 @@ require (
github.com/prometheus/client_golang v1.23.2 // indirect github.com/prometheus/client_golang v1.23.2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.1 // indirect github.com/prometheus/common v0.67.1 // indirect
github.com/prometheus/procfs v0.17.0 // indirect github.com/prometheus/procfs v0.18.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.55.0 // indirect github.com/quic-go/quic-go v0.55.0 // indirect
github.com/rs/xid v1.6.0 // indirect github.com/rs/xid v1.6.0 // indirect
@@ -106,7 +106,7 @@ require (
github.com/smallstep/certificates v0.28.4 // indirect github.com/smallstep/certificates v0.28.4 // indirect
github.com/smallstep/cli-utils v0.12.2 // indirect github.com/smallstep/cli-utils v0.12.2 // indirect
github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca // indirect github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca // indirect
github.com/smallstep/linkedca v0.24.0 // indirect github.com/smallstep/linkedca v0.25.0 // indirect
github.com/smallstep/nosql v0.7.0 // indirect github.com/smallstep/nosql v0.7.0 // indirect
github.com/smallstep/pkcs7 v0.2.1 // indirect github.com/smallstep/pkcs7 v0.2.1 // indirect
github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492 // indirect github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492 // indirect
@@ -153,9 +153,9 @@ require (
golang.org/x/text v0.30.0 // indirect golang.org/x/text v0.30.0 // indirect
golang.org/x/time v0.14.0 // indirect golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.38.0 // indirect golang.org/x/tools v0.38.0 // indirect
google.golang.org/api v0.252.0 // indirect google.golang.org/api v0.253.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect
google.golang.org/grpc v1.76.0 // indirect google.golang.org/grpc v1.76.0 // indirect
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect

24
go.sum
View File

@@ -249,8 +249,8 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -326,8 +326,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q
github.com/prometheus/common v0.67.1 h1:OTSON1P4DNxzTg4hmKCc37o4ZAZDv0cfXLkOt0oEowI= github.com/prometheus/common v0.67.1 h1:OTSON1P4DNxzTg4hmKCc37o4ZAZDv0cfXLkOt0oEowI=
github.com/prometheus/common v0.67.1/go.mod h1:RpmT9v35q2Y+lsieQsdOh5sXZ6ajUGC8NjZAmr8vb0Q= github.com/prometheus/common v0.67.1/go.mod h1:RpmT9v35q2Y+lsieQsdOh5sXZ6ajUGC8NjZAmr8vb0Q=
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= github.com/prometheus/procfs v0.18.0 h1:2QTA9cKdznfYJz7EDaa7IiJobHuV7E1WzeBwcrhk0ao=
github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/prometheus/procfs v0.18.0/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
@@ -381,8 +381,8 @@ github.com/smallstep/cli-utils v0.12.2 h1:lGzM9PJrH/qawbzMC/s2SvgLdJPKDWKwKzx9do
github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y= github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y=
github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4=
github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4=
github.com/smallstep/linkedca v0.24.0 h1:7nQuHLrI7XQVbZUgvNsUiW35mskyK1itsZyboZxor3E= github.com/smallstep/linkedca v0.25.0 h1:txT9QHGbCsJq0MhAghBq7qhurGY727tQuqUi+n4BVBo=
github.com/smallstep/linkedca v0.24.0/go.mod h1:7VovSkUuLpO4sJPUxp25aEo9+3XIcgEEMoj2noEQFcI= github.com/smallstep/linkedca v0.25.0/go.mod h1:Q3jVAauFKNlF86W5/RFtgQeyDKz98GL/KN3KG4mJOvc=
github.com/smallstep/nosql v0.7.0 h1:YiWC9ZAHcrLCrayfaF+QJUv16I2bZ7KdLC3RpJcnAnE= github.com/smallstep/nosql v0.7.0 h1:YiWC9ZAHcrLCrayfaF+QJUv16I2bZ7KdLC3RpJcnAnE=
github.com/smallstep/nosql v0.7.0/go.mod h1:H5VnKMCbeq9QA6SRY5iqPylfxLfYcLwvUff3onQ8+HU= github.com/smallstep/nosql v0.7.0/go.mod h1:H5VnKMCbeq9QA6SRY5iqPylfxLfYcLwvUff3onQ8+HU=
github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA=
@@ -635,8 +635,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y=
google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI= google.golang.org/api v0.253.0 h1:apU86Eq9Q2eQco3NsUYFpVTfy7DwemojL7LmbAj7g/I=
google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw= google.golang.org/api v0.253.0/go.mod h1:PX09ad0r/4du83vZVAaGg7OaeyGnaUmT/CYPNvtLCbw=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -648,10 +648,10 @@ google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk
google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM= google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f h1:1FTH6cpXFsENbPR5Bu8NQddPSaUUE6NA2XdZdDSAJK4= google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=

View File

@@ -32,7 +32,14 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
) )
// Return 500 error to client // Return 500 error to client
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error")) if _, err := w.Write([]byte("Internal Server Error")); err != nil {
m.logger.Error(err.Error(),
zap.String("log_id", logID),
zap.Any("panic", rec),
zap.Stack("stack"),
)
return
}
} }
}() }()
@@ -252,7 +259,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
) )
if phase == 1 { if phase == 1 {
// IP blacklisting - the highest priority // IP blacklisting - the highest priority
m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call
xForwardedFor := r.Header.Get("X-Forwarded-For") xForwardedFor := r.Header.Get("X-Forwarded-For")
@@ -263,7 +269,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr)) m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr))
if m.isIPBlacklisted(firstIP) { if m.isIPBlacklisted(firstIP) {
m.logger.Debug("Starting IP blacklist phase") m.logger.Debug("Starting IP blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP, m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule",
zap.String("message", "Request blocked by IP blacklist"), zap.String("message", "Request blocked by IP blacklist"),
) )
if m.CustomResponses != nil { if m.CustomResponses != nil {
@@ -274,12 +280,11 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
} else { } else {
m.logger.Debug("X-Forwarded-For header present but empty or invalid") m.logger.Debug("X-Forwarded-For header present but empty or invalid")
} }
} else { } else {
m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr") m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr")
if m.isIPBlacklisted(r.RemoteAddr) { if m.isIPBlacklisted(r.RemoteAddr) {
m.logger.Debug("Starting IP blacklist phase") m.logger.Debug("Starting IP blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr, m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule",
zap.String("message", "Request blocked by IP blacklist"), zap.String("message", "Request blocked by IP blacklist"),
) )
if m.CustomResponses != nil { if m.CustomResponses != nil {
@@ -292,7 +297,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
// DNS blacklisting // DNS blacklisting
if m.isDNSBlacklisted(r.Host) { if m.isDNSBlacklisted(r.Host) {
m.logger.Debug("Starting DNS blacklist phase") m.logger.Debug("Starting DNS blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule",
zap.String("message", "Request blocked by DNS blacklist"), zap.String("message", "Request blocked by DNS blacklist"),
zap.String("host", r.Host), zap.String("host", r.Host),
) )
@@ -309,7 +314,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
path := r.URL.Path // Get the request path path := r.URL.Path // Get the request path
if m.rateLimiter.isRateLimited(ip, path) { if m.rateLimiter.isRateLimited(ip, path) {
m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule",
zap.String("message", "Request blocked by rate limit"), zap.String("message", "Request blocked by rate limit"),
) )
if m.CustomResponses != nil { if m.CustomResponses != nil {
@@ -329,14 +334,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
r, r,
zap.Error(err), zap.Error(err),
) )
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
zap.String("message", "Request blocked due to internal error"), zap.String("message", "Request blocked due to internal error"),
) )
m.logger.Debug("Country whitelisting phase completed - blocked due to error") m.logger.Debug("Country whitelisting phase completed - blocked due to error")
m.incrementGeoIPRequestsMetric(false) // Increment with false for error m.incrementGeoIPRequestsMetric(false) // Increment with false for error
return return
} else if !allowed { } else if !allowed {
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
zap.String("message", "Request blocked by country")) zap.String("message", "Request blocked by country"))
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
if m.CustomResponses != nil { if m.CustomResponses != nil {
@@ -357,15 +362,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
r, r,
zap.Error(err), zap.Error(err),
) )
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule",
zap.String("message", "Request blocked due to internal error"), zap.String("message", "Request blocked due to internal error"),
) )
m.logger.Debug("Country blacklisting phase completed - blocked due to error") m.logger.Debug("Country blacklisting phase completed - blocked due to error")
m.incrementGeoIPRequestsMetric(false) // Increment with false for error m.incrementGeoIPRequestsMetric(false) // Increment with false for error
return return
} else if blocked { } else if blocked {
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule",
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country")) zap.String("message", "Request blocked by country"))
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
if m.CustomResponses != nil { if m.CustomResponses != nil {
@@ -388,14 +392,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
for _, rule := range rules { for _, rule := range rules {
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets))) m.logger.Debug("Processing rule", zap.String("rule_id", rule.ID), zap.Int("target_count", len(rule.Targets)))
// Use the custom type as the key // Use the custom type as the key
ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID) ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID)
r = r.WithContext(ctx) r = r.WithContext(ctx)
for _, target := range rule.Targets { for _, target := range rule.Targets {
m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID))) m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", rule.ID))
var value string var value string
var err error var err error
@@ -413,21 +417,21 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
if err != nil { if err != nil {
m.logger.Debug("Failed to extract value for target, skipping rule for this target", m.logger.Debug("Failed to extract value for target, skipping rule for this target",
zap.String("target", target), zap.String("target", target),
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.Error(err), zap.Error(err),
) )
continue continue
} }
m.logger.Debug("Extracted value", m.logger.Debug("Extracted value",
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("target", target), zap.String("target", target),
zap.String("value", value), zap.String("value", value),
) )
if rule.regex.MatchString(value) { if rule.regex.MatchString(value) {
m.logger.Debug("Rule matched", m.logger.Debug("Rule matched",
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("target", target), zap.String("target", target),
zap.String("value", value), zap.String("value", value),
) )
@@ -448,7 +452,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
if !shouldContinue || state.Blocked || state.ResponseWritten { if !shouldContinue || state.Blocked || state.ResponseWritten {
m.logger.Debug("Rule evaluation stopping due to blocking or rule directive", m.logger.Debug("Rule evaluation stopping due to blocking or rule directive",
zap.Int("phase", phase), zap.Int("phase", phase),
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.Bool("continue", shouldContinue), zap.Bool("continue", shouldContinue),
zap.Bool("blocked", state.Blocked), zap.Bool("blocked", state.Blocked),
) )
@@ -460,7 +464,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
} }
} else { } else {
m.logger.Debug("Rule did not match", m.logger.Debug("Rule did not match",
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("target", target), zap.String("target", target),
zap.String("value", value), zap.String("value", value),
) )

View File

@@ -19,6 +19,8 @@ func fileExists(path string) bool {
} }
// isIPv4 - checks if input IP is of type v4 // isIPv4 - checks if input IP is of type v4
//
//nolint:unused
func isIPv4(addr string) bool { func isIPv4(addr string) bool {
return strings.Count(addr, ":") < 2 return strings.Count(addr, ":") < 2
} }

View File

@@ -1,3 +1,5 @@
//go:build it
package caddywaf_test package caddywaf_test
import ( import (

View File

@@ -129,7 +129,7 @@ func (m *Middleware) redactSensitiveFields(fields []zap.Field) []zap.Field {
// prepareLogFields consolidates the logic for preparing log fields, including common fields and log_id. // prepareLogFields consolidates the logic for preparing log fields, including common fields and log_id.
func (m *Middleware) prepareLogFields(r *http.Request, fields []zap.Field) []zap.Field { func (m *Middleware) prepareLogFields(r *http.Request, fields []zap.Field) []zap.Field {
var logID string var logID string
var allFields []zap.Field allFields := make([]zap.Field, 0)
// Initialize with common fields // Initialize with common fields
var sourceIP, userAgent, requestMethod, requestPath, queryParams string var sourceIP, userAgent, requestMethod, requestPath, queryParams string

View File

@@ -18,7 +18,7 @@ func (m *Middleware) allowRequest(state *WAFState) {
} }
// blockRequest handles blocking a request and logging the details. // blockRequest handles blocking a request and logging the details.
func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID, matchedValue string, fields ...zap.Field) { func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID string, fields ...zap.Field) {
// CRITICAL FIX: Set these flags before any other operations // CRITICAL FIX: Set these flags before any other operations
state.Blocked = true state.Blocked = true
state.StatusCode = statusCode state.StatusCode = statusCode

View File

@@ -37,7 +37,7 @@ func TestBlockRequest(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil) r := httptest.NewRequest(http.MethodGet, "/test", nil)
state := &WAFState{} state := &WAFState{}
m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1", "match1") m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1")
assert.Equal(t, http.StatusForbidden, w.Code) assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equal(t, "Blocked", w.Body.String()) assert.Equal(t, "Blocked", w.Body.String())
@@ -56,7 +56,7 @@ func TestBlockRequest(t *testing.T) {
r = r.WithContext(ctx) r = r.WithContext(ctx)
state := &WAFState{} state := &WAFState{}
m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1", "match1") m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1")
assert.Equal(t, http.StatusForbidden, w.Code) assert.Equal(t, http.StatusForbidden, w.Code)
assert.True(t, state.Blocked) assert.True(t, state.Blocked)
@@ -75,7 +75,7 @@ func TestBlockRequest(t *testing.T) {
} }
recorder := NewResponseRecorder(w) recorder := NewResponseRecorder(w)
m.blockRequest(recorder, r, state, http.StatusForbidden, "test reason", "rule1", "match1") m.blockRequest(recorder, r, state, http.StatusForbidden, "test reason", "rule1")
assert.Equal(t, http.StatusForbidden, recorder.StatusCode()) // Check the Recorder status code instead assert.Equal(t, http.StatusForbidden, recorder.StatusCode()) // Check the Recorder status code instead
assert.True(t, state.ResponseWritten) // Check that the ResponseWritten flag is set assert.True(t, state.ResponseWritten) // Check that the ResponseWritten flag is set

View File

@@ -18,7 +18,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
logID := r.Context().Value(ContextKeyLogId("logID")).(string) logID := r.Context().Value(ContextKeyLogId("logID")).(string)
m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("target", strings.Join(rule.Targets, ",")), zap.String("target", strings.Join(rule.Targets, ",")),
zap.String("value", value), zap.String("value", value),
zap.String("description", rule.Description), zap.String("description", rule.Description),
@@ -37,7 +37,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
state.TotalScore += rule.Score state.TotalScore += rule.Score
m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // Corrected argument order - 'r' is now the third argument m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // Corrected argument order - 'r' is now the third argument
zap.String("log_id", logID), zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.Int("score_increase", rule.Score), zap.Int("score_increase", rule.Score),
zap.Int("old_score", oldScore), zap.Int("old_score", oldScore),
zap.Int("new_score", state.TotalScore), zap.Int("new_score", state.TotalScore),
@@ -51,7 +51,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
// Debug the actual action field value to verify what's being used // Debug the actual action field value to verify what's being used
m.logger.Debug("Rule action/mode check", m.logger.Debug("Rule action/mode check",
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("action_field", rule.Action), zap.String("action_field", rule.Action),
zap.Int("score", rule.Score), zap.Int("score", rule.Score),
zap.Int("threshold", m.AnomalyThreshold), zap.Int("threshold", m.AnomalyThreshold),
@@ -77,7 +77,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
state.StatusCode = http.StatusForbidden state.StatusCode = http.StatusForbidden
// Block the request and write the response immediately // Block the request and write the response immediately
m.blockRequest(w, r, state, http.StatusForbidden, blockReason, string(rule.ID), value, m.blockRequest(w, r, state, http.StatusForbidden, blockReason, rule.ID,
zap.Int("total_score", state.TotalScore), zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold), zap.Int("anomaly_threshold", m.AnomalyThreshold),
zap.String("final_block_reason", blockReason), zap.String("final_block_reason", blockReason),
@@ -92,14 +92,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru
if rule.Action == "log" { if rule.Action == "log" {
m.logRequest(zapcore.InfoLevel, "Rule action: Log", r, m.logRequest(zapcore.InfoLevel, "Rule action: Log", r,
zap.String("log_id", logID), zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.Int("total_score", state.TotalScore), // ADDED: Log total score for log action zap.Int("total_score", state.TotalScore), // ADDED: Log total score for log action
zap.Int("anomaly_threshold", m.AnomalyThreshold), // ADDED: Log anomaly threshold for log action zap.Int("anomaly_threshold", m.AnomalyThreshold), // ADDED: Log anomaly threshold for log action
) )
} else if !shouldBlock && !state.ResponseWritten { } else if !shouldBlock && !state.ResponseWritten {
m.logRequest(zapcore.DebugLevel, "Rule action: No Block", r, m.logRequest(zapcore.DebugLevel, "Rule action: No Block", r,
zap.String("log_id", logID), zap.String("log_id", logID),
zap.String("rule_id", string(rule.ID)), zap.String("rule_id", rule.ID),
zap.String("action", rule.Action), zap.String("action", rule.Action),
zap.Int("total_score", state.TotalScore), zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold), zap.Int("anomaly_threshold", m.AnomalyThreshold),
@@ -240,11 +240,11 @@ func (m *Middleware) loadRulesFromFile(path string, ruleIDs map[string]bool) (va
continue continue
} }
if _, exists := ruleIDs[string(rule.ID)]; exists { if _, exists := ruleIDs[rule.ID]; exists {
fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i)) fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i))
continue continue
} }
ruleIDs[string(rule.ID)] = true // Track rule IDs to prevent duplicates ruleIDs[rule.ID] = true // Track rule IDs to prevent duplicates
// RuleCache handling (compile and cache regex) // RuleCache handling (compile and cache regex)
if cachedRegex, exists := m.ruleCache.Get(rule.ID); exists { if cachedRegex, exists := m.ruleCache.Get(rule.ID); exists {

2
tor.go
View File

@@ -135,7 +135,7 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) {
// writeBlacklist writes the updated IP blacklist to the file. // writeBlacklist writes the updated IP blacklist to the file.
func (t *TorConfig) writeBlacklist(ips []string) error { func (t *TorConfig) writeBlacklist(ips []string) error {
data := strings.Join(ips, "\n") data := strings.Join(ips, "\n")
err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0o644) err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0o600)
if err != nil { if err != nil {
return fmt.Errorf("failed to write IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename return fmt.Errorf("failed to write IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename
} }