feat!!: switch to go-trie

This commit is contained in:
drev74
2025-10-10 20:33:20 +03:00
parent 6429e286fd
commit c905277058
9 changed files with 88 additions and 268 deletions

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"fmt"
"net"
"net/netip"
"os"
"strings"
@@ -64,10 +65,7 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
}
func (m *Middleware) isIPBlacklisted(ip string) bool {
if m.ipBlacklist == nil { // Defensive check: ensure ipBlacklist is not nil
return false
}
if m.ipBlacklist.Contains(ip) {
if m.ipBlacklist.Contains(netip.MustParseAddr(ip)) {
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
m.IPBlacklistBlockCount++ // Increment the counter
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter
@@ -127,7 +125,6 @@ func extractIP(remoteAddr string, logger *zap.Logger) string {
return host
}
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))

View File

@@ -20,6 +20,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"strings"
"sync"
@@ -29,6 +30,7 @@ import (
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -235,8 +237,8 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
// Load IP blacklist
if m.IPBlacklistFile != "" {
m.ipBlacklist = NewCIDRTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist)
m.ipBlacklist = trie.NewTrie()
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
if err != nil {
return fmt.Errorf("failed to load IP blacklist: %w", err)
}
@@ -424,8 +426,8 @@ func (m *Middleware) ReloadConfig() error {
m.logger.Info("Reloading WAF configuration")
if m.IPBlacklistFile != "" {
newIPBlacklist := NewCIDRTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil {
newIPBlacklist := trie.NewTrie()
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
return fmt.Errorf("failed to reload IP blacklist: %v", err)
}
@@ -450,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
return nil
}
func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error {
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
return nil
@@ -464,7 +466,7 @@ func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error
// Convert the map to CIDRTrie
for ip := range blacklist {
blacklistMap.Insert(ip)
blacklistMap.Insert(netip.MustParsePrefix(ip), nil)
}
return nil
}

3
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/google/uuid v1.6.0
github.com/oschwald/maxminddb-golang v1.13.1
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9
github.com/stretchr/testify v1.11.1
go.uber.org/zap v1.27.0
)
@@ -89,7 +90,7 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/slackhq/nebula v1.9.6 // indirect
github.com/slackhq/nebula v1.9.7 // indirect
github.com/smallstep/certificates v0.28.4 // indirect
github.com/smallstep/cli-utils v0.12.1 // indirect
github.com/smallstep/linkedca v0.24.0 // indirect

6
go.sum
View File

@@ -281,6 +281,8 @@ github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhM
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU=
github.com/peterbourgon/diskv/v3 v3.0.1/go.mod h1:kJ5Ny7vLdARGU3WUuy6uzO6T0nb/2gWcT1JiBvRmb5o=
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9 h1:C8IqpV7kfAyZDRCnAVNi//l1mWlpyPmq1N6DjVvYEnY=
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9/go.mod h1:dDLiSjNqdp8VjphLdGTx19OeAUsHOzhtc1FFJqpzWMU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -343,8 +345,8 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/slackhq/nebula v1.9.6 h1:Fl0LE2dHDeVEK3R+un59Z3V4ZzbZ6q2e/zF4ClaD5yo=
github.com/slackhq/nebula v1.9.6/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ=
github.com/slackhq/nebula v1.9.7 h1:v5u46efIyYHGdfjFnozQbRRhMdaB9Ma1SSTcUcE2lfE=
github.com/slackhq/nebula v1.9.7/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/certificates v0.28.4 h1:JTU6/A5Xes6m+OsR6fw1RACSA362vJc9SOFVG7poBEw=

View File

@@ -6,12 +6,14 @@ import (
"mime/multipart"
"net/http"
"net/http/httptest"
"net/netip"
"regexp"
"strings"
"testing"
"time"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
trie "github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
@@ -23,7 +25,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
dnsBlacklist: map[string]struct{}{
"malicious.domain": {},
},
ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist
ipBlacklist: trie.NewTrie(),
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
@@ -90,6 +92,38 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
logger, err := zap.NewDevelopment()
assert.NoError(t, err)
ipBlackList := trie.NewTrie()
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1"), nil)
middleware := &Middleware{
logger: logger,
ipBlacklist: ipBlackList,
CustomResponses: map[int]CustomBlockResponse{
403: {
StatusCode: http.StatusForbidden,
Body: "Access Denied",
},
},
}
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = "127.0.0.1"
w := httptest.NewRecorder()
state := &WAFState{}
// Process the request in Phase 1
middleware.handlePhase(w, req, 1, state)
// Verify that the request was blocked
assert.True(t, state.Blocked, "Request should be blocked")
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
}
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
logger := zap.NewNop()
middleware := &Middleware{
@@ -108,7 +142,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
CustomResponses: map[int]CustomBlockResponse{
@@ -168,7 +202,7 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -220,7 +254,7 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -272,7 +306,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -325,7 +359,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -377,7 +411,7 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -429,7 +463,7 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -479,7 +513,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -531,7 +565,7 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -583,7 +617,7 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -635,7 +669,7 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -688,7 +722,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -747,7 +781,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -821,7 +855,7 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -879,7 +913,7 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -937,7 +971,7 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -991,7 +1025,7 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1049,7 +1083,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1107,7 +1141,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1174,7 +1208,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1219,7 +1253,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1274,7 +1308,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1329,7 +1363,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1382,7 +1416,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1437,7 +1471,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1489,7 +1523,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
}
@@ -1594,7 +1628,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1664,7 +1698,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}
@@ -1717,7 +1751,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: make(map[string]struct{}),
}

View File

@@ -7,6 +7,7 @@ import (
"testing"
"time"
trie "github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
@@ -397,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
Body: "Rate limit exceeded",
},
},
ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
}

View File

@@ -6,11 +6,13 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
trie "github.com/phemmer/go-iptrie"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -442,7 +444,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
},
},
ruleCache: NewRuleCache(),
ipBlacklist: NewCIDRTrie(),
ipBlacklist: trie.NewTrie(),
dnsBlacklist: map[string]struct{}{},
requestValueExtractor: NewRequestValueExtractor(logger, false),
rateLimiter: func() *RateLimiter {
@@ -465,7 +467,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
}
// Add some IPs to the blacklist
middleware.ipBlacklist.Insert("192.168.1.0/24")
middleware.ipBlacklist.Insert(netip.MustParsePrefix("192.168.1.0/24"), nil)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {

154
types.go
View File

@@ -1,8 +1,6 @@
package caddywaf
import (
"fmt"
"net"
"regexp"
"sync"
"time"
@@ -11,6 +9,7 @@ import (
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/oschwald/maxminddb-golang"
trie "github.com/phemmer/go-iptrie"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
@@ -31,68 +30,6 @@ var (
type RuleID string
type HitCount int
// ==================== Struct Definitions ====================
type TrieNode struct {
children map[byte]*TrieNode
isLeaf bool
}
func NewTrieNode() *TrieNode {
return &TrieNode{
children: make(map[byte]*TrieNode), // Initialize the map
isLeaf: false,
}
}
type CIDRTrie struct {
ipv4Root *TrieNode
ipv6Root *TrieNode
mu sync.RWMutex
}
func NewCIDRTrie() *CIDRTrie {
return &CIDRTrie{
ipv4Root: NewTrieNode(), // Initialize with a new TrieNode
ipv6Root: NewTrieNode(), // Initialize with a new TrieNode
}
}
func (t *CIDRTrie) Insert(cidr string) error {
t.mu.Lock()
defer t.mu.Unlock()
ip, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
if ip.To4() != nil {
// IPv4
return t.insertIPv4(ipNet)
} else {
// IPv6
return t.insertIPv6(ipNet)
}
}
func (t *CIDRTrie) Contains(ipStr string) bool {
t.mu.RLock()
defer t.mu.RUnlock()
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.To4() != nil {
// IPv4
return t.containsIPv4(ip)
} else {
// IPv6
return t.containsIPv6(ip)
}
}
// RuleCache caches compiled regex patterns for rules.
type RuleCache struct {
mu sync.RWMutex
@@ -167,7 +104,7 @@ type Middleware struct {
CountryBlock CountryAccessFilter `json:"country_block"`
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
Rules map[int][]Rule `json:"-"`
ipBlacklist *CIDRTrie `json:"-"` // Changed to CIDRTrie
ipBlacklist *trie.Trie `json:"-"`
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
logger *zap.Logger
LogSeverity string `json:"log_severity,omitempty"`
@@ -244,90 +181,3 @@ func (rc *RuleCache) Set(ruleID string, regex *regexp.Regexp) {
defer rc.mu.Unlock()
rc.rules[ruleID] = regex
}
func (t *CIDRTrie) insertIPv4(ipNet *net.IPNet) error {
ip := ipNet.IP.To4()
if ip == nil {
return fmt.Errorf("invalid IPv4 address")
}
mask, _ := ipNet.Mask.Size()
node := t.ipv4Root
for i := 0; i < mask; i++ {
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
if node.children[bit] == nil {
node.children[bit] = NewTrieNode() // Initialize the child node
}
node = node.children[bit]
}
node.isLeaf = true
return nil
}
func (t *CIDRTrie) insertIPv6(ipNet *net.IPNet) error {
ip := ipNet.IP.To16()
if ip == nil {
return fmt.Errorf("invalid IPv6 address")
}
mask, _ := ipNet.Mask.Size()
node := t.ipv6Root
for i := 0; i < mask; i++ {
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
if node.children[bit] == nil {
node.children[bit] = NewTrieNode() // Initialize the child node
}
node = node.children[bit]
}
node.isLeaf = true
return nil
}
func (t *CIDRTrie) containsIPv4(ip net.IP) bool {
ip = ip.To4()
if ip == nil {
return false
}
node := t.ipv4Root
for i := 0; i < len(ip)*8; i++ {
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
if node.children[bit] == nil {
return false
}
node = node.children[bit]
if node.isLeaf {
return true
}
}
return node.isLeaf
}
func (t *CIDRTrie) containsIPv6(ip net.IP) bool {
ip = ip.To16()
if ip == nil {
return false
}
// Add this check to ensure ip is not empty
if len(ip) == 0 {
return false
}
node := t.ipv6Root
for i := 0; i < len(ip)*8; i++ {
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
if node.children[bit] == nil {
return false
}
node = node.children[bit]
if node.isLeaf {
return true
}
}
return false
}

View File

@@ -5,75 +5,6 @@ import (
"testing"
)
func TestNewCIDRTrie(t *testing.T) {
trie := NewCIDRTrie()
if trie == nil {
t.Fatal("NewCIDRTrie() returned nil")
}
if trie.ipv4Root == nil {
t.Fatal("NewCIDRTrie() created a trie with nil ipv4Root")
}
if trie.ipv6Root == nil {
t.Fatal("NewCIDRTrie() created a trie with nil ipv6Root")
}
if trie.ipv4Root.children == nil {
t.Fatal("NewCIDRTrie() created ipv4Root with nil children map")
}
if trie.ipv6Root.children == nil {
t.Fatal("NewCIDRTrie() created ipv6Root with nil children map")
}
}
func TestCIDRTrie_Insert(t *testing.T) {
tests := []struct {
name string
cidr string
wantErr bool
}{
{"valid IPv4 CIDR", "192.168.1.0/24", false},
{"valid IPv6 CIDR", "2001:db8::/32", false}, // IPv6 is now supported
{"invalid CIDR", "invalid", true},
{"invalid IPv4 mask", "192.168.1.0/33", true},
{"invalid IPv6 mask", "2001:db8::/129", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trie := NewCIDRTrie()
err := trie.Insert(tt.cidr)
if (err != nil) != tt.wantErr {
t.Errorf("CIDRTrie.Insert() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCIDRTrie_Contains(t *testing.T) {
trie := NewCIDRTrie()
_ = trie.Insert("192.168.1.0/24")
_ = trie.Insert("2001:db8::/32") // Add an IPv6 CIDR
tests := []struct {
name string
ip string
want bool
}{
{"IPv4 in range", "192.168.1.1", true},
{"IPv4 out of range", "192.168.2.1", false},
{"Invalid IP", "invalid", false},
{"IPv6 in range", "2001:db8::1", true},
{"IPv6 out of range", "2001:db9::1", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := trie.Contains(tt.ip); got != tt.want {
t.Errorf("CIDRTrie.Contains() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewRuleCache(t *testing.T) {
cache := NewRuleCache()
if cache == nil {