mirror of
https://github.com/AdguardTeam/AdGuardDNS.git
synced 2026-05-24 06:25:07 -04:00
459 lines
13 KiB
Go
459 lines
13 KiB
Go
package forward
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/pool"
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
"github.com/AdguardTeam/golibs/syncutil"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const (
|
|
// ErrQuestion is returned if the response from the upstream is invalid,
|
|
// i.e. this is a response to a different query.
|
|
ErrQuestion errors.Error = "response has invalid question section"
|
|
)
|
|
|
|
const (
|
|
// poolMaxCapacity is the default pool.Pool capacity we use in
|
|
// UpstreamPlain.
|
|
poolMaxCapacity = 1024
|
|
|
|
// poolIdleTimeout is the default value pool.Pool.IdleTimeout. We're not
|
|
// making it configurable just yet, 30 seconds looks like a reasonable
|
|
// value for DNS.
|
|
poolIdleTimeout = time.Second * 30
|
|
|
|
// minDNSMessageSize is a minimum theoretical size of a DNS message.
|
|
minDNSMessageSize = 12 + 5
|
|
|
|
// udpBufSize is the size of buffers we use for UDP messages. We use
|
|
// 4096 since it's highly unlikely that a UDP message can be larger.
|
|
//
|
|
// TODO(ameshkov): consider making it configurable in the future.
|
|
udpBufSize = 4096
|
|
|
|
// tcpBufSize is the size of buffers we use for TCP messages.
|
|
tcpBufSize = dns.MaxMsgSize
|
|
|
|
// tcpLengthPrefixSize is the size of the length prefix in responses via
|
|
// TCP.
|
|
tcpLengthPrefixSize = 2
|
|
)
|
|
|
|
// UpstreamPlain is a simple plain DNS client.
|
|
type UpstreamPlain struct {
|
|
// connection pools for TCP and TCP
|
|
connsPoolUDP *pool.Pool
|
|
connsPoolTCP *pool.Pool
|
|
|
|
// Pools used for TCP and UDP messages buffers in order to avoid extra
|
|
// allocations.
|
|
udpBufs *syncutil.Pool[[]byte]
|
|
tcpBufs *syncutil.Pool[[]byte]
|
|
|
|
addr netip.AddrPort
|
|
network Network
|
|
|
|
// timeout is the query timeout for this upstream.
|
|
timeout time.Duration
|
|
}
|
|
|
|
// type check
|
|
var _ Upstream = (*UpstreamPlain)(nil)
|
|
|
|
// UpstreamPlainConfig is the configuration structure for a plain-DNS upstream.
|
|
type UpstreamPlainConfig struct {
|
|
// Network is the network to use for this upstream.
|
|
Network Network
|
|
|
|
// Address is the address of the upstream DNS server.
|
|
Address netip.AddrPort
|
|
|
|
// Timeout is the optional query timeout for upstreams. If not set, the
|
|
// context timeout or [defaultUDPTimeout] is used in case of UDP network.
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// NewUpstreamPlain returns a new properly initialized *UpstreamPlain. c must
|
|
// not be nil.
|
|
func NewUpstreamPlain(c *UpstreamPlainConfig) (ups *UpstreamPlain) {
|
|
ups = &UpstreamPlain{
|
|
udpBufs: syncutil.NewSlicePool[byte](udpBufSize),
|
|
tcpBufs: syncutil.NewSlicePool[byte](tcpBufSize),
|
|
|
|
addr: c.Address,
|
|
network: c.Network,
|
|
|
|
timeout: c.Timeout,
|
|
}
|
|
|
|
ups.connsPoolUDP = pool.NewPool(poolMaxCapacity, makeConnsPoolFactory(ups, NetworkUDP))
|
|
ups.connsPoolUDP.IdleTimeout = poolIdleTimeout
|
|
ups.connsPoolTCP = pool.NewPool(poolMaxCapacity, makeConnsPoolFactory(ups, NetworkTCP))
|
|
ups.connsPoolTCP.IdleTimeout = poolIdleTimeout
|
|
|
|
return ups
|
|
}
|
|
|
|
// Exchange implements the [Upstream] interface for [*UpstreamPlain]. It
|
|
// handles gracefully the situation with truncated responses and fallbacks to
|
|
// TCP when needed. Uses the first of context's deadline and the configured
|
|
// timeout specify exchange deadline. Ignores [net.Error] and [io.EOF] errors
|
|
// that occur when writing response. Returns response, network type over which
|
|
// the request has been processed and error if happened.
|
|
func (u *UpstreamPlain) Exchange(
|
|
ctx context.Context,
|
|
req *dns.Msg,
|
|
) (resp *dns.Msg, nw Network, err error) {
|
|
defer func() { err = errors.Annotate(err, "upstreamplain: %w") }()
|
|
|
|
if u.timeout > 0 {
|
|
var cancel func()
|
|
ctx, cancel = context.WithTimeout(ctx, u.timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
nw, networkOverriden := networkOverrideFromContext(ctx)
|
|
if !networkOverriden {
|
|
nw = u.network
|
|
}
|
|
|
|
isUDPOnly := nw == NetworkUDP
|
|
doUDP := isUDPOnly || nw == NetworkAny
|
|
doTCP := nw == NetworkTCP
|
|
|
|
if doUDP {
|
|
nw = NetworkUDP
|
|
resp, err = u.exchangeNet(ctx, req, NetworkUDP)
|
|
if err != nil {
|
|
// The network error always causes the subsequent query attempt
|
|
// using fresh UDP connection, so if it happened again, the upstream
|
|
// is likely dead and using TCP appears meaningless. See
|
|
// [exchangeNet].
|
|
//
|
|
// Thus, non-network errors are considered being related to the
|
|
// response. It may also happen the received response is intended
|
|
// for another timed out request sent from the same source port, but
|
|
// falling back to TCP in this case shouldn't hurt.
|
|
//
|
|
// TODO(e.burkov): Investigate, why UDP-only upstreams are using
|
|
// TCP.
|
|
doTCP = !isExpectedConnErr(err)
|
|
} else {
|
|
// Also, fallback to TCP if the received response is truncated and
|
|
// the network isn't UDP-only.
|
|
doTCP = !isUDPOnly && resp != nil && resp.Truncated
|
|
}
|
|
}
|
|
|
|
if doTCP {
|
|
nw = NetworkTCP
|
|
resp, err = u.exchangeNet(ctx, req, NetworkTCP)
|
|
}
|
|
|
|
return resp, nw, err
|
|
}
|
|
|
|
// Close implements the [io.Closer] interface for *UpstreamPlain.
|
|
func (u *UpstreamPlain) Close() (err error) {
|
|
udpErr := u.connsPoolUDP.Close()
|
|
tcpErr := u.connsPoolTCP.Close()
|
|
|
|
return errors.Annotate(errors.Join(udpErr, tcpErr), "closing upstream: %w")
|
|
}
|
|
|
|
// String implements the [fmt.Stringer] interface for *UpstreamPlain.
|
|
// - If upstream's network is [NetworkAny], it will simply return the IP:port.
|
|
// - If the network is specified, it will return address in the
|
|
// "network://IP:port" format.
|
|
func (u *UpstreamPlain) String() (str string) {
|
|
if u.network == NetworkAny {
|
|
return u.addr.String()
|
|
}
|
|
|
|
return fmt.Sprintf("%s://%s", u.network, u.addr)
|
|
}
|
|
|
|
// exchangeNet sends a DNS query using the specified network (either
|
|
// [NetworkTCP] or [NetworkUDP]).
|
|
func (u *UpstreamPlain) exchangeNet(
|
|
ctx context.Context,
|
|
req *dns.Msg,
|
|
network Network,
|
|
) (resp *dns.Msg, err error) {
|
|
var connsPool *pool.Pool
|
|
if network == NetworkTCP {
|
|
connsPool = u.connsPoolTCP
|
|
} else {
|
|
connsPool = u.connsPoolUDP
|
|
}
|
|
|
|
// Get the buffer to use for packing the request and reading the response.
|
|
// This buffer needs to be returned back to the pool once we're done.
|
|
bufPtr := u.getBuffer(network)
|
|
defer u.putBuffer(network, bufPtr)
|
|
|
|
buf := (*bufPtr)
|
|
|
|
// Pack the query into the specified buffer.
|
|
bufReqLen, err := u.packReq(network, buf, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("packing request: %w", err)
|
|
}
|
|
|
|
// Try connecting to the upstream.
|
|
var conn *pool.Conn
|
|
conn, err = connsPool.Get(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting connection: %w", err)
|
|
}
|
|
|
|
// err is already wrapped inside processConn.
|
|
resp, err = u.processConn(ctx, conn, connsPool, network, req, buf, bufReqLen)
|
|
if isExpectedConnErr(err) {
|
|
conn, err = connsPool.Create(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating connection: %w", err)
|
|
}
|
|
|
|
// err is already wrapped inside processConn.
|
|
resp, err = u.processConn(ctx, conn, connsPool, network, req, buf, bufReqLen)
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
// validatePlainResponse returns an error if the response is not valid for the
|
|
// original request. This is required because we might receive a response to a
|
|
// different query, e.g. when the server is under heavy load.
|
|
func validatePlainResponse(req, resp *dns.Msg) (err error) {
|
|
if req.Id != resp.Id {
|
|
return dns.ErrId
|
|
}
|
|
|
|
if qlen := len(resp.Question); qlen != 1 {
|
|
return fmt.Errorf("%w: only 1 question allowed; got %d", ErrQuestion, qlen)
|
|
}
|
|
|
|
reqQ, respQ := req.Question[0], resp.Question[0]
|
|
|
|
if reqQ.Qtype != respQ.Qtype {
|
|
return fmt.Errorf("%w: mismatched type %s", ErrQuestion, dns.Type(respQ.Qtype))
|
|
}
|
|
|
|
// Compare the names case-insensitively, just like CoreDNS does.
|
|
if !strings.EqualFold(reqQ.Name, respQ.Name) {
|
|
return fmt.Errorf("%w: mismatched name %q", ErrQuestion, respQ.Name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// defaultUDPTimeout is the default timeout for waiting a valid DNS message or
|
|
// network error over UDP protocol.
|
|
const defaultUDPTimeout = 1 * time.Minute
|
|
|
|
// processConn writes the query to the connection and then reads the response
|
|
// from it. We might be dealing with an idle dead connection so if we get
|
|
// a network error here, we'll attempt to open a new connection and call this
|
|
// function again.
|
|
//
|
|
// TODO(ameshkov): 7 parameters in a method is not okay, rework this.
|
|
func (u *UpstreamPlain) processConn(
|
|
ctx context.Context,
|
|
conn *pool.Conn,
|
|
connsPool *pool.Pool,
|
|
network Network,
|
|
req *dns.Msg,
|
|
buf []byte,
|
|
bufReqLen int,
|
|
) (resp *dns.Msg, err error) {
|
|
// Make sure that we return the connection to the pool in the end or close
|
|
// if there was any error.
|
|
defer func() {
|
|
if err != nil {
|
|
err = errors.WithDeferred(err, conn.Close())
|
|
} else {
|
|
err = connsPool.Put(conn)
|
|
}
|
|
}()
|
|
|
|
// Prepare a context with a deadline if needed.
|
|
deadline, ok := ctx.Deadline()
|
|
if !ok && network == NetworkUDP {
|
|
deadline, ok = time.Now().Add(defaultUDPTimeout), true
|
|
}
|
|
|
|
if ok {
|
|
err = conn.SetDeadline(deadline)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("setting deadline: %w", err)
|
|
}
|
|
}
|
|
|
|
// Write the request to the connection.
|
|
_, err = conn.Write(buf[:bufReqLen])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("writing request: %w", err)
|
|
}
|
|
|
|
return u.readValidMsg(req, network, conn, buf)
|
|
}
|
|
|
|
// readValidMsg reads the response from conn to buf, parses and validates it.
|
|
func (u *UpstreamPlain) readValidMsg(
|
|
req *dns.Msg,
|
|
network Network,
|
|
conn net.Conn,
|
|
buf []byte,
|
|
) (resp *dns.Msg, err error) {
|
|
resp, err = u.readMsg(network, conn, buf)
|
|
if err != nil {
|
|
// Don't wrap the error, because it's informative enough as is.
|
|
return nil, err
|
|
}
|
|
|
|
err = validatePlainResponse(req, resp)
|
|
if err != nil {
|
|
return resp, fmt.Errorf("validating %s response: %w", network, err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// readMsg reads the response from the specified connection and parses it. conn
|
|
// must not be nil. If network is [NetworkTCP] then buf must be at least
|
|
// [tcpBufSize] long or [udpBufSize] if network is [NetworkUDP].
|
|
func (u *UpstreamPlain) readMsg(
|
|
network Network,
|
|
conn net.Conn,
|
|
buf []byte,
|
|
) (res *dns.Msg, err error) {
|
|
var n int
|
|
|
|
if network == NetworkTCP {
|
|
_, err = io.ReadFull(conn, buf[:tcpLengthPrefixSize])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading binary data: %w", err)
|
|
}
|
|
|
|
length := binary.BigEndian.Uint16(buf[:tcpLengthPrefixSize])
|
|
n, err = io.ReadFull(conn, buf[:length])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading full: %w", err)
|
|
}
|
|
} else {
|
|
n, err = conn.Read(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("udp network reading: %w", err)
|
|
}
|
|
}
|
|
|
|
if n < minDNSMessageSize {
|
|
return nil, fmt.Errorf("invalid msg: %w", dns.ErrShortRead)
|
|
}
|
|
|
|
res = &dns.Msg{}
|
|
err = res.Unpack(buf[:n])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unpacking msg: %w", err)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// packReq packs the DNS query to the specified buffer.
|
|
func (u *UpstreamPlain) packReq(network Network, buf []byte, req *dns.Msg) (n int, err error) {
|
|
reqLen := req.Len()
|
|
if reqLen > dns.MaxMsgSize {
|
|
return 0, dns.ErrBuf
|
|
}
|
|
|
|
if network == NetworkTCP {
|
|
if reqLen > len(buf)-tcpLengthPrefixSize {
|
|
return 0, dns.ErrBuf
|
|
}
|
|
|
|
// #nosec G115 -- reqLen has already been checked against
|
|
// dns.MaxMsgSize, which equals math.MaxUint16.
|
|
binary.BigEndian.PutUint16(buf, uint16(reqLen))
|
|
_, err = req.PackBuffer(buf[tcpLengthPrefixSize:])
|
|
|
|
return reqLen + tcpLengthPrefixSize, err
|
|
}
|
|
|
|
if reqLen > len(buf) {
|
|
return 0, dns.ErrBuf
|
|
}
|
|
|
|
_, err = req.PackBuffer(buf)
|
|
|
|
return reqLen, err
|
|
}
|
|
|
|
// getBuffer gets a bytes buffer that used for packing the request and then for
|
|
// reading the response.
|
|
func (u *UpstreamPlain) getBuffer(network Network) (bufPtr *[]byte) {
|
|
switch network {
|
|
case NetworkTCP:
|
|
return u.tcpBufs.Get()
|
|
case NetworkUDP:
|
|
return u.udpBufs.Get()
|
|
default:
|
|
panic(fmt.Errorf("no bufs for network %q in get", network))
|
|
}
|
|
}
|
|
|
|
// putBuffer puts the buffer back to the corresponding pool.
|
|
func (u *UpstreamPlain) putBuffer(network Network, bufPtr *[]byte) {
|
|
switch network {
|
|
case NetworkTCP:
|
|
u.tcpBufs.Put(bufPtr)
|
|
case NetworkUDP:
|
|
u.udpBufs.Put(bufPtr)
|
|
default:
|
|
panic(fmt.Errorf("no bufs for network %q in put", network))
|
|
}
|
|
}
|
|
|
|
// makeConnsPoolFactory makes a pool.Factory method for the specified address
|
|
// and network.
|
|
func makeConnsPoolFactory(u *UpstreamPlain, network Network) (f pool.Factory) {
|
|
var dialNetwork string
|
|
switch network {
|
|
case NetworkTCP:
|
|
dialNetwork = "tcp"
|
|
case NetworkUDP:
|
|
dialNetwork = "udp"
|
|
default:
|
|
panic("invalid network passed to makeConnsPoolFactory")
|
|
}
|
|
|
|
return func(ctx context.Context) (conn net.Conn, err error) {
|
|
deadline, ok := ctx.Deadline()
|
|
var timeout time.Duration
|
|
if ok {
|
|
timeout = time.Until(deadline)
|
|
}
|
|
|
|
return net.DialTimeout(dialNetwork, u.addr.String(), timeout)
|
|
}
|
|
}
|
|
|
|
// isExpectedConnErr returns true if the error is expected. In this case,
|
|
// we will make a second attempt to process the request.
|
|
func isExpectedConnErr(err error) (is bool) {
|
|
var netErr net.Error
|
|
|
|
return err != nil && (errors.As(err, &netErr) || errors.Is(err, io.EOF))
|
|
}
|