Feat: Implement ASN Blocking (#73)

This commit is contained in:
Fabrizio Salmi
2025-12-06 22:18:10 +01:00
parent 34d7a29119
commit 1da1fea22b
5 changed files with 155 additions and 0 deletions

View File

@@ -220,6 +220,22 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
} }
} }
// Configure ASN blocking
if m.BlockASNs.Enabled {
if !fileExists(m.BlockASNs.GeoIPDBPath) {
m.logger.Warn("ASN GeoIP database not found. ASN blocking will be disabled", zap.String("path", m.BlockASNs.GeoIPDBPath))
} else {
reader, err := maxminddb.Open(m.BlockASNs.GeoIPDBPath)
if err != nil {
m.logger.Error("Failed to load ASN GeoIP database", zap.String("path", m.BlockASNs.GeoIPDBPath), zap.Error(err))
} else {
m.logger.Info("ASN GeoIP database loaded successfully", zap.String("path", m.BlockASNs.GeoIPDBPath))
m.BlockASNs.geoIP = reader
}
}
}
// Initialize config and blacklist loaders // Initialize config and blacklist loaders
m.configLoader = NewConfigLoader(m.logger) m.configLoader = NewConfigLoader(m.logger)
m.blacklistLoader = NewBlacklistLoader(m.logger) m.blacklistLoader = NewBlacklistLoader(m.logger)
@@ -321,6 +337,19 @@ func (m *Middleware) Shutdown(ctx context.Context) error {
m.logger.Debug("Country whitelist GeoIP database was not open, skipping close.") m.logger.Debug("Country whitelist GeoIP database was not open, skipping close.")
} }
if m.BlockASNs.geoIP != nil {
m.logger.Debug("Closing ASN GeoIP database...")
if err := m.BlockASNs.geoIP.Close(); err != nil {
m.logger.Error("Error encountered while closing ASN GeoIP database", zap.Error(err))
if firstError == nil {
firstError = fmt.Errorf("error closing ASN GeoIP: %w", err)
}
} else {
m.logger.Debug("ASN GeoIP database closed successfully.")
}
m.BlockASNs.geoIP = nil
}
// Log rule hit statistics // Log rule hit statistics
m.logger.Info("Rule Hit Statistics:") m.logger.Info("Rule Hit Statistics:")
m.ruleHits.Range(func(key, value interface{}) bool { m.ruleHits.Range(func(key, value interface{}) bool {

View File

@@ -145,6 +145,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
m.LogFilePath = "debug.json" m.LogFilePath = "debug.json"
m.RedactSensitiveData = false m.RedactSensitiveData = false
m.LogBuffer = 1000 m.LogBuffer = 1000
m.BlockASNs.Enabled = false // Default to false
directiveHandlers := map[string]func(d *caddyfile.Dispenser, m *Middleware) error{ directiveHandlers := map[string]func(d *caddyfile.Dispenser, m *Middleware) error{
"metrics_endpoint": cl.parseMetricsEndpoint, "metrics_endpoint": cl.parseMetricsEndpoint,
@@ -152,6 +153,7 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware
"rate_limit": cl.parseRateLimit, "rate_limit": cl.parseRateLimit,
"block_countries": cl.parseCountryBlockDirective(true), // Use directive-specific helper "block_countries": cl.parseCountryBlockDirective(true), // Use directive-specific helper
"whitelist_countries": cl.parseCountryBlockDirective(false), // Use directive-specific helper "whitelist_countries": cl.parseCountryBlockDirective(false), // Use directive-specific helper
"block_asns": cl.parseBlockASNsDirective, // Add ASN block directive
"log_severity": cl.parseLogSeverity, "log_severity": cl.parseLogSeverity,
"log_json": cl.parseLogJSON, "log_json": cl.parseLogJSON,
"rule_file": cl.parseRuleFile, "rule_file": cl.parseRuleFile,
@@ -300,6 +302,31 @@ func (cl *ConfigLoader) parseCountryBlockDirective(isBlock bool) func(d *caddyfi
} }
} }
// parseBlockASNsDirective handles the block_asns directive
func (cl *ConfigLoader) parseBlockASNsDirective(d *caddyfile.Dispenser, m *Middleware) error {
target := &m.BlockASNs
target.Enabled = true
if !d.NextArg() {
return d.ArgErr()
}
target.GeoIPDBPath = d.Val()
target.BlockedASNs = []string{}
for d.NextArg() {
asn := d.Val()
target.BlockedASNs = append(target.BlockedASNs, asn)
}
cl.logger.Debug("ASN block list configured",
zap.Strings("asns", target.BlockedASNs),
zap.String("geoip_db_path", target.GeoIPDBPath),
zap.String("file", d.File()),
zap.Int("line", d.Line()),
)
return nil
}
func (cl *ConfigLoader) parseLogSeverity(d *caddyfile.Dispenser, m *Middleware) error { func (cl *ConfigLoader) parseLogSeverity(d *caddyfile.Dispenser, m *Middleware) error {
if !d.NextArg() { if !d.NextArg() {
return d.ArgErr() return d.ArgErr()

View File

@@ -3,6 +3,7 @@ package caddywaf
import ( import (
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -197,3 +198,56 @@ func (gh *GeoIPHandler) cacheGeoIPRecord(ip string, record GeoIPRecord) {
}) })
} }
} }
// IsASNInList checks if an IP belongs to a list of blocked ASNs
func (gh *GeoIPHandler) IsASNInList(remoteAddr string, blockedASNs []string, geoIP *maxminddb.Reader) (bool, error) {
if geoIP == nil {
return false, fmt.Errorf("geoip database not loaded")
}
// Extract IP address without port
ip := extractIP(remoteAddr)
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
gh.logger.Debug("Invalid IP address for ASN lookup", zap.String("ip", ip))
return false, fmt.Errorf("invalid IP address: %s", ip)
}
var record ASNRecord
err := geoIP.Lookup(parsedIP, &record)
if err != nil {
gh.logger.Error("GeoIP ASN lookup failed", zap.String("ip", ip), zap.Error(err))
return false, fmt.Errorf("geoip lookup failed: %w", err)
}
asnStr := strconv.FormatUint(uint64(record.AutonomousSystemNumber), 10)
for _, blockedASN := range blockedASNs {
if asnStr == blockedASN {
return true, nil
}
}
return false, nil
}
// GetASN extracts the ASN for logging purposes
func (gh *GeoIPHandler) GetASN(remoteAddr string, geoIP *maxminddb.Reader) string {
if geoIP == nil {
return "N/A"
}
ipConf, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ipConf = remoteAddr
}
parsedIP := net.ParseIP(ipConf)
if parsedIP == nil {
return "N/A"
}
var record ASNRecord
err = geoIP.Lookup(parsedIP, &record)
if err != nil {
return "N/A"
}
return fmt.Sprintf("AS%d %s", record.AutonomousSystemNumber, record.AutonomousSystemOrganization)
}

View File

@@ -353,6 +353,36 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
m.incrementGeoIPRequestsMetric(false) // Increment with false for no block m.incrementGeoIPRequestsMetric(false) // Increment with false for no block
} }
// ASN Blocking
if m.BlockASNs.Enabled {
m.logger.Debug("Starting ASN blocking phase")
blocked, err := m.geoIPHandler.IsASNInList(r.RemoteAddr, m.BlockASNs.BlockedASNs, m.BlockASNs.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check ASN blocking",
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "asn_block_rule",
zap.String("message", "Request blocked due to internal error"),
)
m.logger.Debug("ASN blocking phase completed - blocked due to error")
m.incrementGeoIPRequestsMetric(false) // Increment with false for error
return
} else if blocked {
asnInfo := m.geoIPHandler.GetASN(r.RemoteAddr, m.BlockASNs.geoIP)
m.blockRequest(w, r, state, http.StatusForbidden, "asn_block", "asn_block_rule",
zap.String("message", "Request blocked by ASN"),
zap.String("asn", asnInfo),
)
m.incrementGeoIPRequestsMetric(true) // Increment with true for blocked
if m.CustomResponses != nil {
m.writeCustomResponse(w, state.StatusCode)
}
return
}
m.logger.Debug("ASN blocking phase completed - not blocked")
}
// Blacklisting // Blacklisting
if m.CountryBlacklist.Enabled { if m.CountryBlacklist.Enabled {
m.logger.Debug("Starting country blacklisting phase") m.logger.Debug("Starting country blacklisting phase")

View File

@@ -47,6 +47,14 @@ type CountryAccessFilter struct {
geoIP *maxminddb.Reader `json:"-"` // Explicitly mark as not serialized geoIP *maxminddb.Reader `json:"-"` // Explicitly mark as not serialized
} }
// ASNAccessFilter struct
type ASNAccessFilter struct {
Enabled bool `json:"enabled"`
BlockedASNs []string `json:"blocked_asns"`
GeoIPDBPath string `json:"geoip_db_path"`
geoIP *maxminddb.Reader `json:"-"` // Explicitly mark as not serialized
}
// GeoIPRecord struct // GeoIPRecord struct
type GeoIPRecord struct { type GeoIPRecord struct {
Country struct { Country struct {
@@ -54,6 +62,12 @@ type GeoIPRecord struct {
} `maxminddb:"country"` } `maxminddb:"country"`
} }
// ASNRecord struct
type ASNRecord struct {
AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"`
AutonomousSystemNumber uint `maxminddb:"autonomous_system_number"`
}
// Rule struct // Rule struct
type Rule struct { type Rule struct {
ID string `json:"id"` ID string `json:"id"`
@@ -106,6 +120,7 @@ type Middleware struct {
AnomalyThreshold int `json:"anomaly_threshold"` AnomalyThreshold int `json:"anomaly_threshold"`
CountryBlacklist CountryAccessFilter `json:"country_blacklist"` CountryBlacklist CountryAccessFilter `json:"country_blacklist"`
CountryWhitelist CountryAccessFilter `json:"country_whitelist"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
BlockASNs ASNAccessFilter `json:"block_asns"`
Rules map[int][]Rule `json:"-"` Rules map[int][]Rule `json:"-"`
ipBlacklist *iptrie.Trie `json:"-"` ipBlacklist *iptrie.Trie `json:"-"`
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}