mirror of
https://github.com/fabriziosalmi/caddy-waf.git
synced 2025-12-23 22:27:46 -05:00
feat!!: switch to go-trie
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -64,10 +65,7 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) isIPBlacklisted(ip string) bool {
|
func (m *Middleware) isIPBlacklisted(ip string) bool {
|
||||||
if m.ipBlacklist == nil { // Defensive check: ensure ipBlacklist is not nil
|
if m.ipBlacklist.Contains(netip.MustParseAddr(ip)) {
|
||||||
return false
|
|
||||||
}
|
|
||||||
if m.ipBlacklist.Contains(ip) {
|
|
||||||
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
|
m.muIPBlacklistMetrics.Lock() // Acquire lock before accessing shared counter
|
||||||
m.IPBlacklistBlockCount++ // Increment the counter
|
m.IPBlacklistBlockCount++ // Increment the counter
|
||||||
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter
|
m.muIPBlacklistMetrics.Unlock() // Release lock after accessing counter
|
||||||
@@ -127,7 +125,6 @@ func extractIP(remoteAddr string, logger *zap.Logger) string {
|
|||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
|
|
||||||
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
|
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
|
||||||
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
|
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
|
||||||
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))
|
bl.logger.Debug("Loading IP blacklist", zap.String("path", path))
|
||||||
|
|||||||
14
caddywaf.go
14
caddywaf.go
@@ -20,6 +20,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -29,6 +30,7 @@ import (
|
|||||||
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
|
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
|
||||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
trie "github.com/phemmer/go-iptrie"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
|
|
||||||
@@ -235,8 +237,8 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
|
|||||||
|
|
||||||
// Load IP blacklist
|
// Load IP blacklist
|
||||||
if m.IPBlacklistFile != "" {
|
if m.IPBlacklistFile != "" {
|
||||||
m.ipBlacklist = NewCIDRTrie()
|
m.ipBlacklist = trie.NewTrie()
|
||||||
err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist)
|
err = m.loadIPBlacklist(m.IPBlacklistFile, *m.ipBlacklist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
return fmt.Errorf("failed to load IP blacklist: %w", err)
|
||||||
}
|
}
|
||||||
@@ -424,8 +426,8 @@ func (m *Middleware) ReloadConfig() error {
|
|||||||
|
|
||||||
m.logger.Info("Reloading WAF configuration")
|
m.logger.Info("Reloading WAF configuration")
|
||||||
if m.IPBlacklistFile != "" {
|
if m.IPBlacklistFile != "" {
|
||||||
newIPBlacklist := NewCIDRTrie()
|
newIPBlacklist := trie.NewTrie()
|
||||||
if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil {
|
if err := m.loadIPBlacklist(m.IPBlacklistFile, *newIPBlacklist); err != nil {
|
||||||
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
|
m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err))
|
||||||
return fmt.Errorf("failed to reload IP blacklist: %v", err)
|
return fmt.Errorf("failed to reload IP blacklist: %v", err)
|
||||||
}
|
}
|
||||||
@@ -450,7 +452,7 @@ func (m *Middleware) ReloadConfig() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error {
|
func (m *Middleware) loadIPBlacklist(path string, blacklistMap trie.Trie) error {
|
||||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
|
m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path))
|
||||||
return nil
|
return nil
|
||||||
@@ -464,7 +466,7 @@ func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error
|
|||||||
|
|
||||||
// Convert the map to CIDRTrie
|
// Convert the map to CIDRTrie
|
||||||
for ip := range blacklist {
|
for ip := range blacklist {
|
||||||
blacklistMap.Insert(ip)
|
blacklistMap.Insert(netip.MustParsePrefix(ip), nil)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/oschwald/maxminddb-golang v1.13.1
|
github.com/oschwald/maxminddb-golang v1.13.1
|
||||||
|
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
)
|
)
|
||||||
@@ -89,7 +90,7 @@ require (
|
|||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
github.com/shopspring/decimal v1.4.0 // indirect
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
||||||
github.com/slackhq/nebula v1.9.6 // indirect
|
github.com/slackhq/nebula v1.9.7 // indirect
|
||||||
github.com/smallstep/certificates v0.28.4 // indirect
|
github.com/smallstep/certificates v0.28.4 // indirect
|
||||||
github.com/smallstep/cli-utils v0.12.1 // indirect
|
github.com/smallstep/cli-utils v0.12.1 // indirect
|
||||||
github.com/smallstep/linkedca v0.24.0 // indirect
|
github.com/smallstep/linkedca v0.24.0 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -281,6 +281,8 @@ github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhM
|
|||||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||||
github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU=
|
github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU=
|
||||||
github.com/peterbourgon/diskv/v3 v3.0.1/go.mod h1:kJ5Ny7vLdARGU3WUuy6uzO6T0nb/2gWcT1JiBvRmb5o=
|
github.com/peterbourgon/diskv/v3 v3.0.1/go.mod h1:kJ5Ny7vLdARGU3WUuy6uzO6T0nb/2gWcT1JiBvRmb5o=
|
||||||
|
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9 h1:C8IqpV7kfAyZDRCnAVNi//l1mWlpyPmq1N6DjVvYEnY=
|
||||||
|
github.com/phemmer/go-iptrie v0.0.0-20240326174613-ba542f5282c9/go.mod h1:dDLiSjNqdp8VjphLdGTx19OeAUsHOzhtc1FFJqpzWMU=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
@@ -343,8 +345,8 @@ github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5k
|
|||||||
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/slackhq/nebula v1.9.6 h1:Fl0LE2dHDeVEK3R+un59Z3V4ZzbZ6q2e/zF4ClaD5yo=
|
github.com/slackhq/nebula v1.9.7 h1:v5u46efIyYHGdfjFnozQbRRhMdaB9Ma1SSTcUcE2lfE=
|
||||||
github.com/slackhq/nebula v1.9.6/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ=
|
github.com/slackhq/nebula v1.9.7/go.mod h1:1+4q4wd3dDAjO8rKCttSb9JIVbklQhuJiBp5I0lbIsQ=
|
||||||
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
|
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
|
||||||
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
|
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
|
||||||
github.com/smallstep/certificates v0.28.4 h1:JTU6/A5Xes6m+OsR6fw1RACSA362vJc9SOFVG7poBEw=
|
github.com/smallstep/certificates v0.28.4 h1:JTU6/A5Xes6m+OsR6fw1RACSA362vJc9SOFVG7poBEw=
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||||
|
trie "github.com/phemmer/go-iptrie"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -23,7 +25,7 @@ func TestBlockedRequestPhase1_DNSBlacklist(t *testing.T) {
|
|||||||
dnsBlacklist: map[string]struct{}{
|
dnsBlacklist: map[string]struct{}{
|
||||||
"malicious.domain": {},
|
"malicious.domain": {},
|
||||||
},
|
},
|
||||||
ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist
|
ipBlacklist: trie.NewTrie(),
|
||||||
CustomResponses: map[int]CustomBlockResponse{
|
CustomResponses: map[int]CustomBlockResponse{
|
||||||
403: {
|
403: {
|
||||||
StatusCode: http.StatusForbidden,
|
StatusCode: http.StatusForbidden,
|
||||||
@@ -90,6 +92,38 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
|
|||||||
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
|
||||||
|
logger, err := zap.NewDevelopment()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ipBlackList := trie.NewTrie()
|
||||||
|
ipBlackList.Insert(netip.MustParsePrefix("127.0.0.1"), nil)
|
||||||
|
|
||||||
|
middleware := &Middleware{
|
||||||
|
logger: logger,
|
||||||
|
ipBlacklist: ipBlackList,
|
||||||
|
CustomResponses: map[int]CustomBlockResponse{
|
||||||
|
403: {
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Body: "Access Denied",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
state := &WAFState{}
|
||||||
|
|
||||||
|
// Process the request in Phase 1
|
||||||
|
middleware.handlePhase(w, req, 1, state)
|
||||||
|
|
||||||
|
// Verify that the request was blocked
|
||||||
|
assert.True(t, state.Blocked, "Request should be blocked")
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code, "Expected status code 403")
|
||||||
|
assert.Contains(t, w.Body.String(), "Access Denied", "Response body should contain 'Access Denied'")
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
||||||
logger := zap.NewNop()
|
logger := zap.NewNop()
|
||||||
middleware := &Middleware{
|
middleware := &Middleware{
|
||||||
@@ -108,7 +142,7 @@ func TestHandlePhase_Phase2_NiktoUserAgent(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
CustomResponses: map[int]CustomBlockResponse{
|
CustomResponses: map[int]CustomBlockResponse{
|
||||||
@@ -168,7 +202,7 @@ func TestBlockedRequestPhase1_HeaderRegex(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -220,7 +254,7 @@ func TestBlockedRequestPhase1_HeaderRegex_SpecificValue(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -272,7 +306,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CommaSeparatedTargets(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -325,7 +359,7 @@ func TestBlockedRequestPhase1_CombinedConditions(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -377,7 +411,7 @@ func TestBlockedRequestPhase1_NoMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -429,7 +463,7 @@ func TestBlockedRequestPhase1_HeaderRegex_EmptyHeader(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -479,7 +513,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MissingHeader(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -531,7 +565,7 @@ func TestBlockedRequestPhase1_HeaderRegex_ComplexPattern(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -583,7 +617,7 @@ func TestBlockedRequestPhase1_MultiTargetMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -635,7 +669,7 @@ func TestBlockedRequestPhase1_MultiTargetNoMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -688,7 +722,7 @@ func TestBlockedRequestPhase1_URLParameterRegex_NoMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -747,7 +781,7 @@ func TestBlockedRequestPhase1_MultipleRules(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -821,7 +855,7 @@ func TestBlockedRequestPhase2_BodyRegex(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -879,7 +913,7 @@ func TestBlockedRequestPhase2_BodyRegex_JSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -937,7 +971,7 @@ func TestBlockedRequestPhase2_BodyRegex_FormURLEncoded(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -991,7 +1025,7 @@ func TestBlockedRequestPhase2_BodyRegex_SpecificPattern(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1049,7 +1083,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1107,7 +1141,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoMatch_MultipartForm(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1174,7 +1208,7 @@ func TestBlockedRequestPhase2_BodyRegex_NoBody(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1219,7 +1253,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoMatch(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1274,7 +1308,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_EmptyBody(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1329,7 +1363,7 @@ func TestBlockedRequestPhase4_ResponseBodyRegex_NoBody(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1382,7 +1416,7 @@ func TestBlockedRequestPhase3_ResponseHeaderRegex_NoSetCookie(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1437,7 +1471,7 @@ func TestBlockedRequestPhase1_HeaderRegex_CaseInsensitive(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1489,7 +1523,7 @@ func TestBlockedRequestPhase1_HeaderRegex_MultipleMatchingHeaders(t *testing.T)
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
}
|
}
|
||||||
@@ -1594,7 +1628,7 @@ func TestBlockedRequestPhase1_RateLimiting_MultiplePaths(t *testing.T) {
|
|||||||
Body: "Rate limit exceeded",
|
Body: "Rate limit exceeded",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: make(map[string]struct{}),
|
dnsBlacklist: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1664,7 +1698,7 @@ func TestBlockedRequestPhase1_RateLimiting_DifferentIPs(t *testing.T) {
|
|||||||
Body: "Rate limit exceeded",
|
Body: "Rate limit exceeded",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: make(map[string]struct{}),
|
dnsBlacklist: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1717,7 +1751,7 @@ func TestBlockedRequestPhase1_RateLimiting_MatchAllPaths(t *testing.T) {
|
|||||||
Body: "Rate limit exceeded",
|
Body: "Rate limit exceeded",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: make(map[string]struct{}),
|
dnsBlacklist: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
trie "github.com/phemmer/go-iptrie"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -397,7 +398,7 @@ func TestBlockedRequestPhase1_RateLimiting(t *testing.T) {
|
|||||||
Body: "Rate limit exceeded",
|
Body: "Rate limit exceeded",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ipBlacklist: NewCIDRTrie(), // Initialize ipBlacklist
|
ipBlacklist: trie.NewTrie(), // Initialize ipBlacklist
|
||||||
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
|
dnsBlacklist: make(map[string]struct{}), // Initialize dnsBlacklist
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||||
|
trie "github.com/phemmer/go-iptrie"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
@@ -442,7 +444,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleCache: NewRuleCache(),
|
ruleCache: NewRuleCache(),
|
||||||
ipBlacklist: NewCIDRTrie(),
|
ipBlacklist: trie.NewTrie(),
|
||||||
dnsBlacklist: map[string]struct{}{},
|
dnsBlacklist: map[string]struct{}{},
|
||||||
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
requestValueExtractor: NewRequestValueExtractor(logger, false),
|
||||||
rateLimiter: func() *RateLimiter {
|
rateLimiter: func() *RateLimiter {
|
||||||
@@ -465,7 +467,7 @@ func TestConcurrentRuleEvaluation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add some IPs to the blacklist
|
// Add some IPs to the blacklist
|
||||||
middleware.ipBlacklist.Insert("192.168.1.0/24")
|
middleware.ipBlacklist.Insert(netip.MustParsePrefix("192.168.1.0/24"), nil)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
|
|||||||
154
types.go
154
types.go
@@ -1,8 +1,6 @@
|
|||||||
package caddywaf
|
package caddywaf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -11,6 +9,7 @@ import (
|
|||||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
trie "github.com/phemmer/go-iptrie"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
)
|
)
|
||||||
@@ -31,68 +30,6 @@ var (
|
|||||||
type RuleID string
|
type RuleID string
|
||||||
type HitCount int
|
type HitCount int
|
||||||
|
|
||||||
// ==================== Struct Definitions ====================
|
|
||||||
type TrieNode struct {
|
|
||||||
children map[byte]*TrieNode
|
|
||||||
isLeaf bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTrieNode() *TrieNode {
|
|
||||||
return &TrieNode{
|
|
||||||
children: make(map[byte]*TrieNode), // Initialize the map
|
|
||||||
isLeaf: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CIDRTrie struct {
|
|
||||||
ipv4Root *TrieNode
|
|
||||||
ipv6Root *TrieNode
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCIDRTrie() *CIDRTrie {
|
|
||||||
return &CIDRTrie{
|
|
||||||
ipv4Root: NewTrieNode(), // Initialize with a new TrieNode
|
|
||||||
ipv6Root: NewTrieNode(), // Initialize with a new TrieNode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CIDRTrie) Insert(cidr string) error {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
ip, ipNet, err := net.ParseCIDR(cidr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
// IPv4
|
|
||||||
return t.insertIPv4(ipNet)
|
|
||||||
} else {
|
|
||||||
// IPv6
|
|
||||||
return t.insertIPv6(ipNet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CIDRTrie) Contains(ipStr string) bool {
|
|
||||||
t.mu.RLock()
|
|
||||||
defer t.mu.RUnlock()
|
|
||||||
|
|
||||||
ip := net.ParseIP(ipStr)
|
|
||||||
if ip == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
// IPv4
|
|
||||||
return t.containsIPv4(ip)
|
|
||||||
} else {
|
|
||||||
// IPv6
|
|
||||||
return t.containsIPv6(ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleCache caches compiled regex patterns for rules.
|
// RuleCache caches compiled regex patterns for rules.
|
||||||
type RuleCache struct {
|
type RuleCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -167,7 +104,7 @@ type Middleware struct {
|
|||||||
CountryBlock CountryAccessFilter `json:"country_block"`
|
CountryBlock CountryAccessFilter `json:"country_block"`
|
||||||
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
|
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
|
||||||
Rules map[int][]Rule `json:"-"`
|
Rules map[int][]Rule `json:"-"`
|
||||||
ipBlacklist *CIDRTrie `json:"-"` // Changed to CIDRTrie
|
ipBlacklist *trie.Trie `json:"-"`
|
||||||
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
|
dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{}
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
LogSeverity string `json:"log_severity,omitempty"`
|
LogSeverity string `json:"log_severity,omitempty"`
|
||||||
@@ -244,90 +181,3 @@ func (rc *RuleCache) Set(ruleID string, regex *regexp.Regexp) {
|
|||||||
defer rc.mu.Unlock()
|
defer rc.mu.Unlock()
|
||||||
rc.rules[ruleID] = regex
|
rc.rules[ruleID] = regex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *CIDRTrie) insertIPv4(ipNet *net.IPNet) error {
|
|
||||||
ip := ipNet.IP.To4()
|
|
||||||
if ip == nil {
|
|
||||||
return fmt.Errorf("invalid IPv4 address")
|
|
||||||
}
|
|
||||||
|
|
||||||
mask, _ := ipNet.Mask.Size()
|
|
||||||
node := t.ipv4Root
|
|
||||||
|
|
||||||
for i := 0; i < mask; i++ {
|
|
||||||
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
|
|
||||||
if node.children[bit] == nil {
|
|
||||||
node.children[bit] = NewTrieNode() // Initialize the child node
|
|
||||||
}
|
|
||||||
node = node.children[bit]
|
|
||||||
}
|
|
||||||
|
|
||||||
node.isLeaf = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CIDRTrie) insertIPv6(ipNet *net.IPNet) error {
|
|
||||||
ip := ipNet.IP.To16()
|
|
||||||
if ip == nil {
|
|
||||||
return fmt.Errorf("invalid IPv6 address")
|
|
||||||
}
|
|
||||||
|
|
||||||
mask, _ := ipNet.Mask.Size()
|
|
||||||
node := t.ipv6Root
|
|
||||||
|
|
||||||
for i := 0; i < mask; i++ {
|
|
||||||
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
|
|
||||||
if node.children[bit] == nil {
|
|
||||||
node.children[bit] = NewTrieNode() // Initialize the child node
|
|
||||||
}
|
|
||||||
node = node.children[bit]
|
|
||||||
}
|
|
||||||
|
|
||||||
node.isLeaf = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CIDRTrie) containsIPv4(ip net.IP) bool {
|
|
||||||
ip = ip.To4()
|
|
||||||
if ip == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
node := t.ipv4Root
|
|
||||||
for i := 0; i < len(ip)*8; i++ {
|
|
||||||
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
|
|
||||||
if node.children[bit] == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
node = node.children[bit]
|
|
||||||
if node.isLeaf {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node.isLeaf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CIDRTrie) containsIPv6(ip net.IP) bool {
|
|
||||||
ip = ip.To16()
|
|
||||||
if ip == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add this check to ensure ip is not empty
|
|
||||||
if len(ip) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
node := t.ipv6Root
|
|
||||||
for i := 0; i < len(ip)*8; i++ {
|
|
||||||
bit := (ip[i/8] >> (7 - uint(i%8))) & 1
|
|
||||||
if node.children[bit] == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
node = node.children[bit]
|
|
||||||
if node.isLeaf {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,75 +5,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewCIDRTrie(t *testing.T) {
|
|
||||||
trie := NewCIDRTrie()
|
|
||||||
if trie == nil {
|
|
||||||
t.Fatal("NewCIDRTrie() returned nil")
|
|
||||||
}
|
|
||||||
if trie.ipv4Root == nil {
|
|
||||||
t.Fatal("NewCIDRTrie() created a trie with nil ipv4Root")
|
|
||||||
}
|
|
||||||
if trie.ipv6Root == nil {
|
|
||||||
t.Fatal("NewCIDRTrie() created a trie with nil ipv6Root")
|
|
||||||
}
|
|
||||||
if trie.ipv4Root.children == nil {
|
|
||||||
t.Fatal("NewCIDRTrie() created ipv4Root with nil children map")
|
|
||||||
}
|
|
||||||
if trie.ipv6Root.children == nil {
|
|
||||||
t.Fatal("NewCIDRTrie() created ipv6Root with nil children map")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCIDRTrie_Insert(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cidr string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"valid IPv4 CIDR", "192.168.1.0/24", false},
|
|
||||||
{"valid IPv6 CIDR", "2001:db8::/32", false}, // IPv6 is now supported
|
|
||||||
{"invalid CIDR", "invalid", true},
|
|
||||||
{"invalid IPv4 mask", "192.168.1.0/33", true},
|
|
||||||
{"invalid IPv6 mask", "2001:db8::/129", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
trie := NewCIDRTrie()
|
|
||||||
err := trie.Insert(tt.cidr)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("CIDRTrie.Insert() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCIDRTrie_Contains(t *testing.T) {
|
|
||||||
trie := NewCIDRTrie()
|
|
||||||
_ = trie.Insert("192.168.1.0/24")
|
|
||||||
_ = trie.Insert("2001:db8::/32") // Add an IPv6 CIDR
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"IPv4 in range", "192.168.1.1", true},
|
|
||||||
{"IPv4 out of range", "192.168.2.1", false},
|
|
||||||
{"Invalid IP", "invalid", false},
|
|
||||||
{"IPv6 in range", "2001:db8::1", true},
|
|
||||||
{"IPv6 out of range", "2001:db9::1", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := trie.Contains(tt.ip); got != tt.want {
|
|
||||||
t.Errorf("CIDRTrie.Contains() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewRuleCache(t *testing.T) {
|
func TestNewRuleCache(t *testing.T) {
|
||||||
cache := NewRuleCache()
|
cache := NewRuleCache()
|
||||||
if cache == nil {
|
if cache == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user