build(deps): bump github.com/nats-io/nats-server/v2

Bumps [github.com/nats-io/nats-server/v2](https://github.com/nats-io/nats-server) from 2.12.5 to 2.12.6.
- [Release notes](https://github.com/nats-io/nats-server/releases)
- [Changelog](https://github.com/nats-io/nats-server/blob/main/RELEASES.md)
- [Commits](https://github.com/nats-io/nats-server/compare/v2.12.5...v2.12.6)

---
updated-dependencies:
- dependency-name: github.com/nats-io/nats-server/v2
  dependency-version: 2.12.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot]
2026-03-25 14:46:12 +00:00
committed by Ralf Haferkamp
parent 9bee89691f
commit cde52d9e9b
33 changed files with 901 additions and 491 deletions

4
go.mod
View File

@@ -55,7 +55,7 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/mna/pigeon v1.3.0
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/nats-io/nats-server/v2 v2.12.5
github.com/nats-io/nats-server/v2 v2.12.6
github.com/nats-io/nats.go v1.49.0
github.com/oklog/run v1.2.0
github.com/olekukonko/tablewriter v1.1.4
@@ -304,7 +304,7 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/mschoch/smat v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nats-io/jwt/v2 v2.8.0 // indirect
github.com/nats-io/jwt/v2 v2.8.1 // indirect
github.com/nats-io/nkeys v0.4.15 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect

8
go.sum
View File

@@ -905,10 +905,10 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8=
github.com/nats-io/jwt/v2 v2.8.0 h1:K7uzyz50+yGZDO5o772eRE7atlcSEENpL7P+b74JV1g=
github.com/nats-io/jwt/v2 v2.8.0/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA=
github.com/nats-io/nats-server/v2 v2.12.5 h1:EOHLbsLJgUHUwzkj9gBTOlubkX+dmSs0EYWMdBiHivU=
github.com/nats-io/nats-server/v2 v2.12.5/go.mod h1:JQDAKcwdXs0NRhvYO31dzsXkzCyDkOBS7SKU3Nozu14=
github.com/nats-io/jwt/v2 v2.8.1 h1:V0xpGuD/N8Mi+fQNDynXohVvp7ZztevW5io8CUWlPmU=
github.com/nats-io/jwt/v2 v2.8.1/go.mod h1:nWnOEEiVMiKHQpnAy4eXlizVEtSfzacZ1Q43LIRavZg=
github.com/nats-io/nats-server/v2 v2.12.6 h1:Egbx9Vl7Ch8wTtpXPGqbehkZ+IncKqShUxvrt1+Enc8=
github.com/nats-io/nats-server/v2 v2.12.6/go.mod h1:4HPlrvtmSO3yd7KcElDNMx9kv5EBJBnJJzQPptXlheo=
github.com/nats-io/nats.go v1.49.0 h1:yh/WvY59gXqYpgl33ZI+XoVPKyut/IcEaqtsiuTJpoE=
github.com/nats-io/nats.go v1.49.0/go.mod h1:fDCn3mN5cY8HooHwE2ukiLb4p4G4ImmzvXyJt+tGwdw=
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=

View File

@@ -152,19 +152,22 @@ type Mapping map[Subject][]WeightedMapping
func (m *Mapping) Validate(vr *ValidationResults) {
for ubFrom, wm := range (map[Subject][]WeightedMapping)(*m) {
ubFrom.Validate(vr)
perCluster := make(map[string]uint8)
total := uint8(0)
perCluster := make(map[string]uint32)
total := uint32(0)
for _, e := range wm {
e.Subject.Validate(vr)
if e.GetWeight() > 100 {
vr.AddError("Mapping %q has a weight %d that exceeds 100", ubFrom, e.GetWeight())
}
if e.Cluster != "" {
t := perCluster[e.Cluster]
t += e.Weight
t += uint32(e.GetWeight())
perCluster[e.Cluster] = t
if t > 100 {
vr.AddError("Mapping %q in cluster %q exceeds 100%% among all of it's weighted to mappings", ubFrom, e.Cluster)
}
} else {
total += e.GetWeight()
total += uint32(e.GetWeight())
}
}
if total > 100 {
@@ -286,7 +289,7 @@ func (a *Account) Validate(acct *AccountClaims, vr *ValidationResults) {
tvr := CreateValidationResults()
a.Trace.Destination.Validate(tvr)
if !tvr.IsEmpty() {
vr.AddError(fmt.Sprintf("the account Trace.Destination %s", tvr.Issues[0].Description))
vr.AddError("the account Trace.Destination %s", tvr.Issues[0].Description)
}
if a.Trace.Destination.HasWildCards() {
vr.AddError("the account Trace.Destination subject %q is not a valid publish subject", a.Trace.Destination)
@@ -325,7 +328,7 @@ func (a *Account) Validate(acct *AccountClaims, vr *ValidationResults) {
a.Info.Validate(vr)
if err := a.ClusterTraffic.Valid(); err != nil {
vr.AddError(err.Error())
vr.AddError("%s", err.Error())
}
}

View File

@@ -63,6 +63,9 @@ func formatJwt(kind string, jwtString string) ([]byte, error) {
func DecorateSeed(seed []byte) ([]byte, error) {
w := bytes.NewBuffer(nil)
ts := bytes.TrimSpace(seed)
if len(ts) < 2 {
return nil, errors.New("seed is too short")
}
pre := string(ts[0:2])
kind := ""
switch pre {
@@ -138,6 +141,18 @@ func FormatUserConfig(jwtString string, seed []byte) ([]byte, error) {
return nil, fmt.Errorf("nkey seed is not an user seed")
}
kp, err := nkeys.FromSeed(seed)
if err != nil {
return nil, err
}
pk, err := kp.PublicKey()
if err != nil {
return nil, err
}
if pk != gc.Claims().Subject {
return nil, fmt.Errorf("nkey seed does not match the jwt subject")
}
d, err := DecorateSeed(seed)
if err != nil {
return nil, err

View File

@@ -26,6 +26,12 @@ import (
const libVersion = 2
// MaxTokenSize is the maximum size of a JWT token in bytes
const MaxTokenSize = 1024 * 1024 // 1MB
// ErrTokenTooLarge is returned when a token exceeds MaxTokenSize
var ErrTokenTooLarge = errors.New("token too large")
type identifier struct {
Type ClaimType `json:"type,omitempty"`
GenericFields `json:"nats,omitempty"`
@@ -56,6 +62,9 @@ type v1ClaimsDataDeletedFields struct {
// doesn't match the expected algorithm, or the claim is
// not valid or verification fails an error is returned.
func Decode(token string) (Claims, error) {
if len(token) > MaxTokenSize {
return nil, fmt.Errorf("token size %d exceeds maximum of %d bytes: %w", len(token), MaxTokenSize, ErrTokenTooLarge)
}
// must have 3 chunks
chunks := strings.Split(token, ".")
if len(chunks) != 3 {

View File

@@ -126,7 +126,8 @@ type Imports []*Import
// Validate checks if an import is valid for the wrapping account
func (i *Imports) Validate(acctPubKey string, vr *ValidationResults) {
toSet := make(map[Subject]struct{}, len(*i))
// Group subjects by account to check for overlaps only within the same account
subsByAcct := make(map[string]map[Subject]struct{}, len(*i))
for _, v := range *i {
if v == nil {
vr.AddError("null import is not allowed")
@@ -140,15 +141,19 @@ func (i *Imports) Validate(acctPubKey string, vr *ValidationResults) {
if sub == "" {
sub = v.Subject
}
for k := range toSet {
if sub.IsContainedIn(k) || k.IsContainedIn(sub) {
vr.AddError("overlapping subject namespace for %q and %q", sub, k)
// Check for overlapping subjects only within the same account
for subOther := range subsByAcct[v.Account] {
if sub.IsContainedIn(subOther) || subOther.IsContainedIn(sub) {
vr.AddError("overlapping subject namespace for %q and %q in same account %q", sub, subOther, v.Account)
}
}
if _, ok := toSet[sub]; ok {
vr.AddError("overlapping subject namespace for %q", v.To)
if subsByAcct[v.Account] == nil {
subsByAcct[v.Account] = make(map[Subject]struct{}, len(*i))
}
toSet[sub] = struct{}{}
if _, ok := subsByAcct[v.Account][sub]; ok {
vr.AddError("overlapping subject namespace for %q in account %q", sub, v.Account)
}
subsByAcct[v.Account][sub] = struct{}{}
}
v.Validate(acctPubKey, vr)
}

View File

@@ -71,12 +71,12 @@ func ParseServerVersion(version string) (int, int, int, error) {
// Validate checks the validity of the operators contents
func (o *Operator) Validate(vr *ValidationResults) {
if err := o.validateAccountServerURL(); err != nil {
vr.AddError(err.Error())
vr.AddError("%s", err.Error())
}
for _, v := range o.validateOperatorServiceURLs() {
if v != nil {
vr.AddError(v.Error())
vr.AddError("%s", v.Error())
}
}

View File

@@ -114,42 +114,28 @@ func ParseFileWithChecks(fp string) (map[string]any, error) {
return p.mapping, nil
}
// cleanupUsedEnvVars will recursively remove all already used
// environment variables which might be in the parsed tree.
func cleanupUsedEnvVars(m map[string]any) {
for k, v := range m {
t := v.(*token)
if t.usedVariable {
delete(m, k)
continue
}
// Cleanup any other env var that is still in the map.
if tm, ok := t.value.(map[string]any); ok {
cleanupUsedEnvVars(tm)
}
// configDigest returns a digest for the parsed config.
func configDigest(m map[string]any) (string, error) {
digest := sha256.New()
e := json.NewEncoder(digest)
if err := e.Encode(m); err != nil {
return _EMPTY_, err
}
return fmt.Sprintf("sha256:%x", digest.Sum(nil)), nil
}
// ParseFileWithChecksDigest returns the processed config and a digest
// that represents the configuration.
func ParseFileWithChecksDigest(fp string) (map[string]any, string, error) {
data, err := os.ReadFile(fp)
m, err := ParseFileWithChecks(fp)
if err != nil {
return nil, _EMPTY_, err
}
p, err := parse(string(data), fp, true)
digest, err := configDigest(m)
if err != nil {
return nil, _EMPTY_, err
}
// Filter out any environment variables before taking the digest.
cleanupUsedEnvVars(p.mapping)
digest := sha256.New()
e := json.NewEncoder(digest)
err = e.Encode(p.mapping)
if err != nil {
return nil, _EMPTY_, err
}
return p.mapping, fmt.Sprintf("sha256:%x", digest.Sum(nil)), nil
return m, digest, nil
}
type token struct {

View File

@@ -235,19 +235,19 @@ func (d *DN) RDNsMatch(other *DN) bool {
if len(d.RDNs) != len(other.RDNs) {
return false
}
CheckNextRDN:
matched := make([]bool, len(other.RDNs))
for _, irdn := range d.RDNs {
for _, ordn := range other.RDNs {
if (len(irdn.Attributes) == len(ordn.Attributes)) &&
(irdn.hasAllAttributes(ordn.Attributes) && ordn.hasAllAttributes(irdn.Attributes)) {
// Found the RDN, check if next one matches.
continue CheckNextRDN
found := false
for j, ordn := range other.RDNs {
if !matched[j] && irdn.Equal(ordn) {
matched[j] = true
found = true
break
}
}
// Could not find a matching individual RDN, auth fails.
return false
if !found {
return false
}
}
return true
}

View File

@@ -1616,10 +1616,12 @@ func (a *Account) checkServiceImportsForCycles(from string, visited map[string]b
}
// Push ourselves and check si.acc
visited[a.Name] = true
if subjectIsSubsetMatch(si.from, from) {
from = si.from
// Make a copy to not overwrite the passed value.
f := from
if subjectIsSubsetMatch(si.from, f) {
f = si.from
}
if err := si.acc.checkServiceImportsForCycles(from, visited); err != nil {
if err := si.acc.checkServiceImportsForCycles(f, visited); err != nil {
return err
}
a.mu.RLock()
@@ -1674,10 +1676,12 @@ func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool
}
// Push ourselves and check si.acc
visited[a.Name] = true
if subjectIsSubsetMatch(si.to, to) {
to = si.to
// Make a copy to not overwrite the passed value.
t := to
if subjectIsSubsetMatch(si.to, t) {
t = si.to
}
if err := si.acc.checkStreamImportsForCycles(to, visited); err != nil {
if err := si.acc.checkStreamImportsForCycles(t, visited); err != nil {
return err
}
a.mu.RLock()

View File

@@ -604,6 +604,7 @@ func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.User
func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (authorized bool) {
var (
nkey *NkeyUser
ujwt string
juc *jwt.UserClaims
acc *Account
user *User
@@ -798,16 +799,23 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
// Check if we have trustedKeys defined in the server. If so we require a user jwt.
if s.trustedKeys != nil {
if c.opts.JWT == _EMPTY_ && opts.DefaultSentinel != _EMPTY_ {
c.opts.JWT = opts.DefaultSentinel
ujwt = c.opts.JWT
if ujwt == _EMPTY_ && c.isMqtt() {
// For MQTT, we pass the password as the JWT too, but do so here so it's not
// publicly exposed in the client options if it isn't a JWT.
ujwt = c.opts.Password
}
if c.opts.JWT == _EMPTY_ {
if ujwt == _EMPTY_ && opts.DefaultSentinel != _EMPTY_ {
c.opts.JWT = opts.DefaultSentinel
ujwt = c.opts.JWT
}
if ujwt == _EMPTY_ {
s.mu.Unlock()
c.Debugf("Authentication requires a user JWT")
return false
}
// So we have a valid user jwt here.
juc, err = jwt.DecodeUserClaims(c.opts.JWT)
juc, err = jwt.DecodeUserClaims(ujwt)
if err != nil {
s.mu.Unlock()
c.Debugf("User JWT not valid: %v", err)
@@ -1077,6 +1085,11 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
// Hold onto the user's public key.
c.mu.Lock()
c.pubKey = juc.Subject
// If this is a MQTT client, we purposefully didn't populate the JWT as it could contain
// a password or token. Now we know it's a valid JWT, we can populate it.
if c.isMqtt() {
c.opts.JWT = ujwt
}
c.tags = juc.Tags
c.nameTag = juc.Name
c.mu.Unlock()

View File

@@ -81,9 +81,6 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequi
xkp, xkey = s.xkp, s.info.XKey
}
// FIXME: so things like the server ID that get assigned, are used as a sort of nonce - but
// reality is that the keypair here, is generated, so the response generated a JWT has to be
// this user - no replay possible
// Create a keypair for the user. We will expect this public user to be in the signed response.
// This prevents replay attacks.
ukp, _ := nkeys.CreateUser()

View File

@@ -153,7 +153,6 @@ const (
compressionNegotiated // Marks if this connection has negotiated compression level with remote.
didTLSFirst // Marks if this connection requested and was accepted doing the TLS handshake first (prior to INFO).
isSlowConsumer // Marks connection as a slow consumer.
firstPong // Marks if this is the first PONG received
)
// set the flag (would be equivalent to set the boolean to true)
@@ -869,6 +868,11 @@ func (c *client) registerWithAccount(acc *Account) error {
}
c.mu.Lock()
// This check does not apply to SYSTEM or JETSTREAM or ACCOUNT clients (because they don't have a `nc`...)
if c.isClosed() && !isInternalClient(c.kind) {
c.mu.Unlock()
return ErrConnectionClosed
}
kind := c.kind
srv := c.srv
c.acc = acc
@@ -1353,6 +1357,13 @@ func (c *client) flushClients(budget time.Duration) time.Time {
return last
}
func (c *client) resetReadLoopStallTime() {
if c.in.tst >= stallClientMaxDuration {
c.rateLimitFormatWarnf("Producer was stalled for a total of %v", c.in.tst.Round(time.Millisecond))
}
c.in.tst = 0
}
// readLoop is the main socket read functionality.
// Runs in its own Go routine.
func (c *client) readLoop(pre []byte) {
@@ -1430,21 +1441,6 @@ func (c *client) readLoop(pre []byte) {
return
}
}
if ws {
bufs, err = c.wsRead(wsr, reader, b[:n])
if bufs == nil && err != nil {
if err != io.EOF {
c.Errorf("read error: %v", err)
}
c.closeConnection(closedStateForErr(err))
return
} else if bufs == nil {
continue
}
} else {
bufs[0] = b[:n]
}
// Check if the account has mappings and if so set the local readcache flag.
// We check here to make sure any changes such as config reload are reflected here.
if c.kind == CLIENT || c.kind == LEAF {
@@ -1462,17 +1458,32 @@ func (c *client) readLoop(pre []byte) {
c.in.bytes = 0
c.in.subs = 0
if ws {
err = c.wsReadAndParse(wsr, reader, b[:n])
if err != nil {
// Match the normal parse path: any already-buffered deliveries
// need their pending flush signals drained before we close.
c.flushClients(0)
if err != io.EOF {
c.Errorf("read error: %v", err)
}
c.closeConnection(closedStateForErr(err))
return
}
c.resetReadLoopStallTime()
goto postParse
} else {
bufs[0] = b[:n]
}
// Main call into parser for inbound data. This will generate callouts
// to process messages, etc.
for i := 0; i < len(bufs); i++ {
if err := c.parse(bufs[i]); err != nil {
if err == ErrMinimumVersionRequired {
// Special case here, currently only for leaf node connections.
// When process the CONNECT protocol, if the minimum version
// required was not met, an error was printed and sent back to
// the remote, and connection was closed after a certain delay
// (to avoid "rapid" reconnection from the remote).
// We don't need to do any of the things below, simply return.
// processLeafConnect() already sent the rejection and closed
// the connection, so there is nothing else to do here.
return
}
if dur := time.Since(c.in.start); dur >= readLoopReportThreshold {
@@ -1489,13 +1500,10 @@ func (c *client) readLoop(pre []byte) {
}
return
}
// Clear total stalled time here.
if c.in.tst >= stallClientMaxDuration {
c.rateLimitFormatWarnf("Producer was stalled for a total of %v", c.in.tst.Round(time.Millisecond))
}
c.in.tst = 0
c.resetReadLoopStallTime()
}
postParse:
// If we are a ROUTER/LEAF and have processed an INFO, it is possible that
// we are asked to switch to compression now.
if checkCompress && c.in.flags.isSet(switchToCompression) {
@@ -2115,41 +2123,37 @@ func (c *client) processErr(errStr string) {
}
}
// Password pattern matcher.
var passPat = regexp.MustCompile(`"?\s*pass\S*?"?\s*[:=]\s*"?(([^",\r\n}])*)`)
var tokenPat = regexp.MustCompile(`"?\s*auth_token\S*?"?\s*[:=]\s*"?(([^",\r\n}])*)`)
// Matcher for pass/password and auth_token fields.
var prefixAuthPat = regexp.MustCompile(`"?\s*(?:auth_token\S*?|pass\S*?)"?\s*[:=]\s*"?([^",\r\n}]*)`)
// Exact matcher for fields sig, proxy_sig and nkey.
// Overlapping field "sig" does not match inside "proxy_sig".
var exactAuthPat = regexp.MustCompile(`(?:^|[^A-Za-z0-9_])"?\s*(?:proxy_sig|nkey|sig)"?\s*[:=]\s*"?([^",\r\n}]*)`)
// removeSecretsFromTrace removes any notion of passwords/tokens from trace
// messages for logging.
func removeSecretsFromTrace(arg []byte) []byte {
buf := redact("pass", passPat, arg)
return redact("auth_token", tokenPat, buf)
buf := redact(prefixAuthPat, arg)
return redact(exactAuthPat, buf)
}
func redact(name string, pat *regexp.Regexp, proto []byte) []byte {
if !bytes.Contains(proto, []byte(name)) {
func redact(pat *regexp.Regexp, proto []byte) []byte {
m := pat.FindAllSubmatchIndex(proto, -1)
if len(m) == 0 {
return proto
}
// Take a copy of the connect proto just for the trace message.
var _arg [4096]byte
buf := append(_arg[:0], proto...)
m := pat.FindAllSubmatchIndex(buf, -1)
if len(m) == 0 {
return proto
}
redactedPass := []byte("[REDACTED]")
for _, i := range m {
if len(i) < 4 {
for i := len(m) - 1; i >= 0; i-- {
match := m[i]
if len(match) < 4 {
continue
}
start := i[2]
end := i[3]
start, end := match[2], match[3]
// Replace value substring.
buf = append(buf[:start], append(redactedPass, buf[end:]...)...)
break
}
return buf
}
@@ -2685,11 +2689,9 @@ func (c *client) processPong() {
c.rtt = computeRTT(c.rttStart)
srv := c.srv
reorderGWs := c.kind == GATEWAY && c.gw.outbound
firstPong := c.flags.setIfNotSet(firstPong)
var ri *routeInfo
// When receiving the first PONG, for a route with pooling, we may be
// instructed to start a new route.
if firstPong && c.kind == ROUTER && c.route != nil {
// For a route with pooling, we may be instructed to start a new route.
if c.kind == ROUTER && c.route != nil && c.route.startNewRoute != nil {
ri = c.route.startNewRoute
c.route.startNewRoute = nil
}
@@ -2807,9 +2809,10 @@ func (c *client) processHeaderPub(arg, remaining []byte) error {
// look for the tracing header and if found, we will generate a
// trace event with the max payload ingress error.
// Do this only for CLIENT connections.
if c.kind == CLIENT && len(remaining) > 0 {
if td := getHeader(MsgTraceDest, remaining); len(td) > 0 {
c.initAndSendIngressErrEvent(remaining, string(td), ErrMaxPayload)
if c.kind == CLIENT && c.pa.hdr > 0 && len(remaining) > 0 {
hdr := remaining[:min(len(remaining), c.pa.hdr)]
if td, ok := c.allowedMsgTraceDest(hdr, false); ok && td != _EMPTY_ {
c.initAndSendIngressErrEvent(hdr, td, ErrMaxPayload)
}
}
c.maxPayloadViolation(c.pa.size, maxPayload)
@@ -3242,10 +3245,12 @@ func (c *client) canSubscribe(subject string, optQueue ...string) bool {
queue = optQueue[0]
}
// For CLIENT connections that are MQTT, or other types of connections, we will
// implicitly allow anything that starts with the "$MQTT." prefix. However,
// For CLIENT connections that are MQTT we will implicitly allow anything that starts with
// the "$MQTT.sub." or "$MQTT.deliver.pubrel." prefix. For other types of connections, we
// will implicitly allow anything that starts with the full "$MQTT." prefix. However,
// we don't just return here, we skip the check for "allow" but will check "deny".
if (c.isMqtt() || (c.kind != CLIENT)) && strings.HasPrefix(subject, mqttPrefix) {
if (c.isMqtt() && (strings.HasPrefix(subject, mqttSubPrefix) || strings.HasPrefix(subject, mqttPubRelDeliverySubjectPrefix))) ||
(c.kind != CLIENT && strings.HasPrefix(subject, mqttPrefix)) {
checkAllow = false
}
// Check allow list. If no allow list that means all are allowed. Deny can overrule.
@@ -4055,6 +4060,41 @@ func (c *client) pubAllowed(subject string) bool {
return c.pubAllowedFullCheck(subject, true, false)
}
// allowedMsgTraceDest returns the trace destination if present and authorized.
// It only considers static publish permissions and does not consume dynamic
// reply permissions because the client is not publishing the trace event itself.
func (c *client) allowedMsgTraceDest(hdr []byte, hasLock bool) (string, bool) {
if len(hdr) == 0 {
return _EMPTY_, true
}
td := sliceHeader(MsgTraceDest, hdr)
if len(td) == 0 {
return _EMPTY_, true
}
dest := bytesToString(td)
if c.kind == CLIENT {
if hasGWRoutedReplyPrefix(td) {
return dest, false
}
var acc *Account
var srv *Server
if !hasLock {
c.mu.Lock()
}
acc, srv = c.acc, c.srv
if !hasLock {
c.mu.Unlock()
}
if bytes.HasPrefix(td, clientNRGPrefix) && srv != nil && acc != srv.SystemAccount() {
return dest, false
}
}
if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowedFullCheck(dest, false, hasLock) {
return dest, false
}
return dest, true
}
// pubAllowedFullCheck checks on all publish permissioning depending
// on the flag for dynamic reply permissions.
func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bool {
@@ -4067,10 +4107,10 @@ func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bo
return v.(bool)
}
allowed, checkAllow := true, true
// For CLIENT connections that are MQTT, or other types of connections, we will
// implicitly allow anything that starts with the "$MQTT." prefix. However,
// we don't just return here, we skip the check for "allow" but will check "deny".
if (c.isMqtt() || c.kind != CLIENT) && strings.HasPrefix(subject, mqttPrefix) {
// For any connections, other than CLIENT, we will implicitly allow anything that
// starts with the "$MQTT." prefix. However, we don't just return here,
// we skip the check for "allow" but will check "deny".
if c.kind != CLIENT && strings.HasPrefix(subject, mqttPrefix) {
checkAllow = false
}
// Cache miss, check allow then deny as needed.
@@ -4190,10 +4230,19 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
genidAddr := &acc.sl.genid
// Check pub permissions
if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowedFullCheck(string(c.pa.subject), true, true) {
c.mu.Unlock()
c.pubPermissionViolation(c.pa.subject)
return false, true
if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) {
if !c.pubAllowedFullCheck(string(c.pa.subject), true, true) {
c.mu.Unlock()
c.pubPermissionViolation(c.pa.subject)
return false, true
}
}
if c.pa.hdr > 0 {
if td, ok := c.allowedMsgTraceDest(msg[:c.pa.hdr], true); !ok {
c.mu.Unlock()
c.pubPermissionViolation(stringToBytes(td))
return false, true
}
}
c.mu.Unlock()
@@ -4393,28 +4442,43 @@ func (c *client) setupResponseServiceImport(acc *Account, si *serviceImport, tra
return rsi
}
// Will remove a header if present.
func removeHeaderIfPresent(hdr []byte, key string) []byte {
start := getHeaderKeyIndex(key, hdr)
// key can't be first and we want to check that it is preceded by a '\n'
if start < 1 || hdr[start-1] != '\n' {
// Will remove a status and description from the header if present.
func removeHeaderStatusIfPresent(hdr []byte) []byte {
k := []byte("NATS/1.0")
kl, i := len(k), bytes.IndexByte(hdr, '\r')
if !bytes.HasPrefix(hdr, k) || i <= kl {
return hdr
}
index := start + len(key)
if index >= len(hdr) || hdr[index] != ':' {
return hdr
}
end := bytes.Index(hdr[start:], []byte(_CRLF_))
if end < 0 {
return hdr
}
hdr = append(hdr[:start], hdr[start+end+len(_CRLF_):]...)
if len(hdr) <= len(emptyHdrLine) {
hdr = append(hdr[:kl], hdr[i:]...)
if len(hdr) == len(emptyHdrLine) {
return nil
}
return hdr
}
// Will remove a header if present.
func removeHeaderIfPresent(hdr []byte, key string) []byte {
for {
start := getHeaderKeyIndex(key, hdr)
// key can't be first and we want to check that it is preceded by a '\n'
if start < 1 || hdr[start-1] != '\n' {
return hdr
}
index := start + len(key)
if index >= len(hdr) || hdr[index] != ':' {
return hdr
}
end := bytes.Index(hdr[start:], []byte(_CRLF_))
if end < 0 {
return hdr
}
hdr = append(hdr[:start], hdr[start+end+len(_CRLF_):]...)
if len(hdr) <= len(emptyHdrLine) {
return nil
}
}
}
func removeHeaderIfPrefixPresent(hdr []byte, prefix string) []byte {
var index int
for {
@@ -4749,16 +4813,33 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
if !isResponse {
isSysImport := siAcc == c.srv.SystemAccount()
var ci *ClientInfo
if hadPrevSi && c.pa.hdr >= 0 {
var cis ClientInfo
if err := json.Unmarshal(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr]), &cis); err == nil {
ci = &cis
var cis *ClientInfo
if c.pa.hdr >= 0 {
var hci ClientInfo
if err := json.Unmarshal(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr]), &hci); err == nil {
cis = &hci
}
}
if c.kind == LEAF && c.pa.hdr >= 0 && len(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr])) > 0 {
// Leaf nodes may forward a Nats-Request-Info from a remote domain,
// but the local server must replace it with the identity of the
// authenticated leaf connection instead of trusting forwarded values.
ci = c.getClientInfo(share)
if hadPrevSi {
ci.Service = acc.Name
// Check if we are moving into a share details account from a non-shared
// and add in server and cluster details.
if !share && (si.share || isSysImport) {
c.addServerAndClusterInfo(ci)
}
} else if !share && isSysImport {
c.addServerAndClusterInfo(ci)
}
} else if hadPrevSi && cis != nil {
ci = cis
ci.Service = acc.Name
// Check if we are moving into a share details account from a non-shared
// and add in server and cluster details.
if !share && (si.share || isSysImport) {
c.addServerAndClusterInfo(ci)
}
} else if c.kind != LEAF || c.pa.hdr < 0 || len(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr])) == 0 {
ci = c.getClientInfo(share)
@@ -4766,12 +4847,6 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
if !share && isSysImport {
c.addServerAndClusterInfo(ci)
}
} else if c.kind == LEAF && (si.share || isSysImport) {
// We have a leaf header here for ci, augment as above.
ci = c.getClientInfo(si.share)
if !si.share && isSysImport {
c.addServerAndClusterInfo(ci)
}
}
// Set clientInfo if present.
if ci != nil {

View File

@@ -131,19 +131,22 @@ func detectProxyProtoVersion(conn net.Conn) (version int, header []byte, err err
// readProxyProtoV1Header parses PROXY protocol v1 text format.
// Expects the "PROXY " prefix (6 bytes) to have already been consumed.
func readProxyProtoV1Header(conn net.Conn) (*proxyProtoAddr, error) {
// Returns any bytes that were read past the trailing CRLF so the caller can
// replay them into the next protocol layer.
func readProxyProtoV1Header(conn net.Conn) (*proxyProtoAddr, []byte, error) {
// Read rest of line (max 107 bytes total, already read 6)
maxRemaining := proxyProtoV1MaxLineLen - 6
// Read up to maxRemaining bytes at once (more efficient than byte-by-byte)
buf := make([]byte, maxRemaining)
var line []byte
var remaining []byte
for len(line) < maxRemaining {
// Read available data
n, err := conn.Read(buf[len(line):])
if err != nil {
return nil, fmt.Errorf("failed to read v1 line: %w", err)
return nil, nil, fmt.Errorf("failed to read v1 line: %w", err)
}
line = buf[:len(line)+n]
@@ -151,7 +154,8 @@ func readProxyProtoV1Header(conn net.Conn) (*proxyProtoAddr, error) {
// Look for CRLF in what we've read so far
for i := 0; i < len(line)-1; i++ {
if line[i] == '\r' && line[i+1] == '\n' {
// Found CRLF - extract just the line portion
// Found CRLF - keep any over-read bytes for the client parser.
remaining = append(remaining, line[i+2:]...)
line = line[:i]
goto foundCRLF
}
@@ -159,7 +163,7 @@ func readProxyProtoV1Header(conn net.Conn) (*proxyProtoAddr, error) {
}
// Exceeded max length without finding CRLF
return nil, fmt.Errorf("%w: v1 line too long", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: v1 line too long", errProxyProtoInvalid)
foundCRLF:
// Get parts from the protocol
@@ -167,17 +171,17 @@ foundCRLF:
// Validate format
if len(parts) < 1 {
return nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid)
}
// Handle UNKNOWN (health check, like v2 LOCAL)
if parts[0] == proxyProtoV1Unknown {
return nil, nil
return nil, remaining, nil
}
// Must have exactly 5 parts: protocol, src-ip, dst-ip, src-port, dst-port
if len(parts) != 5 {
return nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: invalid v1 format", errProxyProtoInvalid)
}
protocol := parts[0]
@@ -185,29 +189,29 @@ foundCRLF:
dstIP := net.ParseIP(parts[2])
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("%w: invalid address", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: invalid address", errProxyProtoInvalid)
}
// Parse ports
srcPort, err := strconv.ParseUint(parts[3], 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid source port: %w", err)
return nil, nil, fmt.Errorf("invalid source port: %w", err)
}
dstPort, err := strconv.ParseUint(parts[4], 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid dest port: %w", err)
return nil, nil, fmt.Errorf("invalid dest port: %w", err)
}
// Validate protocol matches IP version
if protocol == proxyProtoV1TCP4 && srcIP.To4() == nil {
return nil, fmt.Errorf("%w: TCP4 with IPv6 address", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: TCP4 with IPv6 address", errProxyProtoInvalid)
}
if protocol == proxyProtoV1TCP6 && srcIP.To4() != nil {
return nil, fmt.Errorf("%w: TCP6 with IPv4 address", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: TCP6 with IPv4 address", errProxyProtoInvalid)
}
if protocol != proxyProtoV1TCP4 && protocol != proxyProtoV1TCP6 {
return nil, fmt.Errorf("%w: invalid protocol %s", errProxyProtoInvalid, protocol)
return nil, nil, fmt.Errorf("%w: invalid protocol %s", errProxyProtoInvalid, protocol)
}
return &proxyProtoAddr{
@@ -215,25 +219,27 @@ foundCRLF:
srcPort: uint16(srcPort),
dstIP: dstIP,
dstPort: uint16(dstPort),
}, nil
}, remaining, nil
}
// readProxyProtoHeader reads and parses PROXY protocol (v1 or v2) from the connection.
// Automatically detects version and routes to appropriate parser.
// If the command is LOCAL/UNKNOWN (health check), it returns nil for addr and no error.
// If the command is PROXY, it returns the parsed address information.
// It also returns any bytes that were read past the v1 header terminator so the
// caller can replay them into the normal client parser.
// The connection must be fresh (no data read yet).
func readProxyProtoHeader(conn net.Conn) (*proxyProtoAddr, error) {
func readProxyProtoHeader(conn net.Conn) (*proxyProtoAddr, []byte, error) {
// Set read deadline to prevent hanging on slow/malicious clients
if err := conn.SetReadDeadline(time.Now().Add(proxyProtoReadTimeout)); err != nil {
return nil, err
return nil, nil, err
}
defer conn.SetReadDeadline(time.Time{})
// Detect version
version, firstBytes, err := detectProxyProtoVersion(conn)
if err != nil {
return nil, err
return nil, nil, err
}
switch version {
@@ -244,25 +250,26 @@ func readProxyProtoHeader(conn net.Conn) (*proxyProtoAddr, error) {
// Read rest of v2 signature (bytes 6-11, total 6 more bytes)
remaining := make([]byte, 6)
if _, err := io.ReadFull(conn, remaining); err != nil {
return nil, fmt.Errorf("failed to read v2 signature: %w", err)
return nil, nil, fmt.Errorf("failed to read v2 signature: %w", err)
}
// Verify full signature
fullSig := string(firstBytes) + string(remaining)
if fullSig != proxyProtoV2Sig {
return nil, fmt.Errorf("%w: invalid signature", errProxyProtoInvalid)
return nil, nil, fmt.Errorf("%w: invalid signature", errProxyProtoInvalid)
}
// Read rest of header: ver/cmd, fam/proto, addr-len (4 bytes)
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, fmt.Errorf("failed to read v2 header: %w", err)
return nil, nil, fmt.Errorf("failed to read v2 header: %w", err)
}
// Continue with parsing
return parseProxyProtoV2Header(conn, header)
addr, err := parseProxyProtoV2Header(conn, header)
return addr, nil, err
default:
return nil, fmt.Errorf("unsupported PROXY protocol version: %d", version)
return nil, nil, fmt.Errorf("unsupported PROXY protocol version: %d", version)
}
}

View File

@@ -66,7 +66,7 @@ func init() {
const (
// VERSION is the current version for the server.
VERSION = "2.12.5"
VERSION = "2.12.6"
// PROTO is the currently supported protocol.
// 0 was the original

View File

@@ -215,6 +215,9 @@ var (
// ErrMinimumVersionRequired is returned when a connection is not at the minimum version required.
ErrMinimumVersionRequired = errors.New("minimum version required")
// ErrLeafNodeMinVersionRejected is the leafnode protocol error prefix used
// when rejecting a remote due to leafnodes.min_version.
ErrLeafNodeMinVersionRejected = errors.New("connection rejected since minimum version required is")
// ErrInvalidMappingDestination is used for all subject mapping destination errors
ErrInvalidMappingDestination = errors.New("invalid mapping destination")

View File

@@ -1500,6 +1500,8 @@ func (s *Server) initEventTracking() {
type UserInfo struct {
UserID string `json:"user"`
Account string `json:"account"`
AccountName string `json:"account_name,omitempty"`
UserName string `json:"user_name,omitempty"`
Permissions *Permissions `json:"permissions,omitempty"`
Expires time.Duration `json:"expires,omitempty"`
}
@@ -1519,9 +1521,22 @@ func (s *Server) userInfoReq(sub *subscription, c *client, _ *Account, subject,
return
}
// Look up the requester's account directly from ci.Account rather than
// using the acc returned by getRequestInfo, which may resolve to the
// service account (ci.Service) when the request arrives via a chained
// service import.
var accountName string
if ci.Account != _EMPTY_ {
if reqAcc, _ := s.LookupAccount(ci.Account); reqAcc != nil {
accountName = reqAcc.getNameTag()
}
}
response.Data = &UserInfo{
UserID: ci.User,
Account: ci.Account,
AccountName: accountName,
UserName: ci.NameTag,
Permissions: c.publicPermissions(),
Expires: c.claimExpiration(),
}

View File

@@ -2486,25 +2486,25 @@ func (js *jetStream) checkBytesLimits(selectedLimits *JetStreamAccountLimits, ad
if addBytes < 0 {
addBytes = 1
}
totalBytes := addBytes + maxBytesOffset
totalBytes := addSaturate(addBytes, maxBytesOffset)
switch storage {
case MemoryStorage:
// Account limits defined.
if selectedLimits.MaxMemory >= 0 && currentRes+totalBytes > selectedLimits.MaxMemory {
if selectedLimits.MaxMemory >= 0 && (currentRes > selectedLimits.MaxMemory || totalBytes > selectedLimits.MaxMemory-currentRes) {
return NewJSMemoryResourcesExceededError()
}
// Check if this server can handle request.
if checkServer && js.memReserved+totalBytes > js.config.MaxMemory {
if checkServer && (js.memReserved > js.config.MaxMemory || totalBytes > js.config.MaxMemory-js.memReserved) {
return NewJSMemoryResourcesExceededError()
}
case FileStorage:
// Account limits defined.
if selectedLimits.MaxStore >= 0 && currentRes+totalBytes > selectedLimits.MaxStore {
if selectedLimits.MaxStore >= 0 && (currentRes > selectedLimits.MaxStore || totalBytes > selectedLimits.MaxStore-currentRes) {
return NewJSStorageResourcesExceededError()
}
// Check if this server can handle request.
if checkServer && js.storeReserved+totalBytes > js.config.MaxStore {
if checkServer && (js.storeReserved > js.config.MaxStore || totalBytes > js.config.MaxStore-js.storeReserved) {
return NewJSStorageResourcesExceededError()
}
}

View File

@@ -25,9 +25,9 @@ import (
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
"github.com/nats-io/nuid"
)
@@ -607,7 +607,7 @@ const JSApiStreamRestoreResponseType = "io.nats.jetstream.api.v1.stream_restore_
// JSApiStreamRemovePeerRequest is the required remove peer request.
type JSApiStreamRemovePeerRequest struct {
// Server name of the peer to be removed.
// Server name or peer ID of the peer to be removed.
Peer string `json:"peer"`
}
@@ -1595,9 +1595,9 @@ func (jsa *jsAccount) tieredReservation(tier string, cfg *StreamConfig) int64 {
// If tier is empty, all storage is flat and we should adjust for replicas.
// Otherwise if tiered, storage replication already taken into consideration.
if tier == _EMPTY_ && sa.cfg.Replicas > 1 {
reservation += (int64(sa.cfg.Replicas) * sa.cfg.MaxBytes)
reservation = addSaturate(reservation, mulSaturate(int64(sa.cfg.Replicas), sa.cfg.MaxBytes))
} else {
reservation += sa.cfg.MaxBytes
reservation = addSaturate(reservation, sa.cfg.MaxBytes)
}
}
}
@@ -1880,7 +1880,7 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account,
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
offset = req.Offset
offset = max(req.Offset, 0)
if req.Subject != _EMPTY_ {
filter = req.Subject
}
@@ -2016,7 +2016,7 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
offset = req.Offset
offset = max(req.Offset, 0)
if req.Subject != _EMPTY_ {
filter = req.Subject
}
@@ -2212,7 +2212,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
return
}
details, subjects = req.DeletedDetails, req.SubjectsFilter
offset = req.Offset
offset = max(req.Offset, 0)
}
mset, err := acc.lookupStream(streamName)
@@ -2625,13 +2625,17 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco
return
}
// Check to see if we are a member of the group and if the group has no leader.
// Peers here is a server name, convert to node name.
nodeName := getHash(req.Peer)
js.mu.RLock()
rg := sa.Group
// Check to see if we are a member of the group.
// Peer here is either a peer ID or a server name, convert to node name.
nodeName := getHash(req.Peer)
isMember := rg.isMember(nodeName)
if !isMember {
nodeName = req.Peer
isMember = rg.isMember(nodeName)
}
js.mu.RUnlock()
// Make sure we are a member.
@@ -3086,6 +3090,13 @@ func (s *Server) jsLeaderAccountPurgeRequest(sub *subscription, c *client, _ *Ac
var resp = JSApiAccountPurgeResponse{ApiResponse: ApiResponse{Type: JSApiAccountPurgeResponseType}}
// Check for path like separators in the name.
if strings.ContainsAny(accName, `\/`) {
resp.Error = NewJSStreamGeneralError(errors.New("account name can not contain path separators"))
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !s.JetStreamIsClustered() {
var streams []*stream
var ac *Account
@@ -4031,11 +4042,8 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
if stream != req.Config.Name && req.Config.Name == _EMPTY_ {
req.Config.Name = stream
}
// check stream config at the start of the restore process, not at the end
cfg, apiErr := s.checkStreamCfg(&req.Config, acc, false)
if apiErr != nil {
resp.Error = apiErr
if stream != req.Config.Name {
resp.Error = NewJSStreamMismatchError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
@@ -4045,6 +4053,14 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
return
}
// check stream config at the start of the restore process, not at the end
cfg, apiErr := s.checkStreamCfg(&req.Config, acc, false)
if apiErr != nil {
resp.Error = apiErr
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if err := acc.jsNonClusteredStreamLimitsCheck(&cfg); err != nil {
resp.Error = err
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
@@ -4065,30 +4081,12 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
return
}
s.processStreamRestore(ci, acc, &req.Config, subject, reply, string(msg))
s.processStreamRestore(ci, acc, &cfg, subject, reply, string(msg))
}
func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamConfig, subject, reply, msg string) <-chan error {
js := s.getJetStream()
var resp = JSApiStreamRestoreResponse{ApiResponse: ApiResponse{Type: JSApiStreamRestoreResponseType}}
snapDir := filepath.Join(js.config.StoreDir, snapStagingDir)
if _, err := os.Stat(snapDir); os.IsNotExist(err) {
if err := os.MkdirAll(snapDir, defaultDirPerms); err != nil {
resp.Error = &ApiError{Code: 503, Description: "JetStream unable to create temp storage for restore"}
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return nil
}
}
tfile, err := os.CreateTemp(snapDir, "js-restore-")
if err != nil {
resp.Error = NewJSTempStorageFailedError()
s.sendAPIErrResponse(ci, acc, subject, reply, msg, s.jsonResponse(&resp))
return nil
}
streamName := cfg.Name
s.Noticef("Starting restore for stream '%s > %s'", acc.Name, streamName)
@@ -4114,29 +4112,59 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC
}
// For signaling to upper layers.
var resultOnce sync.Once
var closeOnce sync.Once
resultCh := make(chan result, 1)
activeQ := newIPQueue[int](s, fmt.Sprintf("[ACC:%s] stream '%s' restore", acc.Name, streamName)) // of int
pr, pw := io.Pipe()
var total int
setResult := func(err error, reply string) {
resultOnce.Do(func() {
resultCh <- result{err, reply}
})
}
activeQ := newIPQueue[int](s, fmt.Sprintf("[ACC:%s] stream '%s' restore", acc.Name, streamName))
restoreCh := make(chan struct {
mset *stream
err error
}, 1)
closeWithError := func(err error) {
closeOnce.Do(func() {
if err != nil {
pw.CloseWithError(err)
} else {
pw.Close()
}
})
}
s.startGoRoutine(func() {
defer s.grWG.Done()
mset, err := acc.RestoreStream(cfg, pr)
if err != nil {
pr.CloseWithError(err)
} else {
pr.Close()
}
restoreCh <- struct {
mset *stream
err error
}{
mset: mset,
err: err,
}
})
// FIXME(dlc) - Probably take out of network path eventually due to disk I/O?
processChunk := func(sub *subscription, c *client, _ *Account, subject, reply string, msg []byte) {
// We require reply subjects to communicate back failures, flow etc. If they do not have one log and cancel.
if reply == _EMPTY_ {
sub.client.processUnsub(sub.sid)
resultCh <- result{
fmt.Errorf("restore for stream '%s > %s' requires reply subject for each chunk", acc.Name, streamName),
reply,
}
setResult(fmt.Errorf("restore for stream '%s > %s' requires reply subject for each chunk", acc.Name, streamName), reply)
return
}
// Account client messages have \r\n on end. This is an error.
if len(msg) < LEN_CR_LF {
sub.client.processUnsub(sub.sid)
resultCh <- result{
fmt.Errorf("restore for stream '%s > %s' received short chunk", acc.Name, streamName),
reply,
}
setResult(fmt.Errorf("restore for stream '%s > %s' received short chunk", acc.Name, streamName), reply)
return
}
// Adjust.
@@ -4144,26 +4172,32 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC
// This means we are complete with our transfer from the client.
if len(msg) == 0 {
s.Debugf("Finished staging restore for stream '%s > %s'", acc.Name, streamName)
resultCh <- result{err, reply}
s.Debugf("Finished streaming restore for stream '%s > %s'", acc.Name, streamName)
closeWithError(nil)
setResult(nil, reply)
return
}
// We track total and check on server limits.
// TODO(dlc) - We could check apriori and cancel initial request if we know it won't fit.
total += len(msg)
if js.wouldExceedLimits(FileStorage, total) {
s.resourcesExceededError(FileStorage)
resultCh <- result{NewJSInsufficientResourcesError(), reply}
return
}
// Signal activity before and after the blocking write.
// The pre-write signal refreshes the stall watchdog when the
// chunk arrives; the post-write signal refreshes it again once
// RestoreStream has consumed the data. This keeps the idle
// window between chunks anchored to the end of the previous
// write instead of its start.
activeQ.push(0)
// Append chunk to temp file. Mark as issue if we encounter an error.
if n, err := tfile.Write(msg); n != len(msg) || err != nil {
resultCh <- result{err, reply}
if reply != _EMPTY_ {
s.sendInternalAccountMsg(acc, reply, "-ERR 'storage failure during restore'")
if _, err := pw.Write(msg); err != nil {
closeWithError(err)
sub.client.processUnsub(sub.sid)
var resp = JSApiStreamCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamCreateResponseType}}
if IsNatsErr(err, JSStorageResourcesExceededErr, JSMemoryResourcesExceededErr) {
s.resourcesExceededError(cfg.Storage)
}
resp.Error = NewJSStreamRestoreError(err, Unless(err))
if s.sendInternalAccountMsg(acc, reply, s.jsonResponse(&resp)) == nil {
reply = _EMPTY_
}
setResult(err, reply)
return
}
@@ -4174,8 +4208,7 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC
sub, err := acc.subscribeInternal(restoreSubj, processChunk)
if err != nil {
tfile.Close()
os.Remove(tfile.Name())
closeWithError(err)
resp.Error = NewJSRestoreSubscribeFailedError(err, restoreSubj)
s.sendAPIErrResponse(ci, acc, subject, reply, msg, s.jsonResponse(&resp))
return nil
@@ -4185,14 +4218,14 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC
resp.DeliverSubject = restoreSubj
s.sendAPIResponse(ci, acc, subject, reply, msg, s.jsonResponse(resp))
// Returned to the caller to wait for completion.
doneCh := make(chan error, 1)
// Monitor the progress from another Go routine.
s.startGoRoutine(func() {
defer s.grWG.Done()
defer func() {
tfile.Close()
os.Remove(tfile.Name())
closeWithError(ErrConnectionClosed)
sub.client.processUnsub(sub.sid)
activeQ.unregister()
}()
@@ -4202,71 +4235,97 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC
defer notActive.Stop()
total := 0
var inputDone bool
var replySubj string
var inputErr error
var restoreDone bool
var restoreResult struct {
mset *stream
err error
}
finish := func(reply string, err error, mset *stream) {
end := time.Now().UTC()
s.publishAdvisory(acc, JSAdvisoryStreamRestoreCompletePre+"."+streamName, &JSRestoreCompleteAdvisory{
TypedEvent: TypedEvent{
Type: JSRestoreCompleteAdvisoryType,
ID: nuid.Next(),
Time: end,
},
Stream: streamName,
Start: start,
End: end,
Bytes: int64(total),
Client: ci.forAdvisory(),
Domain: domain,
})
var resp = JSApiStreamCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamCreateResponseType}}
if err != nil {
if IsNatsErr(err, JSStorageResourcesExceededErr, JSMemoryResourcesExceededErr) {
s.resourcesExceededError(cfg.Storage)
}
resp.Error = NewJSStreamRestoreError(err, Unless(err))
s.Warnf("Restore failed for %s for stream '%s > %s' in %v",
friendlyBytes(int64(total)), acc.Name, streamName, end.Sub(start))
} else {
msetCfg := mset.config()
resp.StreamInfo = &StreamInfo{
Created: mset.createdTime(),
State: mset.state(),
Config: *setDynamicStreamMetadata(&msetCfg),
TimeStamp: time.Now().UTC(),
}
s.Noticef("Completed restore of %s for stream '%s > %s' in %v",
friendlyBytes(int64(total)), acc.Name, streamName, end.Sub(start).Round(time.Millisecond))
}
if reply != _EMPTY_ {
s.sendInternalAccountMsg(acc, reply, s.jsonResponse(&resp))
}
doneCh <- err
}
for {
select {
case result := <-resultCh:
err := result.err
var mset *stream
// If we staged properly go ahead and do restore now.
if err == nil {
s.Debugf("Finalizing restore for stream '%s > %s'", acc.Name, streamName)
tfile.Seek(0, 0)
mset, err = acc.RestoreStream(cfg, tfile)
} else {
errStr := err.Error()
tmp := []rune(errStr)
tmp[0] = unicode.ToUpper(tmp[0])
s.Warnf(errStr)
replySubj = result.reply
inputDone = true
inputErr = result.err
notActive.Stop()
if result.err != nil {
closeWithError(result.err)
s.Warnf(result.err.Error())
}
end := time.Now().UTC()
// TODO(rip) - Should this have the error code in it??
s.publishAdvisory(acc, JSAdvisoryStreamRestoreCompletePre+"."+streamName, &JSRestoreCompleteAdvisory{
TypedEvent: TypedEvent{
Type: JSRestoreCompleteAdvisoryType,
ID: nuid.Next(),
Time: end,
},
Stream: streamName,
Start: start,
End: end,
Bytes: int64(total),
Client: ci.forAdvisory(),
Domain: domain,
})
var resp = JSApiStreamCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamCreateResponseType}}
if err != nil {
resp.Error = NewJSStreamRestoreError(err, Unless(err))
s.Warnf("Restore failed for %s for stream '%s > %s' in %v",
friendlyBytes(int64(total)), acc.Name, streamName, end.Sub(start))
} else {
msetCfg := mset.config()
resp.StreamInfo = &StreamInfo{
Created: mset.createdTime(),
State: mset.state(),
Config: *setDynamicStreamMetadata(&msetCfg),
TimeStamp: time.Now().UTC(),
if restoreDone {
err := inputErr
if err == nil {
err = restoreResult.err
}
s.Noticef("Completed restore of %s for stream '%s > %s' in %v",
friendlyBytes(int64(total)), acc.Name, streamName, end.Sub(start).Round(time.Millisecond))
finish(replySubj, err, restoreResult.mset)
return
}
case rr := <-restoreCh:
restoreDone = true
restoreResult = rr
if inputDone {
err := inputErr
if err == nil {
err = rr.err
}
finish(replySubj, err, rr.mset)
return
}
// On the last EOF, send back the stream info or error status.
s.sendInternalAccountMsg(acc, result.reply, s.jsonResponse(&resp))
// Signal to the upper layers.
doneCh <- err
return
case <-activeQ.ch:
if n, ok := activeQ.popOne(); ok {
total += n
notActive.Reset(activityInterval)
if !inputDone {
notActive.Reset(activityInterval)
}
}
case <-notActive.C:
err := fmt.Errorf("restore for stream '%s > %s' is stalled", acc, streamName)
err := fmt.Errorf("restore for stream '%s > %s' is stalled", acc.Name, streamName)
closeWithError(err)
doneCh <- err
return
}
@@ -4794,7 +4853,7 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
offset = req.Offset
offset = max(req.Offset, 0)
}
streamName := streamNameFromSubject(subject)
@@ -4922,7 +4981,7 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account,
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
offset = req.Offset
offset = max(req.Offset, 0)
}
streamName := streamNameFromSubject(subject)

View File

@@ -1437,7 +1437,7 @@ func (js *jetStream) getOrphans() (streams []*stream, consumers []*consumer) {
streams = append(streams, mset)
} else {
// This one is good, check consumers now.
for _, o := range mset.getConsumers() {
for _, o := range mset.getPublicConsumers() {
if sa.consumers[o.String()] == nil {
consumers = append(consumers, o)
}
@@ -2216,6 +2216,10 @@ func (js *jetStream) collectStreamAndConsumerChanges(c RaftNodeCheckpoint, strea
as = make(map[string]*streamAssignment)
streams[sa.Client.serviceAccount()] = as
}
// Preserve consumers from the previous assignment.
if osa := as[sa.Config.Name]; osa != nil {
sa.consumers = osa.consumers
}
as[sa.Config.Name] = sa
}
for _, cas := range ru.updateConsumers {
@@ -2483,7 +2487,9 @@ func (js *jetStream) applyMetaEntries(entries []*Entry, ru *recoveryUpdates) (bo
}
if e.Type == EntrySnapshot {
js.applyMetaSnapshot(e.Data, ru, isRecovering)
if err := js.applyMetaSnapshot(e.Data, ru, isRecovering); err != nil {
return isRecovering, didSnap, err
}
didSnap = true
} else if e.Type == EntryRemovePeer {
if !js.isMetaRecovering() {
@@ -6008,6 +6014,9 @@ func (js *jetStream) consumerAssignmentsOrInflightSeq(account, stream string) it
}
}
sa := js.streamAssignment(account, stream)
if sa == nil {
return
}
for _, ca := range sa.consumers {
// Skip if we already iterated over it as inflight.
if _, ok := inflight[ca.Name]; ok {
@@ -7534,9 +7543,9 @@ func (js *jetStream) tieredStreamAndReservationCount(accName, tier string, cfg *
// If tier is empty, all storage is flat and we should adjust for replicas.
// Otherwise if tiered, storage replication already taken into consideration.
if tier == _EMPTY_ && sa.Config.Replicas > 1 {
reservation += sa.Config.MaxBytes * int64(sa.Config.Replicas)
reservation = addSaturate(reservation, mulSaturate(int64(sa.Config.Replicas), sa.Config.MaxBytes))
} else {
reservation += sa.Config.MaxBytes
reservation = addSaturate(reservation, sa.Config.MaxBytes)
}
}
}
@@ -7646,6 +7655,17 @@ func (s *Server) jsClusteredStreamRequest(ci *ClientInfo, acc *Account, subject,
// Capture if we have existing/inflight assignment first.
if osa := js.streamAssignmentOrInflight(acc.Name, cfg.Name); osa != nil {
copyStreamMetadata(cfg, osa.Config)
// Set the index name on both to ensure the DeepEqual works
currentIName := make(map[string]struct{})
for _, s := range osa.Config.Sources {
currentIName[s.iname] = struct{}{}
}
for _, s := range cfg.Sources {
s.setIndexName()
if _, ok := currentIName[s.iname]; !ok {
s.iname = _EMPTY_
}
}
if !reflect.DeepEqual(osa.Config, cfg) {
resp.Error = NewJSStreamNameExistError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp))
@@ -8210,6 +8230,16 @@ func (s *Server) jsClusteredStreamRestoreRequest(
return
}
resp := JSApiStreamRestoreResponse{ApiResponse: ApiResponse{Type: JSApiStreamRestoreResponseType}}
// check stream config at the start of the restore process, not at the end
cfg, apiErr := s.checkStreamCfg(&req.Config, acc, false)
if apiErr != nil {
resp.Error = apiErr
s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp))
return
}
js.mu.Lock()
defer js.mu.Unlock()
@@ -8217,10 +8247,7 @@ func (s *Server) jsClusteredStreamRestoreRequest(
return
}
cfg := &req.Config
resp := JSApiStreamRestoreResponse{ApiResponse: ApiResponse{Type: JSApiStreamRestoreResponseType}}
if err := js.jsClusteredStreamLimitsCheck(acc, cfg); err != nil {
if err := js.jsClusteredStreamLimitsCheck(acc, &cfg); err != nil {
resp.Error = err
s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp))
return
@@ -8233,7 +8260,7 @@ func (s *Server) jsClusteredStreamRestoreRequest(
}
// Raft group selection and placement.
rg, err := js.createGroupForStream(ci, cfg)
rg, err := js.createGroupForStream(ci, &cfg)
if err != nil {
resp.Error = NewJSClusterNoPeersError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp))
@@ -8241,7 +8268,7 @@ func (s *Server) jsClusteredStreamRestoreRequest(
}
// Pick a preferred leader.
rg.setPreferred(s)
sa := &streamAssignment{Group: rg, Sync: syncSubjForStream(), Config: cfg, Subject: subject, Reply: reply, Client: ci, Created: time.Now().UTC()}
sa := &streamAssignment{Group: rg, Sync: syncSubjForStream(), Config: &cfg, Subject: subject, Reply: reply, Client: ci, Created: time.Now().UTC()}
// Now add in our restore state and pre-select a peer to handle the actual receipt of the snapshot.
sa.Restore = &req.State
if err := cc.meta.Propose(encodeAddStreamAssignment(sa)); err == nil {
@@ -8308,6 +8335,9 @@ func (s *Server) jsClusteredStreamListRequest(acc *Account, ci *ClientInfo, filt
}
scnt := len(streams)
if offset < 0 {
offset = 0
}
if offset > scnt {
offset = scnt
}
@@ -8458,6 +8488,9 @@ func (s *Server) jsClusteredConsumerListRequest(acc *Account, ci *ClientInfo, of
}
ocnt := len(consumers)
if offset < 0 {
offset = 0
}
if offset > ocnt {
offset = ocnt
}

View File

@@ -63,9 +63,9 @@ const (
// LEAF connection as opposed to a CLIENT.
leafNodeWSPath = "/leafnode"
// This is the time the server will wait, when receiving a CONNECT,
// before closing the connection if the required minimum version is not met.
leafNodeWaitBeforeClose = 5 * time.Second
// When a soliciting leafnode is rejected because it does not meet the
// configured minimum version, delay the next reconnect attempt by this long.
leafNodeMinVersionReconnectDelay = 5 * time.Second
)
type leaf struct {
@@ -691,9 +691,8 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
} else {
s.Debugf("Trying to connect as leafnode to remote server on %q%s", rURL.Host, ipStr)
// Check if proxy is configured first, then check if URL supports it
if proxyURL != _EMPTY_ && isWSURL(rURL) {
// Use proxy for WebSocket connections - use original hostname, resolved IP for connection
// Check if proxy is configured
if proxyURL != _EMPTY_ {
targetHost := rURL.Host
// If URL doesn't include port, add the default port for the scheme
if rURL.Port() == _EMPTY_ {
@@ -2082,17 +2081,11 @@ func (c *client) processLeafNodeConnect(s *Server, arg []byte, lang string) erro
if mv := s.getOpts().LeafNode.MinVersion; mv != _EMPTY_ {
major, minor, update, _ := versionComponents(mv)
if !versionAtLeast(proto.Version, major, minor, update) {
// We are going to send back an INFO because otherwise recent
// versions of the remote server would simply break the connection
// after 2 seconds if not receiving it. Instead, we want the
// other side to just "stall" until we finish waiting for the holding
// period and close the connection below.
// Send back an INFO so recent remote servers process the rejection
// cleanly, then close immediately. The soliciting side applies the
// reconnect delay when it processes the error.
s.sendPermsAndAccountInfo(c)
c.sendErrAndErr(fmt.Sprintf("connection rejected since minimum version required is %q", mv))
select {
case <-c.srv.quitCh:
case <-time.After(leafNodeWaitBeforeClose):
}
c.sendErrAndErr(fmt.Sprintf("%s %q", ErrLeafNodeMinVersionRejected, mv))
c.closeConnection(MinimumVersionRequired)
return ErrMinimumVersionRequired
}
@@ -3195,6 +3188,11 @@ func (c *client) leafProcessErr(errStr string) {
c.Errorf("Leafnode connection dropped with same cluster name error. Delaying attempt to reconnect for %v", delay)
return
}
if strings.Contains(errStr, ErrLeafNodeMinVersionRejected.Error()) {
_, delay := c.setLeafConnectDelayIfSoliciting(leafNodeMinVersionReconnectDelay)
c.Errorf("Leafnode connection dropped due to minimum version requirement. Delaying attempt to reconnect for %v", delay)
return
}
// We will look for Loop detected error coming from the other side.
// If we solicit, set the connect delay.
@@ -3217,7 +3215,10 @@ func (c *client) setLeafConnectDelayIfSoliciting(delay time.Duration) (string, t
}
c.leaf.remote.setConnectDelay(delay)
}
accName := c.acc.Name
var accName string
if c.acc != nil {
accName = c.acc.Name
}
c.mu.Unlock()
return accName, delay
}

View File

@@ -241,6 +241,7 @@ var (
errMQTTUnsupportedCharacters = errors.New("character ' ' not supported for MQTT topics")
errMQTTInvalidSession = errors.New("invalid MQTT session")
errMQTTInvalidRetainFlags = errors.New("invalid retained message flags")
errMQTTSessionCollision = errors.New("stored session does not match client ID")
)
type srvMQTT struct {
@@ -260,7 +261,7 @@ type mqttAccountSessionManager struct {
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]int64 // When connection connects with client ID already in use
flappers map[string]time.Time // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsgRef // retained messages
@@ -789,11 +790,17 @@ func (c *client) mqttParse(buf []byte) error {
}
break
}
if err = mqttCheckFixedHeaderFlags(pt, b&mqttPacketFlagMask); err != nil {
break
}
pl, complete, err = r.readPacketLen()
if err != nil || !complete {
break
}
if err = mqttCheckRemainingLength(pt, pl); err != nil {
break
}
switch pt {
// Packets that we receive back when we act as the "sender": PUBACK,
@@ -958,6 +965,43 @@ func (c *client) mqttParse(buf []byte) error {
return err
}
func mqttCheckFixedHeaderFlags(packetType, flags byte) error {
var expected byte
switch packetType {
case mqttPacketConnect, mqttPacketPubAck, mqttPacketPubRec, mqttPacketPubComp,
mqttPacketPing, mqttPacketDisconnect:
expected = 0
case mqttPacketPubRel, mqttPacketSub, mqttPacketUnsub:
expected = 0x2
case mqttPacketPub:
return nil
default:
return nil
}
if flags != expected {
return fmt.Errorf("invalid fixed header flags %x for packet type %x", flags, packetType)
}
return nil
}
func mqttCheckRemainingLength(packetType byte, pl int) error {
var expected int
switch packetType {
case mqttPacketConnect, mqttPacketPub, mqttPacketSub, mqttPacketUnsub:
return nil
case mqttPacketPubAck, mqttPacketPubRec, mqttPacketPubRel, mqttPacketPubComp:
expected = 2
case mqttPacketPing, mqttPacketDisconnect:
expected = 0
default:
return nil
}
if pl != expected {
return fmt.Errorf("invalid remaining length %d for packet type %x", pl, packetType)
}
return nil
}
func (c *client) mqttTraceMsg(msg []byte) {
maxTrace := c.srv.getOpts().MaxTracedMsgLen
if maxTrace > 0 && len(msg) > maxTrace {
@@ -1174,7 +1218,7 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
sessions: make(map[string]*mqttSession),
sessByHash: make(map[string]*mqttSession),
sessLocked: make(map[string]struct{}),
flappers: make(map[string]int64),
flappers: make(map[string]time.Time),
jsa: mqttJSA{
id: id,
c: c,
@@ -2090,7 +2134,7 @@ func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc *
//
// Lock held on entry.
func (as *mqttAccountSessionManager) addSessToFlappers(clientID string) {
as.flappers[clientID] = time.Now().UnixNano()
as.flappers[clientID] = time.Now()
if as.flapTimer == nil {
as.flapTimer = time.AfterFunc(mqttFlapCleanItvl, func() {
as.mu.Lock()
@@ -2099,9 +2143,9 @@ func (as *mqttAccountSessionManager) addSessToFlappers(clientID string) {
if as.flapTimer == nil {
return
}
now := time.Now().UnixNano()
now := time.Now()
for cID, tm := range as.flappers {
if now-tm > int64(mqttSessJailDur) {
if now.Sub(tm) > mqttSessJailDur {
delete(as.flappers, cID)
}
}
@@ -2971,16 +3015,12 @@ func mqttDecodeRetainedMessage(subject string, h, m []byte) (*mqttRetainedMsg, e
// Lock not held on entry, but session is in the locked map.
func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opts *Options) (*mqttSession, bool, error) {
jsa := &as.jsa
formatError := func(errTxt string, err error) (*mqttSession, bool, error) {
accName := jsa.c.acc.GetName()
return nil, false, fmt.Errorf("%s for account %q, session %q: %v", errTxt, accName, clientID, err)
}
hash := getHash(clientID)
smsg, err := jsa.loadSessionMsg(as.domainTk, hash)
if err != nil {
if isErrorOtherThan(err, JSNoMessageFoundErr) {
return formatError("loading session record", err)
return nil, false, fmt.Errorf("loading session record: %w", err)
}
// Message not found, so reate the session...
// Create a session and indicate that this session did not exist.
@@ -2991,7 +3031,10 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt
// We need to recover the existing record now.
ps := &mqttPersistedSession{}
if err := json.Unmarshal(smsg.Data, ps); err != nil {
return formatError(fmt.Sprintf("unmarshal of session record at sequence %v", smsg.Sequence), err)
return nil, false, fmt.Errorf("unmarshal of session record at sequence %v: %w", smsg.Sequence, err)
}
if ps.ID != clientID {
return nil, false, errMQTTSessionCollision
}
// Restore this session (even if we don't own it), the caller will do the right thing.
@@ -3673,8 +3716,12 @@ func (c *client) mqttParseConnect(r *mqttReader, hasMappings bool) (byte, *mqttC
c.mqtt.cid = nuid.Next()
}
// Spec [MQTT-3.1.3-4] and [MQTT-3.1.3-9]
if !utf8.ValidString(c.mqtt.cid) {
return mqttConnAckRCIdentifierRejected, nil, fmt.Errorf("invalid utf8 for client ID: %q", c.mqtt.cid)
if err := mqttValidateString(c.mqtt.cid, "client ID"); err != nil {
return mqttConnAckRCIdentifierRejected, nil, err
} else if !isValidName(c.mqtt.cid) {
// Should not contain characters that make it an invalid name for NATS subjects, etc.
err = fmt.Errorf("invalid character in %s %q", "client ID", c.mqtt.cid)
return mqttConnAckRCIdentifierRejected, nil, err
}
if hasWill {
@@ -3692,8 +3739,8 @@ func (c *client) mqttParseConnect(r *mqttReader, hasMappings bool) (byte, *mqttC
if len(topic) == 0 {
return 0, nil, errMQTTEmptyWillTopic
}
if !utf8.Valid(topic) {
return 0, nil, fmt.Errorf("invalid utf8 for Will topic %q", topic)
if err := mqttValidateTopic(topic, "Will topic"); err != nil {
return 0, nil, err
}
// Convert MQTT topic to NATS subject
cp.will.subject, err = mqttTopicToNATSPubSubject(topic)
@@ -3734,8 +3781,8 @@ func (c *client) mqttParseConnect(r *mqttReader, hasMappings bool) (byte, *mqttC
return mqttConnAckRCBadUserOrPassword, nil, errMQTTEmptyUsername
}
// Spec [MQTT-3.1.3-11]
if !utf8.ValidString(c.opts.Username) {
return mqttConnAckRCBadUserOrPassword, nil, fmt.Errorf("invalid utf8 for user name %q", c.opts.Username)
if err := mqttValidateString(c.opts.Username, "user name"); err != nil {
return mqttConnAckRCBadUserOrPassword, nil, err
}
}
@@ -3745,7 +3792,6 @@ func (c *client) mqttParseConnect(r *mqttReader, hasMappings bool) (byte, *mqttC
return 0, nil, err
}
c.opts.Token = c.opts.Password
c.opts.JWT = c.opts.Password
}
return 0, cp, nil
}
@@ -3835,7 +3881,7 @@ CHECK:
if tm, ok := asm.flappers[cid]; ok {
// If the last time it tried to connect was more than 1 sec ago,
// then accept and remove from flappers map.
if time.Now().UnixNano()-tm > int64(mqttSessJailDur) {
if time.Since(tm) > mqttSessJailDur {
asm.removeSessFromFlappers(cid)
} else {
// Will hold this client for a second and then close it. We
@@ -3883,13 +3929,19 @@ CHECK:
// Do we have an existing session for this client ID
es, exists := asm.sessions[cid]
asm.mu.Unlock()
formatError := func(err error) error {
return fmt.Errorf("%v for account %q, session %q", err, c.acc.GetName(), cid)
}
// The session is not in the map, but may be on disk, so try to recover
// or create the stream if not.
if !exists {
es, exists, err = asm.createOrRestoreSession(cid, s.getOpts())
if err != nil {
return err
if err == errMQTTSessionCollision {
sendConnAck(mqttConnAckRCIdentifierRejected, false)
}
return formatError(err)
}
}
if exists {
@@ -4041,6 +4093,9 @@ func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish, hasMapping
if len(pp.topic) == 0 {
return errMQTTTopicIsEmpty
}
if err := mqttValidateTopic(pp.topic, "topic"); err != nil {
return err
}
// Convert the topic to a NATS subject. This call will also check that
// there is no MQTT wildcards (Spec [MQTT-3.3.2-2] and [MQTT-4.7.1-1])
// Note that this may not result in a copy if there is no conversion.
@@ -4093,6 +4148,26 @@ func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish, hasMapping
return nil
}
func mqttValidateTopic(topic []byte, field string) error {
if !utf8.Valid(topic) {
return fmt.Errorf("invalid utf8 for %s %q", field, topic)
}
if bytes.IndexByte(topic, 0) >= 0 {
return fmt.Errorf("invalid null character in %s %q", field, topic)
}
return nil
}
func mqttValidateString(value string, field string) error {
if !utf8.ValidString(value) {
return fmt.Errorf("invalid utf8 for %s %q", field, value)
}
if strings.IndexByte(value, 0) >= 0 {
return fmt.Errorf("invalid null character in %s %q", field, value)
}
return nil
}
func mqttPubTrace(pp *mqttPublish) string {
dup := pp.flags&mqttPubFlagDup != 0
qos := mqttGetQoS(pp.flags)
@@ -4791,8 +4866,8 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool)
return 0, nil, errMQTTTopicFilterCannotBeEmpty
}
// Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1]
if !utf8.Valid(topic) {
return 0, nil, fmt.Errorf("invalid utf8 for topic filter %q", topic)
if err := mqttValidateTopic(topic, "topic filter"); err != nil {
return 0, nil, err
}
var qos byte
// We are going to report if we had an error during the conversion,

View File

@@ -367,6 +367,13 @@ func (c *client) initMsgTrace() *msgTrace {
}
}
dest = getHdrVal(MsgTraceDest)
if c.kind == CLIENT {
if td, ok := c.allowedMsgTraceDest(hdr, false); !ok {
return nil
} else if td != _EMPTY_ {
dest = td
}
}
// Check the destination to see if this is a valid public subject.
if !IsValidPublishSubject(dest) {
// We still have to return a msgTrace object (if traceOnly is set)

View File

@@ -6435,3 +6435,18 @@ func expandPath(p string) (string, error) {
return filepath.Join(home, p[1:]), nil
}
// RedactArgs redacts sensitive arguments from the command line.
// For example, turns '--pass=secret' into '--pass=[REDACTED]'.
func RedactArgs(args []string) {
secret := regexp.MustCompile("^-{1,2}(user|pass|auth)(=.*)?$")
for i, arg := range args {
if secret.MatchString(arg) {
if idx := strings.Index(arg, "="); idx != -1 {
args[i] = arg[:idx] + "=[REDACTED]"
} else if i+1 < len(args) {
args[i+1] = "[REDACTED]"
}
}
}
}

View File

@@ -63,8 +63,14 @@ func protoScanFieldValue(typ int, b []byte) (size int, err error) {
case 0:
_, size, err = protoScanVarint(b)
case 5: // fixed32
if len(b) < 4 {
return 0, errProtoInsufficient
}
size = 4
case 1: // fixed64
if len(b) < 8 {
return 0, errProtoInsufficient
}
size = 8
case 2: // length-delimited
size, err = protoScanBytes(b)

View File

@@ -3521,8 +3521,15 @@ func (n *raft) trackResponse(ar *appendEntryResponse) bool {
indexUpdateQ.push(ar.index)
}
// Ignore items already committed.
if ar.index <= n.commit {
// Ignore items already committed, or skip if this is not about an entry that matches our current term.
if ar.index <= n.commit || ar.term != n.term {
assert.AlwaysOrUnreachable(ar.term <= n.term, "Raft response term mismatch", map[string]any{
"n.accName": n.accName,
"n.group": n.group,
"n.id": n.id,
"n.term": n.term,
"ar.term": ar.term,
})
return false
}
@@ -4306,7 +4313,8 @@ func (n *raft) processAppendEntryResponse(ar *appendEntryResponse) {
if ar.success {
// The remote node successfully committed the append entry.
// They agree with our leadership and are happy with the state of the log.
// In this case ar.term doesn't matter.
// In this case ar.term was populated with the remote's pterm. If this matches
// our term, we can use it to check for quorum and up our commit.
var err error
var committed bool
@@ -5024,6 +5032,7 @@ func (n *raft) switchToCandidate() {
}
// Increment the term.
n.term++
n.vote = noVote
// Clear current Leader.
n.updateLeader(noLeader)
n.switchState(Candidate)

View File

@@ -88,7 +88,7 @@ type route struct {
// an implicit route and sending to the remote.
gossipMode byte
// This will be set in case of pooling so that a route can trigger
// the creation of the next after receiving the first PONG, ensuring
// the creation of the next after receiving a PONG, ensuring
// that authentication did not fail.
startNewRoute *routeInfo
}

View File

@@ -3452,7 +3452,7 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client {
pre = pre[:n]
}
conn = &tlsMixConn{conn, bytes.NewBuffer(pre)}
addr, err := readProxyProtoHeader(conn)
addr, proxyPre, err := readProxyProtoHeader(conn)
if err != nil && err != errProxyProtoUnrecognized {
// err != errProxyProtoUnrecognized implies that we detected a proxy
// protocol header but we failed to parse it, so don't continue.
@@ -3480,7 +3480,7 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client {
// that it's a non-proxied connection and we want the pre-read to remain
// for the next step.
if err == nil {
pre = nil
pre = proxyPre
}
// Because we have ProxyProtocol enabled, our earlier INFO message didn't
// include the client_ip. If we need to send it again then we will include

View File

@@ -42,7 +42,6 @@ type winServiceWrapper struct {
}
var dockerized = false
var startupDelay = 10 * time.Second
func init() {
if v, exists := os.LookupEnv("NATS_DOCKERIZED"); exists && v == "1" {
@@ -67,6 +66,7 @@ func (w *winServiceWrapper) Execute(args []string, changes <-chan svc.ChangeRequ
status <- svc.Status{State: svc.StartPending}
go w.server.Start()
var startupDelay = 10 * time.Second
if v, exists := os.LookupEnv("NATS_STARTUP_DELAY"); exists {
if delay, err := time.ParseDuration(v); err == nil {
startupDelay = delay
@@ -86,24 +86,32 @@ func (w *winServiceWrapper) Execute(args []string, changes <-chan svc.ChangeRequ
}
loop:
for change := range changes {
switch change.Cmd {
case svc.Interrogate:
status <- change.CurrentStatus
case svc.Stop, svc.Shutdown:
w.server.Shutdown()
break loop
case reopenLogCmd:
// File log re-open for rotating file logs.
w.server.ReOpenLogFile()
case ldmCmd:
go w.server.lameDuckMode()
case svc.ParamChange:
if err := w.server.Reload(); err != nil {
w.server.Errorf("Failed to reload server configuration: %s", err)
for {
select {
case change, ok := <-changes:
if !ok {
break loop
}
default:
w.server.Debugf("Unexpected control request: %v", change.Cmd)
switch change.Cmd {
case svc.Interrogate:
status <- change.CurrentStatus
case svc.Stop, svc.Shutdown:
w.server.Shutdown()
break loop
case reopenLogCmd:
// File log re-open for rotating file logs.
w.server.ReOpenLogFile()
case ldmCmd:
go w.server.lameDuckMode()
case svc.ParamChange:
if err := w.server.Reload(); err != nil {
w.server.Errorf("Failed to reload server configuration: %s", err)
}
default:
w.server.Debugf("Unexpected control request: %v", change.Cmd)
}
case <-w.server.quitCh:
break loop
}
}

View File

@@ -1830,7 +1830,7 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account, pedantic boo
}
}
// check for duplicates
// check sources for duplicates
var iNames = make(map[string]struct{})
for _, src := range cfg.Sources {
if src == nil || !isValidName(src.Name) {
@@ -2235,10 +2235,12 @@ func (jsa *jsAccount) configUpdateCheck(old, new *StreamConfig, s *Server, pedan
_, reserved = js.tieredStreamAndReservationCount(acc.Name, tier, &cfg)
}
// reservation does not account for this stream, hence add the old value
if tier == _EMPTY_ && old.Replicas > 1 {
reserved += old.MaxBytes * int64(old.Replicas)
} else {
reserved += old.MaxBytes
if old.MaxBytes > 0 {
if tier == _EMPTY_ && old.Replicas > 1 {
reserved = addSaturate(reserved, mulSaturate(int64(old.Replicas), old.MaxBytes))
} else {
reserved = addSaturate(reserved, old.MaxBytes)
}
}
if err := js.checkAllLimits(&selected, &cfg, reserved, maxBytesOffset); err != nil {
return nil, err
@@ -2807,6 +2809,12 @@ func (mset *stream) processMirrorMsgs(mirror *sourceInfo, ready *sync.WaitGroup)
// Grab stream quit channel.
mset.mu.Lock()
msgs, qch, siqch := mirror.msgs, mset.qch, mirror.qch
// If the mirror was already canceled before we got here, exit early.
if siqch == nil {
mset.mu.Unlock()
ready.Done()
return
}
// Set the last seen as now so that we don't fail at the first check.
mirror.last.Store(time.Now().UnixNano())
mset.mu.Unlock()
@@ -3412,6 +3420,7 @@ func (mset *stream) setupMirrorConsumer() error {
"consumer": mirror.cname,
},
) {
mirror.wg.Done()
ready.Done()
}
}
@@ -3974,7 +3983,6 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool {
} else {
err = mset.processJetStreamMsg(m.subj, _EMPTY_, hdr, msg, 0, 0, nil, true, true)
}
if err != nil {
s := mset.srv
if strings.Contains(err.Error(), "no space left") {
@@ -3984,31 +3992,35 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool {
mset.mu.RLock()
accName, sname, iName := mset.acc.Name, mset.cfg.Name, si.iname
mset.mu.RUnlock()
// Can happen temporarily all the time during normal operations when the sourcing stream
// is working queue/interest with a limit and discard new.
// TODO - Improve sourcing to WQ with limit and new to use flow control rather than re-creating the consumer.
if errors.Is(err, ErrMaxMsgs) || errors.Is(err, ErrMaxBytes) {
// Can happen temporarily all the time during normal operations when the sourcing stream is discard new
// (example use case is for sourcing into a work queue)
// TODO - Maybe improve sourcing to WQ with limit and new to use flow control rather than re-creating the consumer.
if errors.Is(err, ErrMaxMsgs) || errors.Is(err, ErrMaxBytes) || errors.Is(err, ErrMaxMsgsPerSubject) {
// Do not need to do a full retry that includes finding the last sequence in the stream
// for that source. Just re-create starting with the seq we couldn't store instead.
mset.mu.Lock()
mset.retrySourceConsumerAtSeq(iName, si.sseq)
mset.mu.Unlock()
} else {
// Log some warning for errors other than errLastSeqMismatch or errMaxMsgs.
if !errors.Is(err, errLastSeqMismatch) {
// Log some warning for errors other than errLastSeqMismatch.
if !errors.Is(err, errLastSeqMismatch) && !errors.Is(err, errMsgIdDuplicate) {
s.RateLimitWarnf("Error processing inbound source %q for '%s' > '%s': %v",
iName, accName, sname, err)
}
// Retry in all type of errors if we are still leader.
// Retry in all type of errors we do not want to skip if we are still leader.
if mset.isLeader() {
// This will make sure the source is still in mset.sources map,
// find the last sequence and then call setupSourceConsumer.
iNameMap := map[string]struct{}{iName: {}}
mset.setStartingSequenceForSources(iNameMap)
mset.mu.Lock()
mset.retrySourceConsumerAtSeq(iName, si.sseq+1)
mset.mu.Unlock()
if !errors.Is(err, errMsgIdDuplicate) {
// This will make sure the source is still in mset.sources map,
// find the last sequence and then call setupSourceConsumer.
iNameMap := map[string]struct{}{iName: {}}
mset.setStartingSequenceForSources(iNameMap)
mset.mu.Lock()
mset.retrySourceConsumerAtSeq(iName, si.sseq+1)
mset.mu.Unlock()
} else {
// skipping the message but keep processing the rest of the batch
return true
}
}
}
}
@@ -5403,6 +5415,7 @@ func (mset *stream) getDirectRequest(req *JSApiMsgGetRequest, reply string) {
// processInboundJetStreamMsg handles processing messages bound for a stream.
func (mset *stream) processInboundJetStreamMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy.
hdr = removeHeaderStatusIfPresent(hdr)
if mt, traceOnly := c.isMsgTraceEnabled(); mt != nil {
// If message is delivered, we need to disable the message trace headers
// to prevent a trace event to be generated when a stored message

View File

@@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"math"
"math/bits"
"net"
"net/url"
"reflect"
@@ -363,3 +364,23 @@ func parallelTaskQueue(mp int) chan<- func() {
}
return tq
}
// addSaturate returns a + b, saturating at math.MaxInt64.
// Both a and b must be non-negative.
func addSaturate(a, b int64) int64 {
sum, carry := bits.Add64(uint64(a), uint64(b), 0)
if carry != 0 || sum > uint64(math.MaxInt64) {
return math.MaxInt64
}
return int64(sum)
}
// mulSaturate returns a * b, saturating at math.MaxInt64.
// Both a and b must be non-negative.
func mulSaturate(a, b int64) int64 {
hi, lo := bits.Mul64(uint64(a), uint64(b))
if hi != 0 || lo > uint64(math.MaxInt64) {
return math.MaxInt64
}
return int64(lo)
}

View File

@@ -60,6 +60,8 @@ const (
wsMaxControlPayloadSize = 125
wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size
wsCompressThreshold = 64 // Don't compress for small buffer(s)
wsMaxMsgPayloadMultiple = 8
wsMaxMsgPayloadLimit = 64 * 1024 * 1024
wsCloseSatusSize = 2
// From https://tools.ietf.org/html/rfc6455#section-11.7
@@ -180,6 +182,21 @@ func (r *wsReadInfo) resetCompressedState() {
r.csz = 0
}
// Compressed WebSocket messages have to be accumulated before they can be
// decompressed and handed to the parser, so this transport limit needs to
// allow batching several max_payload-sized NATS operations while still
// capping resource usage on the buffered compressed path.
func wsMaxMessageSize(mpay int) uint64 {
if mpay <= 0 {
mpay = MAX_PAYLOAD_SIZE
}
limit := uint64(mpay) * wsMaxMsgPayloadMultiple
if limit > wsMaxMsgPayloadLimit {
limit = wsMaxMsgPayloadLimit
}
return limit
}
// Returns a slice containing `needed` bytes from the given buffer `buf`
// starting at position `pos`, and possibly read from the given reader `r`.
// When bytes are present in `buf`, the `pos` is incremented by the number
@@ -217,17 +234,44 @@ func (c *client) isWebsocket() bool {
//
// Client lock MUST NOT be held on entry.
func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) {
var bufs [][]byte
err := c.wsReadLoop(r, ior, buf, func(b []byte, compressed, final bool) error {
if compressed {
return errors.New("compressed websocket frames require wsReadAndParse")
}
bufs = append(bufs, b)
return nil
})
return bufs, err
}
func (c *client) wsReadAndParse(r *wsReadInfo, ior io.Reader, buf []byte) error {
mpay := int(atomic.LoadInt32(&c.mpay))
if mpay <= 0 {
mpay = MAX_PAYLOAD_SIZE
}
return c.wsReadLoop(r, ior, buf, func(b []byte, compressed, final bool) error {
if compressed {
if err := c.wsDecompressAndParse(r, b, final, mpay); err != nil {
r.resetCompressedState()
return err
}
if final {
r.fc = false
}
return nil
}
return c.parse(b)
})
}
func (c *client) wsReadLoop(r *wsReadInfo, ior io.Reader, buf []byte, handle func([]byte, bool, bool) error) error {
var (
bufs [][]byte
tmpBuf []byte
err error
pos uint64
max = uint64(len(buf))
mpay = int(atomic.LoadInt32(&c.mpay))
)
if mpay <= 0 {
mpay = MAX_PAYLOAD_SIZE
}
for pos != max {
if r.fs {
b0 := buf[pos]
@@ -235,23 +279,23 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
final := b0&wsFinalBit != 0
compressed := b0&wsRsv1Bit != 0
if b0&(wsRsv2Bit|wsRsv3Bit) != 0 {
return bufs, c.wsHandleProtocolError("RSV2 and RSV3 must be clear")
return c.wsHandleProtocolError("RSV2 and RSV3 must be clear")
}
if compressed && !c.ws.compress {
return bufs, c.wsHandleProtocolError("compressed frame received without negotiated permessage-deflate")
return c.wsHandleProtocolError("compressed frame received without negotiated permessage-deflate")
}
pos++
tmpBuf, pos, err = wsGet(ior, buf, pos, 1)
if err != nil {
return bufs, err
return err
}
b1 := tmpBuf[0]
// Clients MUST set the mask bit. If not set, reject.
// However, LEAF by default will not have masking, unless they are forced to, by configuration.
if r.mask && b1&wsMaskBit == 0 {
return bufs, c.wsHandleProtocolError("mask bit missing")
return c.wsHandleProtocolError("mask bit missing")
}
// Store size in case it is < 125
@@ -260,46 +304,46 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
switch frameType {
case wsPingMessage, wsPongMessage, wsCloseMessage:
if r.rem > wsMaxControlPayloadSize {
return bufs, c.wsHandleProtocolError(
return c.wsHandleProtocolError(
fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
wsMaxControlPayloadSize))
}
if !final {
return bufs, c.wsHandleProtocolError("control frame does not have final bit set")
return c.wsHandleProtocolError("control frame does not have final bit set")
}
if compressed {
return bufs, c.wsHandleProtocolError("control frame must not be compressed")
return c.wsHandleProtocolError("control frame must not be compressed")
}
case wsTextMessage, wsBinaryMessage:
if !r.ff {
return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received")
return c.wsHandleProtocolError("new message started before final frame for previous message was received")
}
r.ff = final
r.fc = compressed
case wsContinuationFrame:
// Compressed bit must be only set in the first frame
if r.ff || compressed {
return bufs, c.wsHandleProtocolError("invalid continuation frame")
return c.wsHandleProtocolError("invalid continuation frame")
}
r.ff = final
default:
return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
return c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
}
switch r.rem {
case 126:
tmpBuf, pos, err = wsGet(ior, buf, pos, 2)
if err != nil {
return bufs, err
return err
}
r.rem = uint64(binary.BigEndian.Uint16(tmpBuf))
case 127:
tmpBuf, pos, err = wsGet(ior, buf, pos, 8)
if err != nil {
return bufs, err
return err
}
if r.rem = binary.BigEndian.Uint64(tmpBuf); r.rem&(uint64(1)<<63) != 0 {
return bufs, c.wsHandleProtocolError("invalid 64-bit payload length")
return c.wsHandleProtocolError("invalid 64-bit payload length")
}
}
@@ -307,7 +351,7 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
// Read masking key
tmpBuf, pos, err = wsGet(ior, buf, pos, 4)
if err != nil {
return bufs, err
return err
}
copy(r.mkey[:], tmpBuf)
r.mkpos = 0
@@ -317,7 +361,7 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
if wsIsControlFrame(frameType) {
pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos)
if err != nil {
return bufs, err
return err
}
continue
}
@@ -326,59 +370,26 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
r.fs = false
}
if pos < max {
var b []byte
var n uint64
n = r.rem
n := r.rem
if pos+n > max {
n = max - pos
}
b = buf[pos : pos+n]
b := buf[pos : pos+n]
pos += n
r.rem -= n
// If needed, unmask the buffer
if r.mask {
r.unmask(b)
}
addToBufs := true
// Handle compressed message
if r.fc {
// Assume that we may have continuation frames or not the full payload.
addToBufs = false
if r.csz+uint64(len(b)) > uint64(mpay) {
r.resetCompressedState()
return bufs, ErrMaxPayload
}
// Make a copy of the buffer before adding it to the list
// of compressed fragments.
r.cbufs = append(r.cbufs, append([]byte(nil), b...))
r.csz += uint64(len(b))
// When we have the final frame and we have read the full payload,
// we can decompress it.
if r.ff && r.rem == 0 {
b, err = r.decompress(mpay)
if err != nil {
r.resetCompressedState()
return bufs, err
}
r.fc = false
// Now we can add to `bufs`
addToBufs = true
}
if err := handle(b, r.fc, r.ff && r.rem == 0); err != nil {
return err
}
// For non compressed frames, or when we have decompressed the
// whole message.
if addToBufs {
bufs = append(bufs, b)
}
// If payload has been fully read, then indicate that next
// is the start of a frame.
if r.rem == 0 {
r.fs = true
}
}
}
return bufs, nil
return nil
}
func (r *wsReadInfo) Read(dst []byte) (int, error) {
@@ -434,45 +445,65 @@ func (r *wsReadInfo) ReadByte() (byte, error) {
return b, nil
}
// decompress decompresses the collected buffers.
// The size of the decompressed buffer will be limited to the `mpay` value.
// If, while decompressing, the resulting uncompressed buffer exceeds this
// limit, the decompression stops and an empty buffer and the ErrMaxPayload
// error are returned.
func (r *wsReadInfo) decompress(mpay int) ([]byte, error) {
// If not limit is specified, use the default maximum payload size.
if mpay <= 0 {
mpay = MAX_PAYLOAD_SIZE
func (c *client) wsDecompressAndParse(r *wsReadInfo, b []byte, final bool, mpay int) error {
limit := wsMaxMessageSize(mpay)
if len(b) > 0 {
if r.csz+uint64(len(b)) > limit {
return ErrMaxPayload
}
r.cbufs = append(r.cbufs, append([]byte(nil), b...))
r.csz += uint64(len(b))
}
if !final {
return nil
}
if r.csz+uint64(len(compressLastBlock)) > limit {
return ErrMaxPayload
}
r.coff = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
r.cbufs = append(r.cbufs, compressLastBlock)
// Get a decompressor from the pool and bind it to this object (wsReadInfo)
// that provides Read() and ReadByte() APIs that will consume the compressed
// buffers (r.cbufs).
r.csz += uint64(len(compressLastBlock))
r.coff = 0
d, _ := decompressorPool.Get().(io.ReadCloser)
if d == nil {
d = flate.NewReader(r)
} else {
d.(flate.Resetter).Reset(r, nil)
}
// Use a LimitedReader to limit the decompressed size.
// We use "limit+1" bytes for "N" so we can detect if the limit is exceeded.
defer func() {
d.Close()
decompressorPool.Put(d)
r.cbufs = nil
r.coff = 0
r.csz = 0
}()
lr := io.LimitedReader{R: d, N: int64(mpay + 1)}
b, err := io.ReadAll(&lr)
if err == nil && len(b) > mpay {
// Decompressed data exceeds the maximum payload size.
b, err = nil, ErrMaxPayload
buf := make([]byte, 32*1024)
total := 0
for {
n, err := lr.Read(buf)
if n > 0 {
pn := n
if total+n > mpay {
pn = mpay - total
}
if pn > 0 {
if err := c.parse(buf[:pn]); err != nil {
return err
}
}
total += n
if total > mpay {
return ErrMaxPayload
}
}
if err == nil {
continue
}
if err == io.EOF {
return nil
}
return err
}
lr.R = nil
decompressorPool.Put(d)
// Now reset the compressed buffers list.
r.cbufs = nil
r.coff = 0
r.csz = 0
return b, err
}
// Handles the PING, PONG and CLOSE websocket control frames.

6
vendor/modules.txt vendored
View File

@@ -1154,10 +1154,10 @@ github.com/mschoch/smat
# github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
## explicit
github.com/munnerz/goautoneg
# github.com/nats-io/jwt/v2 v2.8.0
## explicit; go 1.23.0
# github.com/nats-io/jwt/v2 v2.8.1
## explicit; go 1.25.0
github.com/nats-io/jwt/v2
# github.com/nats-io/nats-server/v2 v2.12.5
# github.com/nats-io/nats-server/v2 v2.12.6
## explicit; go 1.25.0
github.com/nats-io/nats-server/v2/conf
github.com/nats-io/nats-server/v2/internal/fastrand