From c8c0fed9e217d918abb74e6cf5aa4b1ef80830a2 Mon Sep 17 00:00:00 2001 From: drev74 Date: Wed, 22 Oct 2025 23:04:48 +0300 Subject: [PATCH] fix: lint errors --- GNUmakefile | 26 ++++++++++++++++++++++++++ caddywaf.go | 6 +----- debug_waf.go | 28 +++++++++++++++++----------- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ handler.go | 42 +++++++++++++++++++++++------------------- helpers.go | 2 ++ it_test.go | 2 ++ logging.go | 2 +- response.go | 2 +- response_test.go | 6 +++--- rules.go | 16 ++++++++-------- tor.go | 2 +- 13 files changed, 103 insertions(+), 67 deletions(-) create mode 100644 GNUmakefile diff --git a/GNUmakefile b/GNUmakefile new file mode 100644 index 0000000..6f62bf1 --- /dev/null +++ b/GNUmakefile @@ -0,0 +1,26 @@ +tidy: + @go mod tidy + @echo "Done!" + +upd: + @go get -u ./... + @echo "Done!" + +fmt: + @go fmt ./... + +test: + @go test -v ./... + @echo "Done!" + +it: + @go test -v ./... -tags=it + @echo "Done!" + +lint: + @echo "==> Checking source code with golangci-lint..." + @golangci-lint run + +lintfix: + @echo "==> Checking source code with golangci-lint..." + @golangci-lint run --fix diff --git a/caddywaf.go b/caddywaf.go index f8215ae..25bee43 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -287,17 +287,13 @@ func (m *Middleware) Shutdown(ctx context.Context) error { m.logger.Debug("Logging worker stopped.") var firstError error - var errorOccurred bool // Close GeoIP databases if m.CountryBlacklist.geoIP != nil { m.logger.Debug("Closing country blacklist GeoIP database...") if err := m.CountryBlacklist.geoIP.Close(); err != nil { m.logger.Error("Error encountered while closing country blacklist GeoIP database", zap.Error(err)) - if !errorOccurred { - firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err) - errorOccurred = true - } + firstError = fmt.Errorf("error closing country blacklist GeoIP: %w", err) } else { m.logger.Debug("Country blacklist GeoIP database closed successfully.") } diff --git a/debug_waf.go b/debug_waf.go index 4535020..c10e5de 100644 --- a/debug_waf.go +++ b/debug_waf.go @@ -59,25 +59,31 @@ func (m *Middleware) DumpRulesToFile(path string) error { } defer f.Close() - f.WriteString("=== WAF Rules Dump ===\n\n") + if _, err := f.WriteString("=== WAF Rules Dump ===\n\n"); err != nil { + return err + } for phase := 1; phase <= 4; phase++ { - f.WriteString(fmt.Sprintf("== Phase %d Rules ==\n", phase)) + fmt.Fprintf(f, "== Phase %d Rules ==\n", phase) rules, ok := m.Rules[phase] if !ok || len(rules) == 0 { - f.WriteString(" No rules for this phase\n\n") + if _, err := f.WriteString(" No rules for this phase\n\n"); err != nil { + return err + } continue } for i, rule := range rules { - f.WriteString(fmt.Sprintf(" Rule %d:\n", i+1)) - f.WriteString(fmt.Sprintf(" ID: %s\n", rule.ID)) - f.WriteString(fmt.Sprintf(" Pattern: %s\n", rule.Pattern)) - f.WriteString(fmt.Sprintf(" Targets: %v\n", rule.Targets)) - f.WriteString(fmt.Sprintf(" Score: %d\n", rule.Score)) - f.WriteString(fmt.Sprintf(" Action: %s\n", rule.Action)) - f.WriteString(fmt.Sprintf(" Description: %s\n", rule.Description)) - f.WriteString("\n") + fmt.Fprintf(f, " Rule %d:\n", i+1) + fmt.Fprintf(f, " ID: %s\n", rule.ID) + fmt.Fprintf(f, " Pattern: %s\n", rule.Pattern) + fmt.Fprintf(f, " Targets: %v\n", rule.Targets) + fmt.Fprintf(f, " Score: %d\n", rule.Score) + fmt.Fprintf(f, " Action: %s\n", rule.Action) + fmt.Fprintf(f, " Description: %s\n", rule.Description) + if _, err := f.WriteString("\n"); err != nil { + return err + } } } diff --git a/go.mod b/go.mod index ed6305e..21dc682 100644 --- a/go.mod +++ b/go.mod @@ -74,7 +74,7 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/compress v1.18.1 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/libdns/libdns v1.1.1 // indirect github.com/manifoldco/promptui v0.9.0 // indirect @@ -94,7 +94,7 @@ require ( github.com/prometheus/client_golang v1.23.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.1 // indirect - github.com/prometheus/procfs v0.17.0 // indirect + github.com/prometheus/procfs v0.18.0 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.55.0 // indirect github.com/rs/xid v1.6.0 // indirect @@ -106,7 +106,7 @@ require ( github.com/smallstep/certificates v0.28.4 // indirect github.com/smallstep/cli-utils v0.12.2 // indirect github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca // indirect - github.com/smallstep/linkedca v0.24.0 // indirect + github.com/smallstep/linkedca v0.25.0 // indirect github.com/smallstep/nosql v0.7.0 // indirect github.com/smallstep/pkcs7 v0.2.1 // indirect github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492 // indirect @@ -153,9 +153,9 @@ require ( golang.org/x/text v0.30.0 // indirect golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.38.0 // indirect - google.golang.org/api v0.252.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f // indirect + google.golang.org/api v0.253.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect google.golang.org/grpc v1.76.0 // indirect google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect google.golang.org/protobuf v1.36.10 // indirect diff --git a/go.sum b/go.sum index 6b46eb5..24c2924 100644 --- a/go.sum +++ b/go.sum @@ -249,8 +249,8 @@ github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCV github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -326,8 +326,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/common v0.67.1 h1:OTSON1P4DNxzTg4hmKCc37o4ZAZDv0cfXLkOt0oEowI= github.com/prometheus/common v0.67.1/go.mod h1:RpmT9v35q2Y+lsieQsdOh5sXZ6ajUGC8NjZAmr8vb0Q= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/prometheus/procfs v0.18.0 h1:2QTA9cKdznfYJz7EDaa7IiJobHuV7E1WzeBwcrhk0ao= +github.com/prometheus/procfs v0.18.0/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk= @@ -381,8 +381,8 @@ github.com/smallstep/cli-utils v0.12.2 h1:lGzM9PJrH/qawbzMC/s2SvgLdJPKDWKwKzx9do github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4= -github.com/smallstep/linkedca v0.24.0 h1:7nQuHLrI7XQVbZUgvNsUiW35mskyK1itsZyboZxor3E= -github.com/smallstep/linkedca v0.24.0/go.mod h1:7VovSkUuLpO4sJPUxp25aEo9+3XIcgEEMoj2noEQFcI= +github.com/smallstep/linkedca v0.25.0 h1:txT9QHGbCsJq0MhAghBq7qhurGY727tQuqUi+n4BVBo= +github.com/smallstep/linkedca v0.25.0/go.mod h1:Q3jVAauFKNlF86W5/RFtgQeyDKz98GL/KN3KG4mJOvc= github.com/smallstep/nosql v0.7.0 h1:YiWC9ZAHcrLCrayfaF+QJUv16I2bZ7KdLC3RpJcnAnE= github.com/smallstep/nosql v0.7.0/go.mod h1:H5VnKMCbeq9QA6SRY5iqPylfxLfYcLwvUff3onQ8+HU= github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= @@ -635,8 +635,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= -google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI= -google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw= +google.golang.org/api v0.253.0 h1:apU86Eq9Q2eQco3NsUYFpVTfy7DwemojL7LmbAj7g/I= +google.golang.org/api v0.253.0/go.mod h1:PX09ad0r/4du83vZVAaGg7OaeyGnaUmT/CYPNvtLCbw= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -648,10 +648,10 @@ google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f h1:1FTH6cpXFsENbPR5Bu8NQddPSaUUE6NA2XdZdDSAJK4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251014184007-4626949a642f/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 h1:mepRgnBZa07I4TRuomDE4sTIYieg/osKmzIf4USdWS4= +google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8/go.mod h1:fDMmzKV90WSg1NbozdqrE64fkuTv6mlq2zxo9ad+3yo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= diff --git a/handler.go b/handler.go index b608b26..3d62bdc 100644 --- a/handler.go +++ b/handler.go @@ -32,7 +32,14 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd ) // Return 500 error to client w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal Server Error")) + if _, err := w.Write([]byte("Internal Server Error")); err != nil { + m.logger.Error(err.Error(), + zap.String("log_id", logID), + zap.Any("panic", rec), + zap.Stack("stack"), + ) + return + } } }() @@ -252,7 +259,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i ) if phase == 1 { - // IP blacklisting - the highest priority m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) // Added log for checking before to isIPBlacklisted call xForwardedFor := r.Header.Get("X-Forwarded-For") @@ -263,7 +269,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr)) if m.isIPBlacklisted(firstIP) { m.logger.Debug("Starting IP blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP, + m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", zap.String("message", "Request blocked by IP blacklist"), ) if m.CustomResponses != nil { @@ -274,12 +280,11 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i } else { m.logger.Debug("X-Forwarded-For header present but empty or invalid") } - } else { m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr") if m.isIPBlacklisted(r.RemoteAddr) { m.logger.Debug("Starting IP blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", zap.String("message", "Request blocked by IP blacklist"), ) if m.CustomResponses != nil { @@ -292,7 +297,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i // DNS blacklisting if m.isDNSBlacklisted(r.Host) { m.logger.Debug("Starting DNS blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, + m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", zap.String("message", "Request blocked by DNS blacklist"), zap.String("host", r.Host), ) @@ -309,7 +314,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i path := r.URL.Path // Get the request path if m.rateLimiter.isRateLimited(ip, path) { m.incrementRateLimiterBlockedRequestsMetric() // Increment the counter in the Middleware - m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", zap.String("message", "Request blocked by rate limit"), ) if m.CustomResponses != nil { @@ -329,14 +334,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i r, zap.Error(err), ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", zap.String("message", "Request blocked due to internal error"), ) m.logger.Debug("Country whitelisting phase completed - blocked due to error") m.incrementGeoIPRequestsMetric(false) // Increment with false for error return } else if !allowed { - m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", zap.String("message", "Request blocked by country")) m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked if m.CustomResponses != nil { @@ -357,15 +362,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i r, zap.Error(err), ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", zap.String("message", "Request blocked due to internal error"), ) m.logger.Debug("Country blacklisting phase completed - blocked due to error") m.incrementGeoIPRequestsMetric(false) // Increment with false for error return } else if blocked { - - m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", zap.String("message", "Request blocked by country")) m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked if m.CustomResponses != nil { @@ -388,14 +392,14 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) for _, rule := range rules { - m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets))) + m.logger.Debug("Processing rule", zap.String("rule_id", rule.ID), zap.Int("target_count", len(rule.Targets))) // Use the custom type as the key ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID) r = r.WithContext(ctx) for _, target := range rule.Targets { - m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID))) + m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", rule.ID)) var value string var err error @@ -413,21 +417,21 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i if err != nil { m.logger.Debug("Failed to extract value for target, skipping rule for this target", zap.String("target", target), - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.Error(err), ) continue } m.logger.Debug("Extracted value", - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("target", target), zap.String("value", value), ) if rule.regex.MatchString(value) { m.logger.Debug("Rule matched", - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("target", target), zap.String("value", value), ) @@ -448,7 +452,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i if !shouldContinue || state.Blocked || state.ResponseWritten { m.logger.Debug("Rule evaluation stopping due to blocking or rule directive", zap.Int("phase", phase), - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.Bool("continue", shouldContinue), zap.Bool("blocked", state.Blocked), ) @@ -460,7 +464,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i } } else { m.logger.Debug("Rule did not match", - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("target", target), zap.String("value", value), ) diff --git a/helpers.go b/helpers.go index dcb57c0..f62e25c 100644 --- a/helpers.go +++ b/helpers.go @@ -19,6 +19,8 @@ func fileExists(path string) bool { } // isIPv4 - checks if input IP is of type v4 +// +//nolint:unused func isIPv4(addr string) bool { return strings.Count(addr, ":") < 2 } diff --git a/it_test.go b/it_test.go index 3e84e3d..41e1b2c 100644 --- a/it_test.go +++ b/it_test.go @@ -1,3 +1,5 @@ +//go:build it + package caddywaf_test import ( diff --git a/logging.go b/logging.go index 707152b..937b27a 100644 --- a/logging.go +++ b/logging.go @@ -129,7 +129,7 @@ func (m *Middleware) redactSensitiveFields(fields []zap.Field) []zap.Field { // prepareLogFields consolidates the logic for preparing log fields, including common fields and log_id. func (m *Middleware) prepareLogFields(r *http.Request, fields []zap.Field) []zap.Field { var logID string - var allFields []zap.Field + allFields := make([]zap.Field, 0) // Initialize with common fields var sourceIP, userAgent, requestMethod, requestPath, queryParams string diff --git a/response.go b/response.go index 2a337ec..8fb0712 100644 --- a/response.go +++ b/response.go @@ -18,7 +18,7 @@ func (m *Middleware) allowRequest(state *WAFState) { } // blockRequest handles blocking a request and logging the details. -func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID, matchedValue string, fields ...zap.Field) { +func (m *Middleware) blockRequest(recorder http.ResponseWriter, r *http.Request, state *WAFState, statusCode int, reason, ruleID string, fields ...zap.Field) { // CRITICAL FIX: Set these flags before any other operations state.Blocked = true state.StatusCode = statusCode diff --git a/response_test.go b/response_test.go index 2d22fd9..5cfb75f 100644 --- a/response_test.go +++ b/response_test.go @@ -37,7 +37,7 @@ func TestBlockRequest(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/test", nil) state := &WAFState{} - m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1", "match1") + m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1") assert.Equal(t, http.StatusForbidden, w.Code) assert.Equal(t, "Blocked", w.Body.String()) @@ -56,7 +56,7 @@ func TestBlockRequest(t *testing.T) { r = r.WithContext(ctx) state := &WAFState{} - m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1", "match1") + m.blockRequest(w, r, state, http.StatusForbidden, "test reason", "rule1") assert.Equal(t, http.StatusForbidden, w.Code) assert.True(t, state.Blocked) @@ -75,7 +75,7 @@ func TestBlockRequest(t *testing.T) { } recorder := NewResponseRecorder(w) - m.blockRequest(recorder, r, state, http.StatusForbidden, "test reason", "rule1", "match1") + m.blockRequest(recorder, r, state, http.StatusForbidden, "test reason", "rule1") assert.Equal(t, http.StatusForbidden, recorder.StatusCode()) // Check the Recorder status code instead assert.True(t, state.ResponseWritten) // Check that the ResponseWritten flag is set diff --git a/rules.go b/rules.go index f148ab3..b38b5b7 100644 --- a/rules.go +++ b/rules.go @@ -18,7 +18,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru logID := r.Context().Value(ContextKeyLogId("logID")).(string) m.logRequest(zapcore.DebugLevel, "Rule Matched", r, // More concise log message - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("target", strings.Join(rule.Targets, ",")), zap.String("value", value), zap.String("description", rule.Description), @@ -37,7 +37,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru state.TotalScore += rule.Score m.logRequest(zapcore.DebugLevel, "Anomaly score increased", r, // Corrected argument order - 'r' is now the third argument zap.String("log_id", logID), - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.Int("score_increase", rule.Score), zap.Int("old_score", oldScore), zap.Int("new_score", state.TotalScore), @@ -51,7 +51,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru // Debug the actual action field value to verify what's being used m.logger.Debug("Rule action/mode check", - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("action_field", rule.Action), zap.Int("score", rule.Score), zap.Int("threshold", m.AnomalyThreshold), @@ -77,7 +77,7 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru state.StatusCode = http.StatusForbidden // Block the request and write the response immediately - m.blockRequest(w, r, state, http.StatusForbidden, blockReason, string(rule.ID), value, + m.blockRequest(w, r, state, http.StatusForbidden, blockReason, rule.ID, zap.Int("total_score", state.TotalScore), zap.Int("anomaly_threshold", m.AnomalyThreshold), zap.String("final_block_reason", blockReason), @@ -92,14 +92,14 @@ func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, ru if rule.Action == "log" { m.logRequest(zapcore.InfoLevel, "Rule action: Log", r, zap.String("log_id", logID), - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.Int("total_score", state.TotalScore), // ADDED: Log total score for log action zap.Int("anomaly_threshold", m.AnomalyThreshold), // ADDED: Log anomaly threshold for log action ) } else if !shouldBlock && !state.ResponseWritten { m.logRequest(zapcore.DebugLevel, "Rule action: No Block", r, zap.String("log_id", logID), - zap.String("rule_id", string(rule.ID)), + zap.String("rule_id", rule.ID), zap.String("action", rule.Action), zap.Int("total_score", state.TotalScore), zap.Int("anomaly_threshold", m.AnomalyThreshold), @@ -240,11 +240,11 @@ func (m *Middleware) loadRulesFromFile(path string, ruleIDs map[string]bool) (va continue } - if _, exists := ruleIDs[string(rule.ID)]; exists { + if _, exists := ruleIDs[rule.ID]; exists { fileInvalidRules = append(fileInvalidRules, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i)) continue } - ruleIDs[string(rule.ID)] = true // Track rule IDs to prevent duplicates + ruleIDs[rule.ID] = true // Track rule IDs to prevent duplicates // RuleCache handling (compile and cache regex) if cachedRegex, exists := m.ruleCache.Get(rule.ID); exists { diff --git a/tor.go b/tor.go index 44a9c0e..e1e1337 100644 --- a/tor.go +++ b/tor.go @@ -135,7 +135,7 @@ func (t *TorConfig) readExistingBlacklist() ([]string, error) { // writeBlacklist writes the updated IP blacklist to the file. func (t *TorConfig) writeBlacklist(ips []string) error { data := strings.Join(ips, "\n") - err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0o644) + err := os.WriteFile(t.TORIPBlacklistFile, []byte(data), 0o600) if err != nil { return fmt.Errorf("failed to write IP blacklist file %s: %w", t.TORIPBlacklistFile, err) // Improved error message with filename }