From c9052770585f63b4ee667c993e97d287cb61410f Mon Sep 17 00:00:00 2001 From: drev74 Date: Fri, 10 Oct 2025 20:33:20 +0300 Subject: [PATCH] feat!!: switch to go-trie --- blacklist.go | 7 +- caddywaf.go | 14 ++-- go.mod | 3 +- go.sum | 6 +- handler_test.go | 94 ++++++++++++++++++--------- ratelimiter_test.go | 3 +- request_test.go | 6 +- types.go | 154 +------------------------------------------- types_test.go | 69 -------------------- 9 files changed, 88 insertions(+), 268 deletions(-) diff --git a/blacklist.go b/blacklist.go index c5628ae..9ab18d1 100644 --- a/blacklist.go +++ b/blacklist.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "net" + "net/netip" "os" "strings" @@ -64,10 +65,7 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma } func (m *Middleware) isIPBlacklisted(ip string) bool { - if m.ipBlacklist == nil { // Defensive check: ensure ipBlacklist is not nil - return false - } - if m.ipBlacklist.Contains(ip) { + if m.ipBlacklist.Contains(netip.MustParseAddr(ip)) { m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter m.IPBlacklistBlockCount++ // Increment the counter m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter @@ -127,7 +125,6 @@ func extractIP(remoteAddr string, logger *zap.Logger) string { return host } -// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map. // LoadIPBlacklistFromFile loads IP addresses from a file into the provided map. func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error { bl.logger.Debug("Loading IP blacklist", zap.String("path", path)) diff --git a/caddywaf.go b/caddywaf.go index 2aa7936..34af732 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/netip" "os" "strings" "sync" @@ -29,6 +30,7 @@ import ( "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/oschwald/maxminddb-golang" + trie "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -235,8 +237,8 @@ func (m *Middleware) Provision(ctx caddy.Context) error { // Load IP blacklist if m.IPBlacklistFile != "" { - m.ipBlacklist = NewCIDRTrie() - err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist) + m.ipBlacklist = trie.NewTrie() + err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist) if err != nil { return fmt.Errorf("failed to load IP blacklist: %w", err) } @@ -424,8 +426,8 @@ func (m *Middleware) ReloadConfig() error { m.logger.Info("Reloading WAF configuration") if m.IPBlacklistFile != "" { - newIPBlacklist := NewCIDRTrie() - if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil { + newIPBlacklist := trie.NewTrie() + if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil { m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err)) return fmt.Errorf("failed to reload IP blacklist: %v", err) } @@ -450,7 +452,7 @@ func (m *Middleware) ReloadConfig() error { return nil } -func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error { +func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error { if _, err := os.Stat(path); os.IsNotExist(err) { m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path)) return nil @@ -464,7 +466,7 @@ func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error // Convert the map to CIDRTrie for ip := range blacklist { - blacklistMap.Insert(ip) + blacklistMap.Insert(netip.MustParsePrefix(ip), nil) } return nil } diff --git a/go.mod b/go.mod index 0e15c91..1b464c1 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/fsnotify/fsnotify v1.9.0 github.com/google/uuid v1.6.0 github.com/oschwald/maxminddb-golang v1.13.1 + github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9 github.com/stretchr/testify v1.11.1 go.uber.org/zap v1.27.0 ) @@ -89,7 +90,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect - github.com/slackhq/nebula v1.9.6 // indirect + github.com/slackhq/nebula v1.9.7 // indirect github.com/smallstep/certificates v0.28.4 // indirect github.com/smallstep/cli-utils v0.12.1 // indirect github.com/smallstep/linkedca v0.24.0 // indirect diff --git a/go.sum b/go.sum index aebfebd..16030ba 100644 --- a/go.sum +++ b/go.sum @@ -281,6 +281,8 @@ github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhM github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU= github.com/peterbourgon/diskv/v3 v3.0.1/go.mod h1:kJ5Ny7vLdARGU3WUuy6uzO6T0nb/2gWcT1JiBvRmb5o= +github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9 h1:C8IqpV7kfAyZDRCnAVNi//l1mWlpyPmq1N6DjVvYEnY= +github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9/go.mod h1:dDLiSjNqdp8VjphLdGTx19OeAUsHOzhtc1FFJqpzWMU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -343,8 +345,8 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/slackhq/nebula v1.9.6 h1:Fl0LE2dHDeVEK3R+un59Z3V4ZzbZ6q2e/zF4ClaD5yo= -github.com/slackhq/nebula v1.9.6/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ= +github.com/slackhq/nebula v1.9.7 h1:v5u46efIyYHGdfjFnozQbRRhMdaB9Ma1SSTcUcE2lfE= +github.com/slackhq/nebula v1.9.7/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/certificates v0.28.4 h1:JTU6/A5Xes6m+OsR6fw1RACSA362vJc9SOFVG7poBEw= diff --git a/handler_test.go b/handler_test.go index 45a5e06..f8bf7e0 100644 --- a/handler_test.go +++ b/handler_test.go @@ -6,12 +6,14 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "net/netip" "regexp" "strings" "testing" "time" "github.com/caddyserver/caddy/v2/modules/caddyhttp" + trie "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -23,7 +25,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) { dnsBlacklist: map[string]struct{}{ "malicious.domain": {}, }, - ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist + ipBlacklist: trie.NewTrie(), CustomResponses: map[int]CustomBlockResponse{ 403: { StatusCode: http.StatusForbidden, @@ -90,6 +92,38 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) { assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") } +func TestBlockedRequestPhase1_IPBlocking(t *testing.T) { + logger, err := zap.NewDevelopment() + assert.NoError(t, err) + + ipBlackList := trie.NewTrie() + ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1"), nil) + + middleware := &Middleware{ + logger: logger, + ipBlacklist: ipBlackList, + CustomResponses: map[int]CustomBlockResponse{ + 403: { + StatusCode: http.StatusForbidden, + Body: "Access Denied", + }, + }, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + req.RemoteAddr = "127.0.0.1" + w := httptest.NewRecorder() + state := &WAFState{} + + // Process the request in Phase 1 + middleware.handlePhase(w, req, 1, state) + + // Verify that the request was blocked + assert.True(t, state.Blocked, "Request should be blocked") + assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403") + assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'") +} + func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { logger := zap.NewNop() middleware := &Middleware{ @@ -108,7 +142,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), CustomResponses: map[int]CustomBlockResponse{ @@ -168,7 +202,7 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -220,7 +254,7 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -272,7 +306,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -325,7 +359,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -377,7 +411,7 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -429,7 +463,7 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -479,7 +513,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -531,7 +565,7 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -583,7 +617,7 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -635,7 +669,7 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -688,7 +722,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -747,7 +781,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -821,7 +855,7 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -879,7 +913,7 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -937,7 +971,7 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -991,7 +1025,7 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1049,7 +1083,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1107,7 +1141,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1174,7 +1208,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1219,7 +1253,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1274,7 +1308,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1329,7 +1363,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1382,7 +1416,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1437,7 +1471,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1489,7 +1523,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T) }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), } @@ -1594,7 +1628,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1664,7 +1698,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } @@ -1717,7 +1751,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: make(map[string]struct{}), } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index 220a8bd..e1e2156 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + trie "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -397,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) { Body: "Rate limit exceeded", }, }, - ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist + ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist } diff --git a/request_test.go b/request_test.go index 6957251..d21eebf 100644 --- a/request_test.go +++ b/request_test.go @@ -6,11 +6,13 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/netip" "sync" "testing" "time" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + trie "github.com/phemmer/go-iptrie" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -442,7 +444,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { }, }, ruleCache: NewRuleCache(), - ipBlacklist: NewCIDRTrie(), + ipBlacklist: trie.NewTrie(), dnsBlacklist: map[string]struct{}{}, requestValueExtractor: NewRequestValueExtractor(logger, false), rateLimiter: func() *RateLimiter { @@ -465,7 +467,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) { } // Add some IPs to the blacklist - middleware.ipBlacklist.Insert("192.168.1.0/24") + middleware.ipBlacklist.Insert(netip.MustParsePrefix("192.168.1.0/24"), nil) var wg sync.WaitGroup for i := 0; i < 100; i++ { diff --git a/types.go b/types.go index 12ae580..6b53844 100644 --- a/types.go +++ b/types.go @@ -1,8 +1,6 @@ package caddywaf import ( - "fmt" - "net" "regexp" "sync" "time" @@ -11,6 +9,7 @@ import ( "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/oschwald/maxminddb-golang" + trie "github.com/phemmer/go-iptrie" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -31,68 +30,6 @@ var ( type RuleID string type HitCount int -// ==================== Struct Definitions ==================== -type TrieNode struct { - children map[byte]*TrieNode - isLeaf bool -} - -func NewTrieNode() *TrieNode { - return &TrieNode{ - children: make(map[byte]*TrieNode), // Initialize the map - isLeaf: false, - } -} - -type CIDRTrie struct { - ipv4Root *TrieNode - ipv6Root *TrieNode - mu sync.RWMutex -} - -func NewCIDRTrie() *CIDRTrie { - return &CIDRTrie{ - ipv4Root: NewTrieNode(), // Initialize with a new TrieNode - ipv6Root: NewTrieNode(), // Initialize with a new TrieNode - } -} - -func (t *CIDRTrie) Insert(cidr string) error { - t.mu.Lock() - defer t.mu.Unlock() - - ip, ipNet, err := net.ParseCIDR(cidr) - if err != nil { - return err - } - - if ip.To4() != nil { - // IPv4 - return t.insertIPv4(ipNet) - } else { - // IPv6 - return t.insertIPv6(ipNet) - } -} - -func (t *CIDRTrie) Contains(ipStr string) bool { - t.mu.RLock() - defer t.mu.RUnlock() - - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } - - if ip.To4() != nil { - // IPv4 - return t.containsIPv4(ip) - } else { - // IPv6 - return t.containsIPv6(ip) - } -} - // RuleCache caches compiled regex patterns for rules. type RuleCache struct { mu sync.RWMutex @@ -167,7 +104,7 @@ type Middleware struct { CountryBlock CountryAccessFilter `json:"country_block"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` Rules map[int][]Rule `json:"-"` - ipBlacklist *CIDRTrie `json:"-"` // Changed to CIDRTrie + ipBlacklist *trie.Trie `json:"-"` dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} logger *zap.Logger LogSeverity string `json:"log_severity,omitempty"` @@ -244,90 +181,3 @@ func (rc *RuleCache) Set(ruleID string, regex *regexp.Regexp) { defer rc.mu.Unlock() rc.rules[ruleID] = regex } - -func (t *CIDRTrie) insertIPv4(ipNet *net.IPNet) error { - ip := ipNet.IP.To4() - if ip == nil { - return fmt.Errorf("invalid IPv4 address") - } - - mask, _ := ipNet.Mask.Size() - node := t.ipv4Root - - for i := 0; i < mask; i++ { - bit := (ip[i/8] >> (7 - uint(i%8))) & 1 - if node.children[bit] == nil { - node.children[bit] = NewTrieNode() // Initialize the child node - } - node = node.children[bit] - } - - node.isLeaf = true - return nil -} - -func (t *CIDRTrie) insertIPv6(ipNet *net.IPNet) error { - ip := ipNet.IP.To16() - if ip == nil { - return fmt.Errorf("invalid IPv6 address") - } - - mask, _ := ipNet.Mask.Size() - node := t.ipv6Root - - for i := 0; i < mask; i++ { - bit := (ip[i/8] >> (7 - uint(i%8))) & 1 - if node.children[bit] == nil { - node.children[bit] = NewTrieNode() // Initialize the child node - } - node = node.children[bit] - } - - node.isLeaf = true - return nil -} - -func (t *CIDRTrie) containsIPv4(ip net.IP) bool { - ip = ip.To4() - if ip == nil { - return false - } - - node := t.ipv4Root - for i := 0; i < len(ip)*8; i++ { - bit := (ip[i/8] >> (7 - uint(i%8))) & 1 - if node.children[bit] == nil { - return false - } - node = node.children[bit] - if node.isLeaf { - return true - } - } - return node.isLeaf -} - -func (t *CIDRTrie) containsIPv6(ip net.IP) bool { - ip = ip.To16() - if ip == nil { - return false - } - - // Add this check to ensure ip is not empty - if len(ip) == 0 { - return false - } - - node := t.ipv6Root - for i := 0; i < len(ip)*8; i++ { - bit := (ip[i/8] >> (7 - uint(i%8))) & 1 - if node.children[bit] == nil { - return false - } - node = node.children[bit] - if node.isLeaf { - return true - } - } - return false -} diff --git a/types_test.go b/types_test.go index 5d164d2..24a02e3 100644 --- a/types_test.go +++ b/types_test.go @@ -5,75 +5,6 @@ import ( "testing" ) -func TestNewCIDRTrie(t *testing.T) { - trie := NewCIDRTrie() - if trie == nil { - t.Fatal("NewCIDRTrie() returned nil") - } - if trie.ipv4Root == nil { - t.Fatal("NewCIDRTrie() created a trie with nil ipv4Root") - } - if trie.ipv6Root == nil { - t.Fatal("NewCIDRTrie() created a trie with nil ipv6Root") - } - if trie.ipv4Root.children == nil { - t.Fatal("NewCIDRTrie() created ipv4Root with nil children map") - } - if trie.ipv6Root.children == nil { - t.Fatal("NewCIDRTrie() created ipv6Root with nil children map") - } -} - -func TestCIDRTrie_Insert(t *testing.T) { - tests := []struct { - name string - cidr string - wantErr bool - }{ - {"valid IPv4 CIDR", "192.168.1.0/24", false}, - {"valid IPv6 CIDR", "2001:db8::/32", false}, // IPv6 is now supported - {"invalid CIDR", "invalid", true}, - {"invalid IPv4 mask", "192.168.1.0/33", true}, - {"invalid IPv6 mask", "2001:db8::/129", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - trie := NewCIDRTrie() - err := trie.Insert(tt.cidr) - if (err != nil) != tt.wantErr { - t.Errorf("CIDRTrie.Insert() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestCIDRTrie_Contains(t *testing.T) { - trie := NewCIDRTrie() - _ = trie.Insert("192.168.1.0/24") - _ = trie.Insert("2001:db8::/32") // Add an IPv6 CIDR - - tests := []struct { - name string - ip string - want bool - }{ - {"IPv4 in range", "192.168.1.1", true}, - {"IPv4 out of range", "192.168.2.1", false}, - {"Invalid IP", "invalid", false}, - {"IPv6 in range", "2001:db8::1", true}, - {"IPv6 out of range", "2001:db9::1", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := trie.Contains(tt.ip); got != tt.want { - t.Errorf("CIDRTrie.Contains() = %v, want %v", got, tt.want) - } - }) - } -} - func TestNewRuleCache(t *testing.T) { cache := NewRuleCache() if cache == nil {