From 1da1fea22ba212d4ad87a1db85fdddd0ab43c626 Mon Sep 17 00:00:00 2001 From: Fabrizio Salmi Date: Sat, 6 Dec 2025 22:18:10 +0100 Subject: [PATCH] Feat: Implement ASN Blocking (#73) --- caddywaf.go | 29 ++++++++++++++++++++++++++++ config.go | 27 +++++++++++++++++++++++++++ geoip.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ handler.go | 30 +++++++++++++++++++++++++++++ types.go | 15 +++++++++++++++ 5 files changed, 155 insertions(+) diff --git a/caddywaf.go b/caddywaf.go index 62966fe..b40898f 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -220,6 +220,22 @@ func (m *Middleware) Provision(ctx caddy.Context) error { } } + + // Configure ASN blocking + if m.BlockASNs.Enabled { + if !fileExists(m.BlockASNs.GeoIPDBPath) { + m.logger.Warn("ASN GeoIP database not found. ASN blocking will be disabled", zap.String("path", m.BlockASNs.GeoIPDBPath)) + } else { + reader, err := maxminddb.Open(m.BlockASNs.GeoIPDBPath) + if err != nil { + m.logger.Error("Failed to load ASN GeoIP database", zap.String("path", m.BlockASNs.GeoIPDBPath), zap.Error(err)) + } else { + m.logger.Info("ASN GeoIP database loaded successfully", zap.String("path", m.BlockASNs.GeoIPDBPath)) + m.BlockASNs.geoIP = reader + } + } + } + // Initialize config and blacklist loaders m.configLoader = NewConfigLoader(m.logger) m.blacklistLoader = NewBlacklistLoader(m.logger) @@ -321,6 +337,19 @@ func (m *Middleware) Shutdown(ctx context.Context) error { m.logger.Debug("Country whitelist GeoIP database was not open, skipping close.") } + if m.BlockASNs.geoIP != nil { + m.logger.Debug("Closing ASN GeoIP database...") + if err := m.BlockASNs.geoIP.Close(); err != nil { + m.logger.Error("Error encountered while closing ASN GeoIP database", zap.Error(err)) + if firstError == nil { + firstError = fmt.Errorf("error closing ASN GeoIP: %w", err) + } + } else { + m.logger.Debug("ASN GeoIP database closed successfully.") + } + m.BlockASNs.geoIP = nil + } + // Log rule hit statistics m.logger.Info("Rule Hit Statistics:") m.ruleHits.Range(func(key, value interface{}) bool { diff --git a/config.go b/config.go index 231c3c3..dc5edd0 100644 --- a/config.go +++ b/config.go @@ -145,6 +145,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware m.LogFilePath = "debug.json" m.RedactSensitiveData = false m.LogBuffer = 1000 + m.BlockASNs.Enabled = false // Default to false directiveHandlers := map[string]func(d *caddyfile.Dispenser, m *Middleware) error{ "metrics_endpoint": cl.parseMetricsEndpoint, @@ -152,6 +153,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware "rate_limit": cl.parseRateLimit, "block_countries": cl.parseCountryBlockDirective(true), // Use directive-specific helper "whitelist_countries": cl.parseCountryBlockDirective(false), // Use directive-specific helper + "block_asns": cl.parseBlockASNsDirective, // Add ASN block directive "log_severity": cl.parseLogSeverity, "log_json": cl.parseLogJSON, "rule_file": cl.parseRuleFile, @@ -300,6 +302,31 @@ func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfi } } +// parseBlockASNsDirective handles the block_asns directive +func (cl *ConfigLoader) parseBlockASNsDirective(d *caddyfile.Dispenser, m *Middleware) error { + target := &m.BlockASNs + target.Enabled = true + + if !d.NextArg() { + return d.ArgErr() + } + target.GeoIPDBPath = d.Val() + target.BlockedASNs = []string{} + + for d.NextArg() { + asn := d.Val() + target.BlockedASNs = append(target.BlockedASNs, asn) + } + + cl.logger.Debug("ASN block list configured", + zap.Strings("asns", target.BlockedASNs), + zap.String("geoip_db_path", target.GeoIPDBPath), + zap.String("file", d.File()), + zap.Int("line", d.Line()), + ) + return nil +} + func (cl *ConfigLoader) parseLogSeverity(d *caddyfile.Dispenser, m *Middleware) error { if !d.NextArg() { return d.ArgErr() diff --git a/geoip.go b/geoip.go index ee39fe4..5251357 100644 --- a/geoip.go +++ b/geoip.go @@ -3,6 +3,7 @@ package caddywaf import ( "fmt" "net" + "strconv" "strings" "sync" "time" @@ -197,3 +198,56 @@ func (gh *GeoIPHandler) cacheGeoIPRecord(ip string, record GeoIPRecord) { }) } } + +// IsASNInList checks if an IP belongs to a list of blocked ASNs +func (gh *GeoIPHandler) IsASNInList(remoteAddr string, blockedASNs []string, geoIP *maxminddb.Reader) (bool, error) { + if geoIP == nil { + return false, fmt.Errorf("geoip database not loaded") + } + + // Extract IP address without port + ip := extractIP(remoteAddr) + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + gh.logger.Debug("Invalid IP address for ASN lookup", zap.String("ip", ip)) + return false, fmt.Errorf("invalid IP address: %s", ip) + } + + var record ASNRecord + err := geoIP.Lookup(parsedIP, &record) + if err != nil { + gh.logger.Error("GeoIP ASN lookup failed", zap.String("ip", ip), zap.Error(err)) + return false, fmt.Errorf("geoip lookup failed: %w", err) + } + + asnStr := strconv.FormatUint(uint64(record.AutonomousSystemNumber), 10) + for _, blockedASN := range blockedASNs { + if asnStr == blockedASN { + return true, nil + } + } + return false, nil +} + +// GetASN extracts the ASN for logging purposes +func (gh *GeoIPHandler) GetASN(remoteAddr string, geoIP *maxminddb.Reader) string { + if geoIP == nil { + return "N/A" + } + ipConf, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + ipConf = remoteAddr + } + + parsedIP := net.ParseIP(ipConf) + if parsedIP == nil { + return "N/A" + } + + var record ASNRecord + err = geoIP.Lookup(parsedIP, &record) + if err != nil { + return "N/A" + } + return fmt.Sprintf("AS%d %s", record.AutonomousSystemNumber, record.AutonomousSystemOrganization) +} diff --git a/handler.go b/handler.go index 3d62bdc..6aadb2a 100644 --- a/handler.go +++ b/handler.go @@ -353,6 +353,36 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.incrementGeoIPRequestsMetric(false) // Increment with false for no block } + // ASN Blocking + if m.BlockASNs.Enabled { + m.logger.Debug("Starting ASN blocking phase") + blocked, err := m.geoIPHandler.IsASNInList(r.RemoteAddr, m.BlockASNs.BlockedASNs, m.BlockASNs.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check ASN blocking", + r, + zap.Error(err), + ) + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule", + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("ASN blocking phase completed - blocked due to error") + m.incrementGeoIPRequestsMetric(false) // Increment with false for error + return + } else if blocked { + asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP) + m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule", + zap.String("message", "Request blocked by ASN"), + zap.String("asn", asnInfo), + ) + m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked + if m.CustomResponses != nil { + m.writeCustomResponse(w, state.StatusCode) + } + return + } + m.logger.Debug("ASN blocking phase completed - not blocked") + } + // Blacklisting if m.CountryBlacklist.Enabled { m.logger.Debug("Starting country blacklisting phase") diff --git a/types.go b/types.go index b4024f5..3e8e221 100644 --- a/types.go +++ b/types.go @@ -47,6 +47,14 @@ type CountryAccessFilter struct { geoIP *maxminddb.Reader `json:"-"` // Explicitly mark as not serialized } +// ASNAccessFilter struct +type ASNAccessFilter struct { + Enabled bool `json:"enabled"` + BlockedASNs []string `json:"blocked_asns"` + GeoIPDBPath string `json:"geoip_db_path"` + geoIP *maxminddb.Reader `json:"-"` // Explicitly mark as not serialized +} + // GeoIPRecord struct type GeoIPRecord struct { Country struct { @@ -54,6 +62,12 @@ type GeoIPRecord struct { } `maxminddb:"country"` } +// ASNRecord struct +type ASNRecord struct { + AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"` + AutonomousSystemNumber uint `maxminddb:"autonomous_system_number"` +} + // Rule struct type Rule struct { ID string `json:"id"` @@ -106,6 +120,7 @@ type Middleware struct { AnomalyThreshold int `json:"anomaly_threshold"` CountryBlacklist CountryAccessFilter `json:"country_blacklist"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` + BlockASNs ASNAccessFilter `json:"block_asns"` Rules map[int][]Rule `json:"-"` ipBlacklist *iptrie.Trie `json:"-"` dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}