Files
AdGuardDNS/internal/agdtest/interface.go
Ainar Garipov b4faca20be Sync v2.21.0
2026-04-17 13:49:23 +03:00

707 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package agdtest
import (
"context"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardDNS/internal/access"
"github.com/AdguardTeam/AdGuardDNS/internal/agd"
"github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd"
"github.com/AdguardTeam/AdGuardDNS/internal/billstat"
"github.com/AdguardTeam/AdGuardDNS/internal/dnscheck"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsdb"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext"
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/ratelimit"
"github.com/AdguardTeam/AdGuardDNS/internal/errcoll"
"github.com/AdguardTeam/AdGuardDNS/internal/filter"
"github.com/AdguardTeam/AdGuardDNS/internal/filter/filterindex"
"github.com/AdguardTeam/AdGuardDNS/internal/geoip"
"github.com/AdguardTeam/AdGuardDNS/internal/profiledb"
"github.com/AdguardTeam/AdGuardDNS/internal/querylog"
"github.com/AdguardTeam/AdGuardDNS/internal/remotekv"
"github.com/AdguardTeam/AdGuardDNS/internal/rulestat"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
)
// Interface Mocks
//
// Keep entities within a module/package in alphabetic order.
// Module AdGuardDNS
// Package access
// type check
var _ access.Interface = (*AccessManager)(nil)
// AccessManager is a [access.Interface] for tests.
type AccessManager struct {
OnIsBlockedHost func(host string, qt uint16) (blocked bool)
OnIsBlockedIP func(ip netip.Addr) (blocked bool)
}
// IsBlockedHost implements the [access.Interface] interface for *AccessManager.
func (a *AccessManager) IsBlockedHost(host string, qt uint16) (blocked bool) {
return a.OnIsBlockedHost(host, qt)
}
// IsBlockedIP implements the [access.Interface] interface for *AccessManager.
func (a *AccessManager) IsBlockedIP(ip netip.Addr) (blocked bool) {
return a.OnIsBlockedIP(ip)
}
// Package agd
// type check
var _ agd.DeviceFinder = (*DeviceFinder)(nil)
// DeviceFinder is an [agd.DeviceFinder] for tests.
type DeviceFinder struct {
OnFind func(
ctx context.Context,
req *dns.Msg,
raddr netip.AddrPort,
laddr netip.AddrPort,
) (r agd.DeviceResult)
}
// Find implements the [agd.DeviceFinder] interface for *DeviceFinder.
func (f *DeviceFinder) Find(
ctx context.Context,
req *dns.Msg,
raddr netip.AddrPort,
laddr netip.AddrPort,
) (r agd.DeviceResult) {
return f.OnFind(ctx, req, raddr, laddr)
}
// Package agdpasswd
// type check
var _ agdpasswd.Authenticator = (*Authenticator)(nil)
// Authenticator is an [agdpasswd.Authenticator] for tests.
type Authenticator struct {
OnAuthenticate func(ctx context.Context, passwd []byte) (ok bool)
}
// Authenticate implements the [agdpasswd.Authenticator] interface for
// *Authenticator.
func (a *Authenticator) Authenticate(ctx context.Context, passwd []byte) (ok bool) {
return a.OnAuthenticate(ctx, passwd)
}
// Package billstat
// type check
var _ billstat.Recorder = (*BillStatRecorder)(nil)
// BillStatRecorder is a [billstat.Recorder] for tests.
type BillStatRecorder struct {
OnRecord func(
ctx context.Context,
id agd.DeviceID,
ctry geoip.Country,
asn geoip.ASN,
start time.Time,
proto agd.Protocol,
)
}
// Record implements the [billstat.Recorder] interface for *BillStatRecorder.
func (r *BillStatRecorder) Record(
ctx context.Context,
id agd.DeviceID,
ctry geoip.Country,
asn geoip.ASN,
start time.Time,
proto agd.Protocol,
) {
r.OnRecord(ctx, id, ctry, asn, start, proto)
}
// type check
var _ billstat.Uploader = (*BillStatUploader)(nil)
// BillStatUploader is a [billstat.Uploader] for tests.
type BillStatUploader struct {
OnUpload func(ctx context.Context, records billstat.Records) (err error)
}
// Upload implements the [billstat.Uploader] interface for *BillStatUploader.
func (b *BillStatUploader) Upload(ctx context.Context, records billstat.Records) (err error) {
return b.OnUpload(ctx, records)
}
// Package dnscheck
// type check
var _ dnscheck.Interface = (*DNSCheck)(nil)
// DNSCheck is a [dnscheck.Interface] for tests.
type DNSCheck struct {
OnCheck func(ctx context.Context, req *dns.Msg, ri *agd.RequestInfo) (reqp *dns.Msg, err error)
}
// Check implements the dnscheck.Interface interface for *DNSCheck.
func (db *DNSCheck) Check(
ctx context.Context,
req *dns.Msg,
ri *agd.RequestInfo,
) (resp *dns.Msg, err error) {
return db.OnCheck(ctx, req, ri)
}
// Package dnsdb
// type check
var _ dnsdb.Interface = (*DNSDB)(nil)
// DNSDB is a [dnsdb.Interface] for tests.
type DNSDB struct {
OnRecord func(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo)
}
// Record implements the [dnsdb.Interface] interface for *DNSDB.
func (db *DNSDB) Record(ctx context.Context, resp *dns.Msg, ri *agd.RequestInfo) {
db.OnRecord(ctx, resp, ri)
}
// Package errcoll
// type check
var _ errcoll.Interface = (*ErrorCollector)(nil)
// ErrorCollector is an [errcoll.Interface] for tests.
//
// TODO(a.garipov): Actually test the error collection where this is used.
type ErrorCollector struct {
OnCollect func(ctx context.Context, err error)
}
// Collect implements the [errcoll.Interface] interface for *ErrorCollector.
func (c *ErrorCollector) Collect(ctx context.Context, err error) {
c.OnCollect(ctx, err)
}
// NewErrorCollector returns a new *ErrorCollector all methods of which panic.
func NewErrorCollector() (c *ErrorCollector) {
return &ErrorCollector{
OnCollect: func(ctx context.Context, err error) {
panic(testutil.UnexpectedCall(ctx, err))
},
}
}
// Package filter
// type check
var _ filter.Interface = (*Filter)(nil)
// Filter is a [filter.Interface] for tests.
type Filter struct {
OnFilterRequest func(ctx context.Context, req *filter.Request) (r filter.Result, err error)
OnFilterResponse func(ctx context.Context, resp *filter.Response) (r filter.Result, err error)
}
// FilterRequest implements the [filter.Interface] interface for *Filter.
func (f *Filter) FilterRequest(
ctx context.Context,
req *filter.Request,
) (r filter.Result, err error) {
return f.OnFilterRequest(ctx, req)
}
// FilterResponse implements the [filter.Interface] interface for *Filter.
func (f *Filter) FilterResponse(
ctx context.Context,
resp *filter.Response,
) (r filter.Result, err error) {
return f.OnFilterResponse(ctx, resp)
}
// NewFilter returns a new *Filter all methods of which panic.
func NewFilter() (f *Filter) {
return &Filter{
OnFilterRequest: func(
ctx context.Context,
req *filter.Request,
) (r filter.Result, err error) {
panic(testutil.UnexpectedCall(ctx, req))
},
OnFilterResponse: func(
ctx context.Context,
resp *filter.Response,
) (r filter.Result, err error) {
panic(testutil.UnexpectedCall(ctx, resp))
},
}
}
// type check
var _ filter.HashMatcher = (*HashMatcher)(nil)
// HashMatcher is a [filter.HashMatcher] for tests.
type HashMatcher struct {
OnMatchByPrefix func(
ctx context.Context,
host string,
) (hashes []string, matched bool, err error)
}
// MatchByPrefix implements the [filter.HashMatcher] interface for *HashMatcher.
func (m *HashMatcher) MatchByPrefix(
ctx context.Context,
host string,
) (hashes []string, matched bool, err error) {
return m.OnMatchByPrefix(ctx, host)
}
// type check
var _ filter.Storage = (*FilterStorage)(nil)
// FilterStorage is a [filter.Storage] for tests.
type FilterStorage struct {
OnForConfig func(ctx context.Context, c filter.Config) (f filter.Interface)
OnHasListID func(id filter.ID) (ok bool)
}
// ForConfig implements the [filter.Storage] interface for
// *FilterStorage.
func (s *FilterStorage) ForConfig(ctx context.Context, c filter.Config) (f filter.Interface) {
return s.OnForConfig(ctx, c)
}
// HasListID implements the [filter.Storage] interface for *FilterStorage.
func (s *FilterStorage) HasListID(id filter.ID) (ok bool) {
return s.OnHasListID(id)
}
// Package filterindex
// type check
var _ filterindex.Storage = (*FilterIndexStorage)(nil)
// FilterIndexStorage is a [filterindex.Storage] for tests.
type FilterIndexStorage struct {
OnTyposquatting func(ctx context.Context) (idx *filterindex.Typosquatting, err error)
}
// Typosquatting implements the [filterindex.Storage] interface for
// *FilterIndexStorage.
func (s *FilterIndexStorage) Typosquatting(
ctx context.Context,
) (idx *filterindex.Typosquatting, err error) {
return s.OnTyposquatting(ctx)
}
// Package geoip
// type check
var _ geoip.Interface = (*GeoIP)(nil)
// GeoIP is a [geoip.Interface] for tests.
type GeoIP struct {
OnData func(
сtx context.Context,
host string,
ip netip.Addr,
) (l *geoip.Location, err error)
OnSubnetByLocation func(
ctx context.Context,
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error)
}
// Data implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) Data(
ctx context.Context,
host string,
ip netip.Addr,
) (l *geoip.Location, err error) {
return g.OnData(ctx, host, ip)
}
// SubnetByLocation implements the [geoip.Interface] interface for *GeoIP.
func (g *GeoIP) SubnetByLocation(
ctx context.Context,
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) {
return g.OnSubnetByLocation(ctx, l, fam)
}
// NewGeoIP returns a new *GeoIP all methods of which panic.
func NewGeoIP() (c *GeoIP) {
return &GeoIP{
OnData: func(
ctx context.Context,
host string,
ip netip.Addr,
) (l *geoip.Location, err error) {
panic(testutil.UnexpectedCall(ctx, host, ip))
},
OnSubnetByLocation: func(
ctx context.Context,
l *geoip.Location,
fam netutil.AddrFamily,
) (n netip.Prefix, err error) {
panic(testutil.UnexpectedCall(ctx, l, fam))
},
}
}
// Package profiledb
// type check
var _ profiledb.Interface = (*ProfileDB)(nil)
// ProfileDB is a [profiledb.Interface] for tests.
type ProfileDB struct {
OnCreateAutoDevice func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error)
OnProfileByDedicatedIP func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error)
OnProfileByDeviceID func(
ctx context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error)
OnProfileByHumanID func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error)
OnProfileByLinkedIP func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error)
}
// CreateAutoDevice implements the [profiledb.Interface] interface for
// *ProfileDB.
func (db *ProfileDB) CreateAutoDevice(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
return db.OnCreateAutoDevice(ctx, id, humanID, devType)
}
// ProfileByDedicatedIP implements the [profiledb.Interface] interface for
// *ProfileDB.
func (db *ProfileDB) ProfileByDedicatedIP(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return db.OnProfileByDedicatedIP(ctx, ip)
}
// ProfileByDeviceID implements the [profiledb.Interface] interface for
// *ProfileDB.
func (db *ProfileDB) ProfileByDeviceID(
ctx context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
return db.OnProfileByDeviceID(ctx, id)
}
// ProfileByHumanID implements the [profiledb.Interface] interface for
// *ProfileDB.
func (db *ProfileDB) ProfileByHumanID(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
return db.OnProfileByHumanID(ctx, id, humanID)
}
// ProfileByLinkedIP implements the [profiledb.Interface] interface for
// *ProfileDB.
func (db *ProfileDB) ProfileByLinkedIP(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
return db.OnProfileByLinkedIP(ctx, ip)
}
// NewProfileDB returns a new *ProfileDB all methods of which panic.
func NewProfileDB() (db *ProfileDB) {
return &ProfileDB{
OnCreateAutoDevice: func(
ctx context.Context,
id agd.ProfileID,
humanID agd.HumanID,
devType agd.DeviceType,
) (p *agd.Profile, d *agd.Device, err error) {
panic(testutil.UnexpectedCall(ctx, id, humanID, devType))
},
OnProfileByDedicatedIP: func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(testutil.UnexpectedCall(ctx, ip))
},
OnProfileByDeviceID: func(
ctx context.Context,
id agd.DeviceID,
) (p *agd.Profile, d *agd.Device, err error) {
panic(testutil.UnexpectedCall(ctx, id))
},
OnProfileByHumanID: func(
ctx context.Context,
profID agd.ProfileID,
humanID agd.HumanIDLower,
) (p *agd.Profile, d *agd.Device, err error) {
panic(testutil.UnexpectedCall(ctx, profID, humanID))
},
OnProfileByLinkedIP: func(
ctx context.Context,
ip netip.Addr,
) (p *agd.Profile, d *agd.Device, err error) {
panic(testutil.UnexpectedCall(ctx, ip))
},
}
}
// type check
var _ profiledb.Storage = (*ProfileStorage)(nil)
// ProfileStorage is a [profiledb.Storage] implementation for tests.
type ProfileStorage struct {
OnCreateAutoDevice func(
ctx context.Context,
req *profiledb.StorageCreateAutoDeviceRequest,
) (resp *profiledb.StorageCreateAutoDeviceResponse, err error)
OnProfiles func(
ctx context.Context,
req *profiledb.StorageProfilesRequest,
) (resp *profiledb.StorageProfilesResponse, err error)
}
// CreateAutoDevice implements the [profiledb.Storage] interface for
// *ProfileStorage.
func (s *ProfileStorage) CreateAutoDevice(
ctx context.Context,
req *profiledb.StorageCreateAutoDeviceRequest,
) (resp *profiledb.StorageCreateAutoDeviceResponse, err error) {
return s.OnCreateAutoDevice(ctx, req)
}
// Profiles implements the [profiledb.Storage] interface for *ProfileStorage.
func (s *ProfileStorage) Profiles(
ctx context.Context,
req *profiledb.StorageProfilesRequest,
) (resp *profiledb.StorageProfilesResponse, err error) {
return s.OnProfiles(ctx, req)
}
// NewProfileStorage returns a new *ProfileStorage all methods of which panic.
func NewProfileStorage() (s *ProfileStorage) {
return &ProfileStorage{
OnCreateAutoDevice: func(
ctx context.Context,
req *profiledb.StorageCreateAutoDeviceRequest,
) (resp *profiledb.StorageCreateAutoDeviceResponse, err error) {
panic(testutil.UnexpectedCall(ctx, req))
},
OnProfiles: func(
ctx context.Context,
req *profiledb.StorageProfilesRequest,
) (resp *profiledb.StorageProfilesResponse, err error) {
panic(testutil.UnexpectedCall(ctx, req))
},
}
}
// Package querylog
// type check
var _ querylog.Interface = (*QueryLog)(nil)
// QueryLog is a [querylog.Interface] for tests.
type QueryLog struct {
OnWrite func(ctx context.Context, e *querylog.Entry) (err error)
}
// Write implements the [querylog.Interface] interface for *QueryLog.
func (ql *QueryLog) Write(ctx context.Context, e *querylog.Entry) (err error) {
return ql.OnWrite(ctx, e)
}
// Package rulestat
// type check
var _ rulestat.Interface = (*RuleStat)(nil)
// RuleStat is a [rulestat.Interface] for tests.
type RuleStat struct {
OnCollect func(ctx context.Context, id filter.ID, text filter.RuleText)
}
// Collect implements the [rulestat.Interface] interface for *RuleStat.
func (s *RuleStat) Collect(ctx context.Context, id filter.ID, text filter.RuleText) {
s.OnCollect(ctx, id, text)
}
// Module dnsserver
// Package netext
var _ netext.ListenConfig = (*ListenConfig)(nil)
// ListenConfig is a [netext.ListenConfig] for tests.
type ListenConfig struct {
OnListen func(ctx context.Context, network, address string) (l net.Listener, err error)
OnListenPacket func(
ctx context.Context,
network string,
address string,
) (conn net.PacketConn, err error)
}
// Listen implements the [netext.ListenConfig] interface for *ListenConfig.
func (c *ListenConfig) Listen(
ctx context.Context,
network string,
address string,
) (l net.Listener, err error) {
return c.OnListen(ctx, network, address)
}
// ListenPacket implements the [netext.ListenConfig] interface for
// *ListenConfig.
func (c *ListenConfig) ListenPacket(
ctx context.Context,
network string,
address string,
) (conn net.PacketConn, err error) {
return c.OnListenPacket(ctx, network, address)
}
// Package ratelimit
// type check
var _ ratelimit.Interface = (*RateLimit)(nil)
// RateLimit is a [ratelimit.Interface] for tests.
type RateLimit struct {
OnIsRateLimited func(
ctx context.Context,
req *dns.Msg,
ip netip.Addr,
) (shouldDrop, isAllowlisted bool, err error)
OnCountResponses func(ctx context.Context, resp *dns.Msg, ip netip.Addr)
}
// IsRateLimited implements the [ratelimit.Interface] interface for *RateLimit.
func (l *RateLimit) IsRateLimited(
ctx context.Context,
req *dns.Msg,
ip netip.Addr,
) (shouldDrop, isAllowlisted bool, err error) {
return l.OnIsRateLimited(ctx, req, ip)
}
// CountResponses implements the [ratelimit.Interface] interface for *RateLimit.
func (l *RateLimit) CountResponses(ctx context.Context, req *dns.Msg, ip netip.Addr) {
l.OnCountResponses(ctx, req, ip)
}
// NewRateLimit returns a new *RateLimit all methods of which panic.
func NewRateLimit() (c *RateLimit) {
return &RateLimit{
OnIsRateLimited: func(
ctx context.Context,
req *dns.Msg,
addr netip.Addr,
) (shouldDrop, isAllowlisted bool, err error) {
panic(testutil.UnexpectedCall(ctx, req, addr))
},
OnCountResponses: func(ctx context.Context, resp *dns.Msg, addr netip.Addr) {
panic(testutil.UnexpectedCall(ctx, resp, addr))
},
}
}
// RemoteKV is an [remotekv.Interface] implementation for tests.
type RemoteKV struct {
OnGet func(ctx context.Context, key string) (val []byte, ok bool, err error)
OnSet func(ctx context.Context, key string, val []byte) (err error)
}
// type check
var _ remotekv.Interface = (*RemoteKV)(nil)
// Get implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Get(ctx context.Context, key string) (val []byte, ok bool, err error) {
return kv.OnGet(ctx, key)
}
// Set implements the [remotekv.Interface] interface for *RemoteKV.
func (kv *RemoteKV) Set(ctx context.Context, key string, val []byte) (err error) {
return kv.OnSet(ctx, key, val)
}
// Module prometheus
// PrometheusRegisterer is a [prometheus.Registerer] implementation for tests.
type PrometheusRegisterer struct {
OnRegister func(prometheus.Collector) (err error)
OnMustRegister func(...prometheus.Collector)
OnUnregister func(prometheus.Collector) (ok bool)
}
// type check
var _ prometheus.Registerer = (*PrometheusRegisterer)(nil)
// Register implements the [prometheus.Registerer] interface for
// *PrometheusRegisterer.
func (r *PrometheusRegisterer) Register(c prometheus.Collector) (err error) {
return r.OnRegister(c)
}
// MustRegister implements the [prometheus.Registerer] interface for
// *PrometheusRegisterer.
func (r *PrometheusRegisterer) MustRegister(collectors ...prometheus.Collector) {
r.OnMustRegister(collectors...)
}
// Unregister implements the [prometheus.Registerer] interface for
// *PrometheusRegisterer.
func (r *PrometheusRegisterer) Unregister(c prometheus.Collector) (ok bool) {
return r.OnUnregister(c)
}
// NewTestPrometheusRegisterer returns a [prometheus.Registerer] implementation
// that does nothing and returns nil from [prometheus.Registerer.Register] and
// true from [prometheus.Registerer.Unregister].
func NewTestPrometheusRegisterer() (r *PrometheusRegisterer) {
return &PrometheusRegisterer{
OnRegister: func(_ prometheus.Collector) (err error) { return nil },
OnMustRegister: func(_ ...prometheus.Collector) {},
OnUnregister: func(_ prometheus.Collector) (ok bool) { return true },
}
}