From 00d49804cf9be771d9ab7eaaf4730d2be0effa9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Duffeck?= Date: Fri, 21 Mar 2025 12:29:25 +0100 Subject: [PATCH] Bump reva to pull in the latest fixes --- go.mod | 9 +- go.sum | 20 +- vendor/github.com/google/go-tpm/LICENSE | 202 ++ .../google/go-tpm/legacy/tpm2/README.md | 35 + .../google/go-tpm/legacy/tpm2/constants.go | 575 ++++ .../google/go-tpm/legacy/tpm2/error.go | 362 +++ .../google/go-tpm/legacy/tpm2/kdf.go | 116 + .../google/go-tpm/legacy/tpm2/open_other.go | 57 + .../google/go-tpm/legacy/tpm2/open_windows.go | 39 + .../google/go-tpm/legacy/tpm2/structures.go | 1112 ++++++++ .../google/go-tpm/legacy/tpm2/tpm2.go | 2326 +++++++++++++++++ .../google/go-tpm/tpmutil/encoding.go | 211 ++ .../google/go-tpm/tpmutil/poll_other.go | 10 + .../google/go-tpm/tpmutil/poll_unix.go | 32 + .../github.com/google/go-tpm/tpmutil/run.go | 113 + .../google/go-tpm/tpmutil/run_other.go | 111 + .../google/go-tpm/tpmutil/run_windows.go | 84 + .../google/go-tpm/tpmutil/structures.go | 195 ++ .../google/go-tpm/tpmutil/tbs/tbs_windows.go | 267 ++ .../nats-io/nats-server/v2/conf/fuzz.go | 1 - .../nats-io/nats-server/v2/conf/parse.go | 55 + .../nats-io/nats-server/v2/logger/syslog.go | 1 - .../nats-io/nats-server/v2/server/README.md | 2 +- .../nats-io/nats-server/v2/server/accounts.go | 207 +- .../nats-io/nats-server/v2/server/auth.go | 187 +- .../nats-io/nats-server/v2/server/client.go | 338 ++- .../nats-io/nats-server/v2/server/const.go | 7 +- .../nats-io/nats-server/v2/server/consumer.go | 602 ++++- .../nats-server/v2/server/disk_avail.go | 1 - .../v2/server/disk_avail_netbsd.go | 1 - .../v2/server/disk_avail_openbsd.go | 1 - .../nats-server/v2/server/disk_avail_wasm.go | 1 - .../v2/server/disk_avail_windows.go | 1 - .../nats-io/nats-server/v2/server/errors.go | 10 +- .../nats-io/nats-server/v2/server/errors.json | 150 +- .../nats-io/nats-server/v2/server/events.go | 106 +- .../nats-server/v2/server/filestore.go | 673 ++++- .../nats-io/nats-server/v2/server/fuzz.go | 1 - .../nats-io/nats-server/v2/server/gateway.go | 35 +- .../nats-io/nats-server/v2/server/ipqueue.go | 167 +- .../nats-server/v2/server/jetstream.go | 68 +- .../nats-server/v2/server/jetstream_api.go | 899 +++++-- .../v2/server/jetstream_cluster.go | 550 ++-- .../v2/server/jetstream_errors_generated.go | 290 +- .../nats-server/v2/server/jetstream_events.go | 39 + .../v2/server/jetstream_versioning.go | 179 ++ .../nats-io/nats-server/v2/server/jwt.go | 6 +- .../nats-io/nats-server/v2/server/leafnode.go | 126 +- .../nats-io/nats-server/v2/server/memstore.go | 426 ++- .../nats-io/nats-server/v2/server/monitor.go | 108 +- .../nats-io/nats-server/v2/server/mqtt.go | 423 ++- .../nats-io/nats-server/v2/server/msgtrace.go | 846 ++++++ .../nats-io/nats-server/v2/server/opts.go | 557 +++- .../nats-io/nats-server/v2/server/parser.go | 24 +- .../nats-io/nats-server/v2/server/proto.go | 269 ++ .../nats-server/v2/server/pse/pse_freebsd.go | 1 - .../nats-server/v2/server/pse/pse_rumprun.go | 1 - .../nats-server/v2/server/pse/pse_wasm.go | 1 - .../nats-server/v2/server/pse/pse_windows.go | 1 - .../nats-server/v2/server/pse/pse_zos.go | 1 - .../nats-io/nats-server/v2/server/raft.go | 266 +- .../nats-io/nats-server/v2/server/reload.go | 19 +- .../nats-io/nats-server/v2/server/route.go | 282 +- .../nats-io/nats-server/v2/server/sendq.go | 7 +- .../nats-io/nats-server/v2/server/server.go | 135 +- .../nats-io/nats-server/v2/server/service.go | 1 - .../nats-io/nats-server/v2/server/signal.go | 5 +- .../nats-server/v2/server/signal_wasm.go | 1 - .../nats-io/nats-server/v2/server/store.go | 15 +- .../nats-io/nats-server/v2/server/stream.go | 800 +++++- .../v2/server/subject_transform.go | 11 +- .../nats-io/nats-server/v2/server/sublist.go | 14 +- .../nats-server/v2/server/sysmem/mem_bsd.go | 1 - .../v2/server/sysmem/mem_darwin.go | 1 - .../nats-server/v2/server/sysmem/mem_linux.go | 1 - .../nats-server/v2/server/sysmem/mem_wasm.go | 1 - .../v2/server/sysmem/mem_windows.go | 1 - .../nats-server/v2/server/sysmem/mem_zos.go | 1 - .../nats-server/v2/server/sysmem/sysctl.go | 1 - .../nats-io/nats-server/v2/server/thw/thw.go | 257 ++ .../v2/server/tpm/js_ek_tpm_other.go | 23 + .../v2/server/tpm/js_ek_tpm_windows.go | 281 ++ .../nats-io/nats-server/v2/server/util.go | 5 +- .../nats-server/v2/server/websocket.go | 85 +- vendor/github.com/onsi/ginkgo/v2/CHANGELOG.md | 12 + .../ginkgo/v2/ginkgo/build/build_command.go | 2 +- .../onsi/ginkgo/v2/ginkgo/command/command.go | 2 +- .../onsi/ginkgo/v2/ginkgo/internal/compile.go | 8 +- .../onsi/ginkgo/v2/ginkgo/run/run_command.go | 2 +- .../ginkgo/v2/ginkgo/watch/watch_command.go | 2 +- .../github.com/onsi/ginkgo/v2/types/config.go | 10 +- .../github.com/onsi/ginkgo/v2/types/errors.go | 7 + .../onsi/ginkgo/v2/types/version.go | 2 +- .../v2/pkg/storage/fs/posix/lookup/lookup.go | 8 +- .../reva/v2/pkg/storage/fs/posix/posix.go | 2 +- .../v2/pkg/storage/fs/posix/tree/revisions.go | 40 +- .../reva/v2/pkg/storage/fs/posix/tree/tree.go | 2 +- .../decomposedfs/metadata/hybrid_backend.go | 48 +- .../metadata/messagepack_backend.go | 14 +- .../decomposedfs/metadata/xattrs_backend.go | 2 +- .../pkg/decomposedfs/tree/revisions.go | 50 +- vendor/golang.org/x/time/rate/rate.go | 17 +- vendor/modules.txt | 17 +- 103 files changed, 14058 insertions(+), 1641 deletions(-) create mode 100644 vendor/github.com/google/go-tpm/LICENSE create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/README.md create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/constants.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/error.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/kdf.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/open_other.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/open_windows.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/structures.go create mode 100644 vendor/github.com/google/go-tpm/legacy/tpm2/tpm2.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/encoding.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/poll_other.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/poll_unix.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/run.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/run_other.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/run_windows.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/structures.go create mode 100644 vendor/github.com/google/go-tpm/tpmutil/tbs/tbs_windows.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/jetstream_versioning.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/msgtrace.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/proto.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/thw/thw.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_other.go create mode 100644 vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_windows.go diff --git a/go.mod b/go.mod index 64199647ee..83870ea959 100644 --- a/go.mod +++ b/go.mod @@ -55,15 +55,15 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/mna/pigeon v1.3.0 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 - github.com/nats-io/nats-server/v2 v2.10.26 + github.com/nats-io/nats-server/v2 v2.11.0 github.com/nats-io/nats.go v1.39.1 github.com/oklog/run v1.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/onsi/ginkgo v1.16.5 - github.com/onsi/ginkgo/v2 v2.23.1 + github.com/onsi/ginkgo/v2 v2.23.2 github.com/onsi/gomega v1.36.2 github.com/open-policy-agent/opa v1.2.0 - github.com/opencloud-eu/reva/v2 v2.28.1-0.20250320135948-a946c0d6d289 + github.com/opencloud-eu/reva/v2 v2.28.1-0.20250321112659-61a430bfb4c5 github.com/orcaman/concurrent-map v1.0.0 github.com/owncloud/libre-graph-api-go v1.0.5-0.20240829135935-80dc00d6f5ea github.com/pkg/errors v0.9.1 @@ -215,6 +215,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/gomodule/redigo v1.9.2 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/go-tpm v0.9.3 // indirect github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/gookit/color v1.5.4 // indirect @@ -321,7 +322,7 @@ require ( go.uber.org/zap v1.23.0 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/sys v0.31.0 // indirect - golang.org/x/time v0.10.0 // indirect + golang.org/x/time v0.11.0 // indirect golang.org/x/tools v0.31.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/go.sum b/go.sum index 8c4bfcb58d..8da190d312 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 h1:I9YN9WMo3SUh7p/ github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964/go.mod h1:eFiR01PwTcpbzXtdMces7zxg6utvFM5puiWHpWB8D/k= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op h1:+OSa/t11TFhqfrX0EOSqQBDJ0YlpmK0rDSiB19dg9M0= +github.com/antithesishq/antithesis-sdk-go v0.4.3-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= @@ -526,6 +528,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA= github.com/google/go-tika v0.3.1/go.mod h1:DJh5N8qxXIl85QkqmXknd+PeeRkUOTbvwyYf7ieDz6c= +github.com/google/go-tpm v0.9.3 h1:+yx0/anQuGzi+ssRqeD6WpXjW2L/V0dItUayO0i9sRc= +github.com/google/go-tpm v0.9.3/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -823,8 +827,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8= github.com/nats-io/jwt/v2 v2.7.3 h1:6bNPK+FXgBeAqdj4cYQ0F8ViHRbi7woQLq4W29nUAzE= github.com/nats-io/jwt/v2 v2.7.3/go.mod h1:GvkcbHhKquj3pkioy5put1wvPxs78UlZ7D/pY+BgZk4= -github.com/nats-io/nats-server/v2 v2.10.26 h1:2i3rAsn4x5/2eOt2NEmuI/iSb8zfHpIUI7yiaOWbo2c= -github.com/nats-io/nats-server/v2 v2.10.26/go.mod h1:SGzoWGU8wUVnMr/HJhEMv4R8U4f7hF4zDygmRxpNsvg= +github.com/nats-io/nats-server/v2 v2.11.0 h1:fdwAT1d6DZW/4LUz5rkvQUe5leGEwjjOQYntzVRKvjE= +github.com/nats-io/nats-server/v2 v2.11.0/go.mod h1:leXySghbdtXSUmWem8K9McnJ6xbJOb0t9+NQ5HTRZjI= github.com/nats-io/nats.go v1.39.1 h1:oTkfKBmz7W047vRxV762M67ZdXeOtUgvbBaNoQ+3PPk= github.com/nats-io/nats.go v1.39.1/go.mod h1:MgRb8oOdigA6cYpEPhXJuRVH6UE/V4jblJ2jQ27IXYM= github.com/nats-io/nkeys v0.4.10 h1:glmRrpCmYLHByYcePvnTBEAwawwapjCPMjy2huw20wc= @@ -852,8 +856,8 @@ github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.23.1 h1:Ox0cOPv/t8RzKJUfDo9ZKtRvBOJY369sFJnl00CjqwY= -github.com/onsi/ginkgo/v2 v2.23.1/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= +github.com/onsi/ginkgo/v2 v2.23.2 h1:LYLd7Wz401p0N7xR8y7WL6D2QZwKpbirDg0EVIvzvMM= +github.com/onsi/ginkgo/v2 v2.23.2/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= @@ -861,8 +865,8 @@ github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/open-policy-agent/opa v1.2.0 h1:88NDVCM0of1eO6Z4AFeL3utTEtMuwloFmWWU7dRV1z0= github.com/open-policy-agent/opa v1.2.0/go.mod h1:30euUmOvuBoebRCcJ7DMF42bRBOPznvt0ACUMYDUGVY= -github.com/opencloud-eu/reva/v2 v2.28.1-0.20250320135948-a946c0d6d289 h1:gg37XG4j3Y7yWLrD+B+2uNQ72g4YasdvpzOKJnuQH1Y= -github.com/opencloud-eu/reva/v2 v2.28.1-0.20250320135948-a946c0d6d289/go.mod h1:iK0tNdLgqK0zBi0l7Q4uWSn9GPUbYtNxz3YAMfYvYNg= +github.com/opencloud-eu/reva/v2 v2.28.1-0.20250321112659-61a430bfb4c5 h1:R2HXrbl4RP78Pgjs9d/djzzc9h7RrePjFRZnBdXHiFM= +github.com/opencloud-eu/reva/v2 v2.28.1-0.20250321112659-61a430bfb4c5/go.mod h1:6KR5qe5pUogF48rnybIksQcxWIACB2ISEqBa3kbZzZA= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= @@ -1477,8 +1481,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= -golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/vendor/github.com/google/go-tpm/LICENSE b/vendor/github.com/google/go-tpm/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/vendor/github.com/google/go-tpm/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/README.md b/vendor/github.com/google/go-tpm/legacy/tpm2/README.md new file mode 100644 index 0000000000..4d0ff8befa --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/README.md @@ -0,0 +1,35 @@ +# TPM 2.0 client library + +## Tests + +This library contains unit tests in `github.com/google/go-tpm/tpm2`, which just +tests that various encoding and error checking functions work correctly. It also +contains more comprehensive integration tests in +`github.com/google/go-tpm/tpm2/test`, which run actual commands on a TPM. + +By default, these integration tests are run against the +[`go-tpm-tools`](https://github.com/google/go-tpm-tools) +simulator, which is baesed on the +[Microsoft Reference TPM2 code](https://github.com/microsoft/ms-tpm-20-ref). To +run both the unit and integration tests, run (in this directory) +```bash +go test . ./test +``` + +These integration tests can also be run against a real TPM device. This is +slightly more complex as the tests often need to be built as a normal user and +then executed as root. For example, +```bash +# Build the test binary without running it +go test -c github.com/google/go-tpm/tpm2/test +# Execute the test binary as root +sudo ./test.test --tpm-path=/dev/tpmrm0 +``` +On Linux, The `--tpm-path` causes the integration tests to be run against a +real TPM located at that path (usually `/dev/tpmrm0` or `/dev/tpm0`). On Windows, the story is similar, execept that +the `--use-tbs` flag is used instead. + +Tip: if your TPM host is remote and you don't want to install Go on it, this +same two-step process can be used. The test binary can be copied to a remote +host and run without extra installation (as the test binary has very few +*runtime* dependancies). diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/constants.go b/vendor/github.com/google/go-tpm/legacy/tpm2/constants.go new file mode 100644 index 0000000000..2b0de54444 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/constants.go @@ -0,0 +1,575 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpm2 + +import ( + "crypto" + "crypto/elliptic" + "fmt" + "strings" + + // Register the relevant hash implementations to prevent a runtime failure. + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" + + "github.com/google/go-tpm/tpmutil" +) + +var hashInfo = []struct { + alg Algorithm + hash crypto.Hash +}{ + {AlgSHA1, crypto.SHA1}, + {AlgSHA256, crypto.SHA256}, + {AlgSHA384, crypto.SHA384}, + {AlgSHA512, crypto.SHA512}, + {AlgSHA3_256, crypto.SHA3_256}, + {AlgSHA3_384, crypto.SHA3_384}, + {AlgSHA3_512, crypto.SHA3_512}, +} + +// MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields. +// Typically used for chunking of big blobs of data (such as for hashing or +// encryption). +const maxDigestBuffer = 1024 + +// Algorithm represents a TPM_ALG_ID value. +type Algorithm uint16 + +// HashToAlgorithm looks up the TPM2 algorithm corresponding to the provided crypto.Hash +func HashToAlgorithm(hash crypto.Hash) (Algorithm, error) { + for _, info := range hashInfo { + if info.hash == hash { + return info.alg, nil + } + } + return AlgUnknown, fmt.Errorf("go hash algorithm #%d has no TPM2 algorithm", hash) +} + +// IsNull returns true if a is AlgNull or zero (unset). +func (a Algorithm) IsNull() bool { + return a == AlgNull || a == AlgUnknown +} + +// UsesCount returns true if a signature algorithm uses count value. +func (a Algorithm) UsesCount() bool { + return a == AlgECDAA +} + +// UsesHash returns true if the algorithm requires the use of a hash. +func (a Algorithm) UsesHash() bool { + return a == AlgOAEP +} + +// Hash returns a crypto.Hash based on the given TPM_ALG_ID. +// An error is returned if the given algorithm is not a hash algorithm or is not available. +func (a Algorithm) Hash() (crypto.Hash, error) { + for _, info := range hashInfo { + if info.alg == a { + if !info.hash.Available() { + return crypto.Hash(0), fmt.Errorf("go hash algorithm #%d not available", info.hash) + } + return info.hash, nil + } + } + return crypto.Hash(0), fmt.Errorf("hash algorithm not supported: 0x%x", a) +} + +func (a Algorithm) String() string { + var s strings.Builder + var err error + switch a { + case AlgUnknown: + _, err = s.WriteString("AlgUnknown") + case AlgRSA: + _, err = s.WriteString("RSA") + case AlgSHA1: + _, err = s.WriteString("SHA1") + case AlgHMAC: + _, err = s.WriteString("HMAC") + case AlgAES: + _, err = s.WriteString("AES") + case AlgKeyedHash: + _, err = s.WriteString("KeyedHash") + case AlgXOR: + _, err = s.WriteString("XOR") + case AlgSHA256: + _, err = s.WriteString("SHA256") + case AlgSHA384: + _, err = s.WriteString("SHA384") + case AlgSHA512: + _, err = s.WriteString("SHA512") + case AlgNull: + _, err = s.WriteString("AlgNull") + case AlgRSASSA: + _, err = s.WriteString("RSASSA") + case AlgRSAES: + _, err = s.WriteString("RSAES") + case AlgRSAPSS: + _, err = s.WriteString("RSAPSS") + case AlgOAEP: + _, err = s.WriteString("OAEP") + case AlgECDSA: + _, err = s.WriteString("ECDSA") + case AlgECDH: + _, err = s.WriteString("ECDH") + case AlgECDAA: + _, err = s.WriteString("ECDAA") + case AlgKDF2: + _, err = s.WriteString("KDF2") + case AlgECC: + _, err = s.WriteString("ECC") + case AlgSymCipher: + _, err = s.WriteString("SymCipher") + case AlgSHA3_256: + _, err = s.WriteString("SHA3_256") + case AlgSHA3_384: + _, err = s.WriteString("SHA3_384") + case AlgSHA3_512: + _, err = s.WriteString("SHA3_512") + case AlgCTR: + _, err = s.WriteString("CTR") + case AlgOFB: + _, err = s.WriteString("OFB") + case AlgCBC: + _, err = s.WriteString("CBC") + case AlgCFB: + _, err = s.WriteString("CFB") + case AlgECB: + _, err = s.WriteString("ECB") + default: + return fmt.Sprintf("Alg?<%d>", int(a)) + } + if err != nil { + return fmt.Sprintf("Writing to string builder failed: %v", err) + } + return s.String() +} + +// Supported Algorithms. +const ( + AlgUnknown Algorithm = 0x0000 + AlgRSA Algorithm = 0x0001 + AlgSHA1 Algorithm = 0x0004 + AlgHMAC Algorithm = 0x0005 + AlgAES Algorithm = 0x0006 + AlgKeyedHash Algorithm = 0x0008 + AlgXOR Algorithm = 0x000A + AlgSHA256 Algorithm = 0x000B + AlgSHA384 Algorithm = 0x000C + AlgSHA512 Algorithm = 0x000D + AlgNull Algorithm = 0x0010 + AlgRSASSA Algorithm = 0x0014 + AlgRSAES Algorithm = 0x0015 + AlgRSAPSS Algorithm = 0x0016 + AlgOAEP Algorithm = 0x0017 + AlgECDSA Algorithm = 0x0018 + AlgECDH Algorithm = 0x0019 + AlgECDAA Algorithm = 0x001A + AlgKDF2 Algorithm = 0x0021 + AlgECC Algorithm = 0x0023 + AlgSymCipher Algorithm = 0x0025 + AlgSHA3_256 Algorithm = 0x0027 + AlgSHA3_384 Algorithm = 0x0028 + AlgSHA3_512 Algorithm = 0x0029 + AlgCTR Algorithm = 0x0040 + AlgOFB Algorithm = 0x0041 + AlgCBC Algorithm = 0x0042 + AlgCFB Algorithm = 0x0043 + AlgECB Algorithm = 0x0044 +) + +// HandleType defines a type of handle. +type HandleType uint8 + +// Supported handle types +const ( + HandleTypePCR HandleType = 0x00 + HandleTypeNVIndex HandleType = 0x01 + HandleTypeHMACSession HandleType = 0x02 + HandleTypeLoadedSession HandleType = 0x02 + HandleTypePolicySession HandleType = 0x03 + HandleTypeSavedSession HandleType = 0x03 + HandleTypePermanent HandleType = 0x40 + HandleTypeTransient HandleType = 0x80 + HandleTypePersistent HandleType = 0x81 +) + +// SessionType defines the type of session created in StartAuthSession. +type SessionType uint8 + +// Supported session types. +const ( + SessionHMAC SessionType = 0x00 + SessionPolicy SessionType = 0x01 + SessionTrial SessionType = 0x03 +) + +// SessionAttributes represents an attribute of a session. +type SessionAttributes byte + +// Session Attributes (Structures 8.4 TPMA_SESSION) +const ( + AttrContinueSession SessionAttributes = 1 << iota + AttrAuditExclusive + AttrAuditReset + _ // bit 3 reserved + _ // bit 4 reserved + AttrDecrypt + AttrEcrypt + AttrAudit +) + +// EmptyAuth represents the empty authorization value. +var EmptyAuth []byte + +// KeyProp is a bitmask used in Attributes field of key templates. Individual +// flags should be OR-ed to form a full mask. +type KeyProp uint32 + +// Key properties. +const ( + FlagFixedTPM KeyProp = 0x00000002 + FlagStClear KeyProp = 0x00000004 + FlagFixedParent KeyProp = 0x00000010 + FlagSensitiveDataOrigin KeyProp = 0x00000020 + FlagUserWithAuth KeyProp = 0x00000040 + FlagAdminWithPolicy KeyProp = 0x00000080 + FlagNoDA KeyProp = 0x00000400 + FlagRestricted KeyProp = 0x00010000 + FlagDecrypt KeyProp = 0x00020000 + FlagSign KeyProp = 0x00040000 + + FlagSealDefault = FlagFixedTPM | FlagFixedParent + FlagSignerDefault = FlagSign | FlagRestricted | FlagFixedTPM | + FlagFixedParent | FlagSensitiveDataOrigin | FlagUserWithAuth + FlagStorageDefault = FlagDecrypt | FlagRestricted | FlagFixedTPM | + FlagFixedParent | FlagSensitiveDataOrigin | FlagUserWithAuth +) + +// TPMProp represents a Property Tag (TPM_PT) used with calls to GetCapability(CapabilityTPMProperties). +type TPMProp uint32 + +// TPM Capability Properties, see TPM 2.0 Spec, Rev 1.38, Table 23. +// Fixed TPM Properties (PT_FIXED) +const ( + FamilyIndicator TPMProp = 0x100 + iota + SpecLevel + SpecRevision + SpecDayOfYear + SpecYear + Manufacturer + VendorString1 + VendorString2 + VendorString3 + VendorString4 + VendorTPMType + FirmwareVersion1 + FirmwareVersion2 + InputMaxBufferSize + TransientObjectsMin + PersistentObjectsMin + LoadedObjectsMin + ActiveSessionsMax + PCRCount + PCRSelectMin + ContextGapMax + _ // (PT_FIXED + 21) is skipped + NVCountersMax + NVIndexMax + MemoryMethod + ClockUpdate + ContextHash + ContextSym + ContextSymSize + OrderlyCount + CommandMaxSize + ResponseMaxSize + DigestMaxSize + ObjectContextMaxSize + SessionContextMaxSize + PSFamilyIndicator + PSSpecLevel + PSSpecRevision + PSSpecDayOfYear + PSSpecYear + SplitSigningMax + TotalCommands + LibraryCommands + VendorCommands + NVMaxBufferSize + TPMModes + CapabilityMaxBufferSize +) + +// Variable TPM Properties (PT_VAR) +const ( + TPMAPermanent TPMProp = 0x200 + iota + TPMAStartupClear + HRNVIndex + HRLoaded + HRLoadedAvail + HRActive + HRActiveAvail + HRTransientAvail + CurrentPersistent + AvailPersistent + NVCounters + NVCountersAvail + AlgorithmSet + LoadedCurves + LockoutCounter + MaxAuthFail + LockoutInterval + LockoutRecovery + NVWriteRecovery + AuditCounter0 + AuditCounter1 +) + +// Allowed ranges of different kinds of Handles (TPM_HANDLE) +// These constants have type TPMProp for backwards compatibility. +const ( + PCRFirst TPMProp = 0x00000000 + HMACSessionFirst TPMProp = 0x02000000 + LoadedSessionFirst TPMProp = 0x02000000 + PolicySessionFirst TPMProp = 0x03000000 + ActiveSessionFirst TPMProp = 0x03000000 + TransientFirst TPMProp = 0x80000000 + PersistentFirst TPMProp = 0x81000000 + PersistentLast TPMProp = 0x81FFFFFF + PlatformPersistent TPMProp = 0x81800000 + NVIndexFirst TPMProp = 0x01000000 + NVIndexLast TPMProp = 0x01FFFFFF + PermanentFirst TPMProp = 0x40000000 + PermanentLast TPMProp = 0x4000010F +) + +// Reserved Handles. +const ( + HandleOwner tpmutil.Handle = 0x40000001 + iota + HandleRevoke + HandleTransport + HandleOperator + HandleAdmin + HandleEK + HandleNull + HandleUnassigned + HandlePasswordSession + HandleLockout + HandleEndorsement + HandlePlatform +) + +// Capability identifies some TPM property or state type. +type Capability uint32 + +// TPM Capabilities. +const ( + CapabilityAlgs Capability = iota + CapabilityHandles + CapabilityCommands + CapabilityPPCommands + CapabilityAuditCommands + CapabilityPCRs + CapabilityTPMProperties + CapabilityPCRProperties + CapabilityECCCurves + CapabilityAuthPolicies +) + +// TPM Structure Tags. Tags are used to disambiguate structures, similar to Alg +// values: tag value defines what kind of data lives in a nested field. +const ( + TagNull tpmutil.Tag = 0x8000 + TagNoSessions tpmutil.Tag = 0x8001 + TagSessions tpmutil.Tag = 0x8002 + TagAttestCertify tpmutil.Tag = 0x8017 + TagAttestQuote tpmutil.Tag = 0x8018 + TagAttestCreation tpmutil.Tag = 0x801a + TagAuthSecret tpmutil.Tag = 0x8023 + TagHashCheck tpmutil.Tag = 0x8024 + TagAuthSigned tpmutil.Tag = 0x8025 +) + +// StartupType instructs the TPM on how to handle its state during Shutdown or +// Startup. +type StartupType uint16 + +// Startup types +const ( + StartupClear StartupType = iota + StartupState +) + +// EllipticCurve identifies specific EC curves. +type EllipticCurve uint16 + +// ECC curves supported by TPM 2.0 spec. +const ( + CurveNISTP192 = EllipticCurve(iota + 1) + CurveNISTP224 + CurveNISTP256 + CurveNISTP384 + CurveNISTP521 + + CurveBNP256 = EllipticCurve(iota + 10) + CurveBNP638 + + CurveSM2P256 = EllipticCurve(0x0020) +) + +var toGoCurve = map[EllipticCurve]elliptic.Curve{ + CurveNISTP224: elliptic.P224(), + CurveNISTP256: elliptic.P256(), + CurveNISTP384: elliptic.P384(), + CurveNISTP521: elliptic.P521(), +} + +// Supported TPM operations. +const ( + CmdNVUndefineSpaceSpecial tpmutil.Command = 0x0000011F + CmdEvictControl tpmutil.Command = 0x00000120 + CmdUndefineSpace tpmutil.Command = 0x00000122 + CmdClear tpmutil.Command = 0x00000126 + CmdHierarchyChangeAuth tpmutil.Command = 0x00000129 + CmdDefineSpace tpmutil.Command = 0x0000012A + CmdCreatePrimary tpmutil.Command = 0x00000131 + CmdIncrementNVCounter tpmutil.Command = 0x00000134 + CmdWriteNV tpmutil.Command = 0x00000137 + CmdWriteLockNV tpmutil.Command = 0x00000138 + CmdDictionaryAttackLockReset tpmutil.Command = 0x00000139 + CmdDictionaryAttackParameters tpmutil.Command = 0x0000013A + CmdPCREvent tpmutil.Command = 0x0000013C + CmdPCRReset tpmutil.Command = 0x0000013D + CmdSequenceComplete tpmutil.Command = 0x0000013E + CmdStartup tpmutil.Command = 0x00000144 + CmdShutdown tpmutil.Command = 0x00000145 + CmdActivateCredential tpmutil.Command = 0x00000147 + CmdCertify tpmutil.Command = 0x00000148 + CmdCertifyCreation tpmutil.Command = 0x0000014A + CmdReadNV tpmutil.Command = 0x0000014E + CmdReadLockNV tpmutil.Command = 0x0000014F + CmdPolicySecret tpmutil.Command = 0x00000151 + CmdCreate tpmutil.Command = 0x00000153 + CmdECDHZGen tpmutil.Command = 0x00000154 + CmdImport tpmutil.Command = 0x00000156 + CmdLoad tpmutil.Command = 0x00000157 + CmdQuote tpmutil.Command = 0x00000158 + CmdRSADecrypt tpmutil.Command = 0x00000159 + CmdSequenceUpdate tpmutil.Command = 0x0000015C + CmdSign tpmutil.Command = 0x0000015D + CmdUnseal tpmutil.Command = 0x0000015E + CmdPolicySigned tpmutil.Command = 0x00000160 + CmdContextLoad tpmutil.Command = 0x00000161 + CmdContextSave tpmutil.Command = 0x00000162 + CmdECDHKeyGen tpmutil.Command = 0x00000163 + CmdEncryptDecrypt tpmutil.Command = 0x00000164 + CmdFlushContext tpmutil.Command = 0x00000165 + CmdLoadExternal tpmutil.Command = 0x00000167 + CmdMakeCredential tpmutil.Command = 0x00000168 + CmdReadPublicNV tpmutil.Command = 0x00000169 + CmdPolicyCommandCode tpmutil.Command = 0x0000016C + CmdPolicyOr tpmutil.Command = 0x00000171 + CmdReadPublic tpmutil.Command = 0x00000173 + CmdRSAEncrypt tpmutil.Command = 0x00000174 + CmdStartAuthSession tpmutil.Command = 0x00000176 + CmdGetCapability tpmutil.Command = 0x0000017A + CmdGetRandom tpmutil.Command = 0x0000017B + CmdHash tpmutil.Command = 0x0000017D + CmdPCRRead tpmutil.Command = 0x0000017E + CmdPolicyPCR tpmutil.Command = 0x0000017F + CmdReadClock tpmutil.Command = 0x00000181 + CmdPCRExtend tpmutil.Command = 0x00000182 + CmdEventSequenceComplete tpmutil.Command = 0x00000185 + CmdHashSequenceStart tpmutil.Command = 0x00000186 + CmdPolicyGetDigest tpmutil.Command = 0x00000189 + CmdPolicyPassword tpmutil.Command = 0x0000018C + CmdEncryptDecrypt2 tpmutil.Command = 0x00000193 +) + +// Regular TPM 2.0 devices use 24-bit mask (3 bytes) for PCR selection. +const sizeOfPCRSelect = 3 + +const defaultRSAExponent = 1<<16 + 1 + +// NVAttr is a bitmask used in Attributes field of NV indexes. Individual +// flags should be OR-ed to form a full mask. +type NVAttr uint32 + +// NV Attributes +const ( + AttrPPWrite NVAttr = 0x00000001 + AttrOwnerWrite NVAttr = 0x00000002 + AttrAuthWrite NVAttr = 0x00000004 + AttrPolicyWrite NVAttr = 0x00000008 + AttrPolicyDelete NVAttr = 0x00000400 + AttrWriteLocked NVAttr = 0x00000800 + AttrWriteAll NVAttr = 0x00001000 + AttrWriteDefine NVAttr = 0x00002000 + AttrWriteSTClear NVAttr = 0x00004000 + AttrGlobalLock NVAttr = 0x00008000 + AttrPPRead NVAttr = 0x00010000 + AttrOwnerRead NVAttr = 0x00020000 + AttrAuthRead NVAttr = 0x00040000 + AttrPolicyRead NVAttr = 0x00080000 + AttrNoDA NVAttr = 0x02000000 + AttrOrderly NVAttr = 0x04000000 + AttrClearSTClear NVAttr = 0x08000000 + AttrReadLocked NVAttr = 0x10000000 + AttrWritten NVAttr = 0x20000000 + AttrPlatformCreate NVAttr = 0x40000000 + AttrReadSTClear NVAttr = 0x80000000 +) + +var permMap = map[NVAttr]string{ + AttrPPWrite: "PPWrite", + AttrOwnerWrite: "OwnerWrite", + AttrAuthWrite: "AuthWrite", + AttrPolicyWrite: "PolicyWrite", + AttrPolicyDelete: "PolicyDelete", + AttrWriteLocked: "WriteLocked", + AttrWriteAll: "WriteAll", + AttrWriteDefine: "WriteDefine", + AttrWriteSTClear: "WriteSTClear", + AttrGlobalLock: "GlobalLock", + AttrPPRead: "PPRead", + AttrOwnerRead: "OwnerRead", + AttrAuthRead: "AuthRead", + AttrPolicyRead: "PolicyRead", + AttrNoDA: "No Do", + AttrOrderly: "Oderly", + AttrClearSTClear: "ClearSTClear", + AttrReadLocked: "ReadLocked", + AttrWritten: "Writte", + AttrPlatformCreate: "PlatformCreate", + AttrReadSTClear: "ReadSTClear", +} + +// String returns a textual representation of the set of NVAttr +func (p NVAttr) String() string { + var retString strings.Builder + for iterator, item := range permMap { + if (p & iterator) != 0 { + retString.WriteString(item + " + ") + } + } + if retString.String() == "" { + return "Permission/s not found" + } + return strings.TrimSuffix(retString.String(), " + ") + +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/error.go b/vendor/github.com/google/go-tpm/legacy/tpm2/error.go new file mode 100644 index 0000000000..e1983356fe --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/error.go @@ -0,0 +1,362 @@ +package tpm2 + +import ( + "fmt" + + "github.com/google/go-tpm/tpmutil" +) + +type ( + // RCFmt0 holds Format 0 error codes + RCFmt0 uint8 + + // RCFmt1 holds Format 1 error codes + RCFmt1 uint8 + + // RCWarn holds error codes used in warnings + RCWarn uint8 + + // RCIndex is used to reference arguments, handles and sessions in errors + RCIndex uint8 +) + +// Format 0 error codes. +const ( + RCInitialize RCFmt0 = 0x00 + RCFailure RCFmt0 = 0x01 + RCSequence RCFmt0 = 0x03 + RCPrivate RCFmt0 = 0x0B + RCHMAC RCFmt0 = 0x19 + RCDisabled RCFmt0 = 0x20 + RCExclusive RCFmt0 = 0x21 + RCAuthType RCFmt0 = 0x24 + RCAuthMissing RCFmt0 = 0x25 + RCPolicy RCFmt0 = 0x26 + RCPCR RCFmt0 = 0x27 + RCPCRChanged RCFmt0 = 0x28 + RCUpgrade RCFmt0 = 0x2D + RCTooManyContexts RCFmt0 = 0x2E + RCAuthUnavailable RCFmt0 = 0x2F + RCReboot RCFmt0 = 0x30 + RCUnbalanced RCFmt0 = 0x31 + RCCommandSize RCFmt0 = 0x42 + RCCommandCode RCFmt0 = 0x43 + RCAuthSize RCFmt0 = 0x44 + RCAuthContext RCFmt0 = 0x45 + RCNVRange RCFmt0 = 0x46 + RCNVSize RCFmt0 = 0x47 + RCNVLocked RCFmt0 = 0x48 + RCNVAuthorization RCFmt0 = 0x49 + RCNVUninitialized RCFmt0 = 0x4A + RCNVSpace RCFmt0 = 0x4B + RCNVDefined RCFmt0 = 0x4C + RCBadContext RCFmt0 = 0x50 + RCCPHash RCFmt0 = 0x51 + RCParent RCFmt0 = 0x52 + RCNeedsTest RCFmt0 = 0x53 + RCNoResult RCFmt0 = 0x54 + RCSensitive RCFmt0 = 0x55 +) + +var fmt0Msg = map[RCFmt0]string{ + RCInitialize: "TPM not initialized by TPM2_Startup or already initialized", + RCFailure: "commands not being accepted because of a TPM failure", + RCSequence: "improper use of a sequence handle", + RCPrivate: "not currently used", + RCHMAC: "not currently used", + RCDisabled: "the command is disabled", + RCExclusive: "command failed because audit sequence required exclusivity", + RCAuthType: "authorization handle is not correct for command", + RCAuthMissing: "5 command requires an authorization session for handle and it is not present", + RCPolicy: "policy failure in math operation or an invalid authPolicy value", + RCPCR: "PCR check fail", + RCPCRChanged: "PCR have changed since checked", + RCUpgrade: "TPM is in field upgrade mode unless called via TPM2_FieldUpgradeData(), then it is not in field upgrade mode", + RCTooManyContexts: "context ID counter is at maximum", + RCAuthUnavailable: "authValue or authPolicy is not available for selected entity", + RCReboot: "a _TPM_Init and Startup(CLEAR) is required before the TPM can resume operation", + RCUnbalanced: "the protection algorithms (hash and symmetric) are not reasonably balanced; the digest size of the hash must be larger than the key size of the symmetric algorithm", + RCCommandSize: "command commandSize value is inconsistent with contents of the command buffer; either the size is not the same as the octets loaded by the hardware interface layer or the value is not large enough to hold a command header", + RCCommandCode: "command code not supported", + RCAuthSize: "the value of authorizationSize is out of range or the number of octets in the Authorization Area is greater than required", + RCAuthContext: "use of an authorization session with a context command or another command that cannot have an authorization session", + RCNVRange: "NV offset+size is out of range", + RCNVSize: "Requested allocation size is larger than allowed", + RCNVLocked: "NV access locked", + RCNVAuthorization: "NV access authorization fails in command actions", + RCNVUninitialized: "an NV Index is used before being initialized or the state saved by TPM2_Shutdown(STATE) could not be restored", + RCNVSpace: "insufficient space for NV allocation", + RCNVDefined: "NV Index or persistent object already defined", + RCBadContext: "context in TPM2_ContextLoad() is not valid", + RCCPHash: "cpHash value already set or not correct for use", + RCParent: "handle for parent is not a valid parent", + RCNeedsTest: "some function needs testing", + RCNoResult: "returned when an internal function cannot process a request due to an unspecified problem; this code is usually related to invalid parameters that are not properly filtered by the input unmarshaling code", + RCSensitive: "the sensitive area did not unmarshal correctly after decryption", +} + +// Format 1 error codes. +const ( + RCAsymmetric = 0x01 + RCAttributes = 0x02 + RCHash = 0x03 + RCValue = 0x04 + RCHierarchy = 0x05 + RCKeySize = 0x07 + RCMGF = 0x08 + RCMode = 0x09 + RCType = 0x0A + RCHandle = 0x0B + RCKDF = 0x0C + RCRange = 0x0D + RCAuthFail = 0x0E + RCNonce = 0x0F + RCPP = 0x10 + RCScheme = 0x12 + RCSize = 0x15 + RCSymmetric = 0x16 + RCTag = 0x17 + RCSelector = 0x18 + RCInsufficient = 0x1A + RCSignature = 0x1B + RCKey = 0x1C + RCPolicyFail = 0x1D + RCIntegrity = 0x1F + RCTicket = 0x20 + RCReservedBits = 0x21 + RCBadAuth = 0x22 + RCExpired = 0x23 + RCPolicyCC = 0x24 + RCBinding = 0x25 + RCCurve = 0x26 + RCECCPoint = 0x27 +) + +var fmt1Msg = map[RCFmt1]string{ + RCAsymmetric: "asymmetric algorithm not supported or not correct", + RCAttributes: "inconsistent attributes", + RCHash: "hash algorithm not supported or not appropriate", + RCValue: "value is out of range or is not correct for the context", + RCHierarchy: "hierarchy is not enabled or is not correct for the use", + RCKeySize: "key size is not supported", + RCMGF: "mask generation function not supported", + RCMode: "mode of operation not supported", + RCType: "the type of the value is not appropriate for the use", + RCHandle: "the handle is not correct for the use", + RCKDF: "unsupported key derivation function or function not appropriate for use", + RCRange: "value was out of allowed range", + RCAuthFail: "the authorization HMAC check failed and DA counter incremented", + RCNonce: "invalid nonce size or nonce value mismatch", + RCPP: "authorization requires assertion of PP", + RCScheme: "unsupported or incompatible scheme", + RCSize: "structure is the wrong size", + RCSymmetric: "unsupported symmetric algorithm or key size, or not appropriate for instance", + RCTag: "incorrect structure tag", + RCSelector: "union selector is incorrect", + RCInsufficient: "the TPM was unable to unmarshal a value because there were not enough octets in the input buffer", + RCSignature: "the signature is not valid", + RCKey: "key fields are not compatible with the selected use", + RCPolicyFail: "a policy check failed", + RCIntegrity: "integrity check failed", + RCTicket: "invalid ticket", + RCReservedBits: "reserved bits not set to zero as required", + RCBadAuth: "authorization failure without DA implications", + RCExpired: "the policy has expired", + RCPolicyCC: "the commandCode in the policy is not the commandCode of the command or the command code in a policy command references a command that is not implemented", + RCBinding: "public and sensitive portions of an object are not cryptographically bound", + RCCurve: "curve not supported", + RCECCPoint: "point is not on the required curve", +} + +// Warning codes. +const ( + RCContextGap RCWarn = 0x01 + RCObjectMemory RCWarn = 0x02 + RCSessionMemory RCWarn = 0x03 + RCMemory RCWarn = 0x04 + RCSessionHandles RCWarn = 0x05 + RCObjectHandles RCWarn = 0x06 + RCLocality RCWarn = 0x07 + RCYielded RCWarn = 0x08 + RCCanceled RCWarn = 0x09 + RCTesting RCWarn = 0x0A + RCReferenceH0 RCWarn = 0x10 + RCReferenceH1 RCWarn = 0x11 + RCReferenceH2 RCWarn = 0x12 + RCReferenceH3 RCWarn = 0x13 + RCReferenceH4 RCWarn = 0x14 + RCReferenceH5 RCWarn = 0x15 + RCReferenceH6 RCWarn = 0x16 + RCReferenceS0 RCWarn = 0x18 + RCReferenceS1 RCWarn = 0x19 + RCReferenceS2 RCWarn = 0x1A + RCReferenceS3 RCWarn = 0x1B + RCReferenceS4 RCWarn = 0x1C + RCReferenceS5 RCWarn = 0x1D + RCReferenceS6 RCWarn = 0x1E + RCNVRate RCWarn = 0x20 + RCLockout RCWarn = 0x21 + RCRetry RCWarn = 0x22 + RCNVUnavailable RCWarn = 0x23 +) + +var warnMsg = map[RCWarn]string{ + RCContextGap: "gap for context ID is too large", + RCObjectMemory: "out of memory for object contexts", + RCSessionMemory: "out of memory for session contexts", + RCMemory: "out of shared object/session memory or need space for internal operations", + RCSessionHandles: "out of session handles", + RCObjectHandles: "out of object handles", + RCLocality: "bad locality", + RCYielded: "the TPM has suspended operation on the command; forward progress was made and the command may be retried", + RCCanceled: "the command was canceled", + RCTesting: "TPM is performing self-tests", + RCReferenceH0: "the 1st handle in the handle area references a transient object or session that is not loaded", + RCReferenceH1: "the 2nd handle in the handle area references a transient object or session that is not loaded", + RCReferenceH2: "the 3rd handle in the handle area references a transient object or session that is not loaded", + RCReferenceH3: "the 4th handle in the handle area references a transient object or session that is not loaded", + RCReferenceH4: "the 5th handle in the handle area references a transient object or session that is not loaded", + RCReferenceH5: "the 6th handle in the handle area references a transient object or session that is not loaded", + RCReferenceH6: "the 7th handle in the handle area references a transient object or session that is not loaded", + RCReferenceS0: "the 1st authorization session handle references a session that is not loaded", + RCReferenceS1: "the 2nd authorization session handle references a session that is not loaded", + RCReferenceS2: "the 3rd authorization session handle references a session that is not loaded", + RCReferenceS3: "the 4th authorization session handle references a session that is not loaded", + RCReferenceS4: "the 5th authorization session handle references a session that is not loaded", + RCReferenceS5: "the 6th authorization session handle references a session that is not loaded", + RCReferenceS6: "the 7th authorization session handle references a session that is not loaded", + RCNVRate: "the TPM is rate-limiting accesses to prevent wearout of NV", + RCLockout: "authorizations for objects subject to DA protection are not allowed at this time because the TPM is in DA lockout mode", + RCRetry: "the TPM was not able to start the command", + RCNVUnavailable: "the command may require writing of NV and NV is not current accessible", +} + +// Indexes for arguments, handles and sessions. +const ( + RC1 RCIndex = iota + 1 + RC2 + RC3 + RC4 + RC5 + RC6 + RC7 + RC8 + RC9 + RCA + RCB + RCC + RCD + RCE + RCF +) + +const unknownCode = "unknown error code" + +// Error is returned for all Format 0 errors from the TPM. It is used for general +// errors not specific to a parameter, handle or session. +type Error struct { + Code RCFmt0 +} + +func (e Error) Error() string { + msg := fmt0Msg[e.Code] + if msg == "" { + msg = unknownCode + } + return fmt.Sprintf("error code 0x%x : %s", e.Code, msg) +} + +// VendorError represents a vendor-specific error response. These types of responses +// are not decoded and Code contains the complete response code. +type VendorError struct { + Code uint32 +} + +func (e VendorError) Error() string { + return fmt.Sprintf("vendor error code 0x%x", e.Code) +} + +// Warning is typically used to report transient errors. +type Warning struct { + Code RCWarn +} + +func (w Warning) Error() string { + msg := warnMsg[w.Code] + if msg == "" { + msg = unknownCode + } + return fmt.Sprintf("warning code 0x%x : %s", w.Code, msg) +} + +// ParameterError describes an error related to a parameter, and the parameter number. +type ParameterError struct { + Code RCFmt1 + Parameter RCIndex +} + +func (e ParameterError) Error() string { + msg := fmt1Msg[e.Code] + if msg == "" { + msg = unknownCode + } + return fmt.Sprintf("parameter %d, error code 0x%x : %s", e.Parameter, e.Code, msg) +} + +// HandleError describes an error related to a handle, and the handle number. +type HandleError struct { + Code RCFmt1 + Handle RCIndex +} + +func (e HandleError) Error() string { + msg := fmt1Msg[e.Code] + if msg == "" { + msg = unknownCode + } + return fmt.Sprintf("handle %d, error code 0x%x : %s", e.Handle, e.Code, msg) +} + +// SessionError describes an error related to a session, and the session number. +type SessionError struct { + Code RCFmt1 + Session RCIndex +} + +func (e SessionError) Error() string { + msg := fmt1Msg[e.Code] + if msg == "" { + msg = unknownCode + } + return fmt.Sprintf("session %d, error code 0x%x : %s", e.Session, e.Code, msg) +} + +// Decode a TPM2 response code and return the appropriate error. Logic +// according to the "Response Code Evaluation" chart in Part 1 of the TPM 2.0 +// spec. +func decodeResponse(code tpmutil.ResponseCode) error { + if code == tpmutil.RCSuccess { + return nil + } + if code&0x180 == 0 { // Bits 7:8 == 0 is a TPM1 error + return fmt.Errorf("response status 0x%x", code) + } + if code&0x80 == 0 { // Bit 7 unset + if code&0x400 > 0 { // Bit 10 set, vendor specific code + return VendorError{uint32(code)} + } + if code&0x800 > 0 { // Bit 11 set, warning with code in bit 0:6 + return Warning{RCWarn(code & 0x7f)} + } + // error with code in bit 0:6 + return Error{RCFmt0(code & 0x7f)} + } + if code&0x40 > 0 { // Bit 6 set, code in 0:5, parameter number in 8:11 + return ParameterError{RCFmt1(code & 0x3f), RCIndex((code & 0xf00) >> 8)} + } + if code&0x800 == 0 { // Bit 11 unset, code in 0:5, handle in 8:10 + return HandleError{RCFmt1(code & 0x3f), RCIndex((code & 0x700) >> 8)} + } + // Code in 0:5, Session in 8:10 + return SessionError{RCFmt1(code & 0x3f), RCIndex((code & 0x700) >> 8)} +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/kdf.go b/vendor/github.com/google/go-tpm/legacy/tpm2/kdf.go new file mode 100644 index 0000000000..3a22e8be77 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/kdf.go @@ -0,0 +1,116 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpm2 + +import ( + "crypto" + "crypto/hmac" + "encoding/binary" + "hash" +) + +// KDFa implements TPM 2.0's default key derivation function, as defined in +// section 11.4.9.2 of the TPM revision 2 specification part 1. +// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/ +// The key & label parameters must not be zero length. +// The label parameter is a non-null-terminated string. +// The contextU & contextV parameters are optional. +// Deprecated: Use KDFaHash. +func KDFa(hashAlg Algorithm, key []byte, label string, contextU, contextV []byte, bits int) ([]byte, error) { + h, err := hashAlg.Hash() + if err != nil { + return nil, err + } + return KDFaHash(h, key, label, contextU, contextV, bits), nil +} + +// KDFe implements TPM 2.0's ECDH key derivation function, as defined in +// section 11.4.9.3 of the TPM revision 2 specification part 1. +// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/ +// The z parameter is the x coordinate of one party's private ECC key multiplied +// by the other party's public ECC point. +// The use parameter is a non-null-terminated string. +// The partyUInfo and partyVInfo are the x coordinates of the initiator's and +// Deprecated: Use KDFeHash. +func KDFe(hashAlg Algorithm, z []byte, use string, partyUInfo, partyVInfo []byte, bits int) ([]byte, error) { + h, err := hashAlg.Hash() + if err != nil { + return nil, err + } + return KDFeHash(h, z, use, partyUInfo, partyVInfo, bits), nil +} + +// KDFaHash implements TPM 2.0's default key derivation function, as defined in +// section 11.4.9.2 of the TPM revision 2 specification part 1. +// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/ +// The key & label parameters must not be zero length. +// The label parameter is a non-null-terminated string. +// The contextU & contextV parameters are optional. +func KDFaHash(h crypto.Hash, key []byte, label string, contextU, contextV []byte, bits int) []byte { + mac := hmac.New(h.New, key) + + out := kdf(mac, bits, func() { + mac.Write([]byte(label)) + mac.Write([]byte{0}) // Terminating null character for C-string. + mac.Write(contextU) + mac.Write(contextV) + binary.Write(mac, binary.BigEndian, uint32(bits)) + }) + return out +} + +// KDFeHash implements TPM 2.0's ECDH key derivation function, as defined in +// section 11.4.9.3 of the TPM revision 2 specification part 1. +// See: https://trustedcomputinggroup.org/resource/tpm-library-specification/ +// The z parameter is the x coordinate of one party's private ECC key multiplied +// by the other party's public ECC point. +// The use parameter is a non-null-terminated string. +// The partyUInfo and partyVInfo are the x coordinates of the initiator's and +// the responder's ECC points, respectively. +func KDFeHash(h crypto.Hash, z []byte, use string, partyUInfo, partyVInfo []byte, bits int) []byte { + hash := h.New() + + out := kdf(hash, bits, func() { + hash.Write(z) + hash.Write([]byte(use)) + hash.Write([]byte{0}) // Terminating null character for C-string. + hash.Write(partyUInfo) + hash.Write(partyVInfo) + }) + return out +} + +func kdf(h hash.Hash, bits int, update func()) []byte { + bytes := (bits + 7) / 8 + out := []byte{} + + for counter := 1; len(out) < bytes; counter++ { + h.Reset() + binary.Write(h, binary.BigEndian, uint32(counter)) + update() + + out = h.Sum(out) + } + // out's length is a multiple of hash size, so there will be excess + // bytes if bytes isn't a multiple of hash size. + out = out[:bytes] + + // As mentioned in the KDFa and KDFe specs mentioned above, + // the unused bits of the most significant octet are masked off. + if maskBits := uint8(bits % 8); maskBits > 0 { + out[0] &= (1 << maskBits) - 1 + } + return out +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/open_other.go b/vendor/github.com/google/go-tpm/legacy/tpm2/open_other.go new file mode 100644 index 0000000000..7d6d9a31b4 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/open_other.go @@ -0,0 +1,57 @@ +//go:build !windows + +// Copyright (c) 2019, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpm2 + +import ( + "errors" + "fmt" + "io" + "os" + + "github.com/google/go-tpm/tpmutil" +) + +// OpenTPM opens a channel to the TPM at the given path. If the file is a +// device, then it treats it like a normal TPM device, and if the file is a +// Unix domain socket, then it opens a connection to the socket. +// +// This function may also be invoked with no paths, as tpm2.OpenTPM(). In this +// case, the default paths on Linux (/dev/tpmrm0 then /dev/tpm0), will be used. +func OpenTPM(path ...string) (tpm io.ReadWriteCloser, err error) { + switch len(path) { + case 0: + tpm, err = tpmutil.OpenTPM("/dev/tpmrm0") + if errors.Is(err, os.ErrNotExist) { + tpm, err = tpmutil.OpenTPM("/dev/tpm0") + } + case 1: + tpm, err = tpmutil.OpenTPM(path[0]) + default: + return nil, errors.New("cannot specify multiple paths to tpm2.OpenTPM") + } + if err != nil { + return nil, err + } + + // Make sure this is a TPM 2.0 + _, err = GetManufacturer(tpm) + if err != nil { + tpm.Close() + return nil, fmt.Errorf("open %s: device is not a TPM 2.0", path) + } + return tpm, nil +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/open_windows.go b/vendor/github.com/google/go-tpm/legacy/tpm2/open_windows.go new file mode 100644 index 0000000000..ad37a60213 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/open_windows.go @@ -0,0 +1,39 @@ +//go:build windows + +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpm2 + +import ( + "fmt" + "io" + + "github.com/google/go-tpm/tpmutil" + "github.com/google/go-tpm/tpmutil/tbs" +) + +// OpenTPM opens a channel to the TPM. +func OpenTPM() (io.ReadWriteCloser, error) { + info, err := tbs.GetDeviceInfo() + if err != nil { + return nil, err + } + + if info.TPMVersion != tbs.TPMVersion20 { + return nil, fmt.Errorf("openTPM: device is not a TPM 2.0") + } + + return tpmutil.OpenTPM() +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/structures.go b/vendor/github.com/google/go-tpm/legacy/tpm2/structures.go new file mode 100644 index 0000000000..6df9f7f0d7 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/structures.go @@ -0,0 +1,1112 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpm2 + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "encoding/binary" + "errors" + "fmt" + "math/big" + "reflect" + + "github.com/google/go-tpm/tpmutil" +) + +// NVPublic contains the public area of an NV index. +type NVPublic struct { + NVIndex tpmutil.Handle + NameAlg Algorithm + Attributes NVAttr + AuthPolicy tpmutil.U16Bytes + DataSize uint16 +} + +type tpmsSensitiveCreate struct { + UserAuth tpmutil.U16Bytes + Data tpmutil.U16Bytes +} + +// PCRSelection contains a slice of PCR indexes and a hash algorithm used in +// them. +type PCRSelection struct { + Hash Algorithm + PCRs []int +} + +type tpmsPCRSelection struct { + Hash Algorithm + Size byte + PCRs tpmutil.RawBytes +} + +// Public contains the public area of an object. +type Public struct { + Type Algorithm + NameAlg Algorithm + Attributes KeyProp + AuthPolicy tpmutil.U16Bytes + + // Exactly one of the following fields should be set + // When encoding/decoding, one will be picked based on Type. + + // RSAParameters contains both [rsa]parameters and [rsa]unique. + RSAParameters *RSAParams + // ECCParameters contains both [ecc]parameters and [ecc]unique. + ECCParameters *ECCParams + // SymCipherParameters contains both [sym]parameters and [sym]unique. + SymCipherParameters *SymCipherParams + // KeyedHashParameters contains both [keyedHash]parameters and [keyedHash]unique. + KeyedHashParameters *KeyedHashParams +} + +// Encode serializes a Public structure in TPM wire format. +func (p Public) Encode() ([]byte, error) { + head, err := tpmutil.Pack(p.Type, p.NameAlg, p.Attributes, p.AuthPolicy) + if err != nil { + return nil, fmt.Errorf("encoding Type, NameAlg, Attributes, AuthPolicy: %v", err) + } + var params []byte + switch p.Type { + case AlgRSA: + params, err = p.RSAParameters.encode() + case AlgKeyedHash: + params, err = p.KeyedHashParameters.encode() + case AlgECC: + params, err = p.ECCParameters.encode() + case AlgSymCipher: + params, err = p.SymCipherParameters.encode() + default: + err = fmt.Errorf("unsupported type in TPMT_PUBLIC: 0x%x", p.Type) + } + if err != nil { + return nil, fmt.Errorf("encoding RSAParameters, ECCParameters, SymCipherParameters or KeyedHash: %v", err) + } + return concat(head, params) +} + +// Key returns the (public) key from the public area of an object. +func (p Public) Key() (crypto.PublicKey, error) { + var pubKey crypto.PublicKey + switch p.Type { + case AlgRSA: + // Endianness of big.Int.Bytes/SetBytes and modulus in the TPM is the same + // (big-endian). + pubKey = &rsa.PublicKey{N: p.RSAParameters.Modulus(), E: int(p.RSAParameters.Exponent())} + case AlgECC: + curve, ok := toGoCurve[p.ECCParameters.CurveID] + if !ok { + return nil, fmt.Errorf("can't map TPM EC curve ID 0x%x to Go elliptic.Curve value", p.ECCParameters.CurveID) + } + pubKey = &ecdsa.PublicKey{ + X: p.ECCParameters.Point.X(), + Y: p.ECCParameters.Point.Y(), + Curve: curve, + } + default: + return nil, fmt.Errorf("unsupported public key type 0x%x", p.Type) + } + return pubKey, nil +} + +// Name computes the Digest-based Name from the public area of an object. +func (p Public) Name() (Name, error) { + pubEncoded, err := p.Encode() + if err != nil { + return Name{}, err + } + hash, err := p.NameAlg.Hash() + if err != nil { + return Name{}, err + } + nameHash := hash.New() + nameHash.Write(pubEncoded) + return Name{ + Digest: &HashValue{ + Alg: p.NameAlg, + Value: nameHash.Sum(nil), + }, + }, nil +} + +// MatchesTemplate checks if the Public area has the same algorithms and +// parameters as the provided template. Note that this does not necessarily +// mean that the key was created from this template, as the Unique field is +// both provided in the template and overridden in the key creation process. +func (p Public) MatchesTemplate(template Public) bool { + if p.Type != template.Type || + p.NameAlg != template.NameAlg || + p.Attributes != template.Attributes || + !bytes.Equal(p.AuthPolicy, template.AuthPolicy) { + return false + } + switch p.Type { + case AlgRSA: + return p.RSAParameters.matchesTemplate(template.RSAParameters) + case AlgECC: + return p.ECCParameters.matchesTemplate(template.ECCParameters) + case AlgSymCipher: + return p.SymCipherParameters.matchesTemplate(template.SymCipherParameters) + case AlgKeyedHash: + return p.KeyedHashParameters.matchesTemplate(template.KeyedHashParameters) + default: + return true + } +} + +// DecodePublic decodes a TPMT_PUBLIC message. No error is returned if +// the input has extra trailing data. +func DecodePublic(buf []byte) (Public, error) { + in := bytes.NewBuffer(buf) + var pub Public + var err error + if err = tpmutil.UnpackBuf(in, &pub.Type, &pub.NameAlg, &pub.Attributes, &pub.AuthPolicy); err != nil { + return pub, fmt.Errorf("decoding TPMT_PUBLIC: %v", err) + } + + switch pub.Type { + case AlgRSA: + pub.RSAParameters, err = decodeRSAParams(in) + case AlgECC: + pub.ECCParameters, err = decodeECCParams(in) + case AlgSymCipher: + pub.SymCipherParameters, err = decodeSymCipherParams(in) + case AlgKeyedHash: + pub.KeyedHashParameters, err = decodeKeyedHashParams(in) + default: + err = fmt.Errorf("unsupported type in TPMT_PUBLIC: 0x%x", pub.Type) + } + return pub, err +} + +// RSAParams represents parameters of an RSA key pair: +// both the TPMS_RSA_PARMS and the TPM2B_PUBLIC_KEY_RSA. +// +// Symmetric and Sign may be nil, depending on key Attributes in Public. +// +// ExponentRaw and ModulusRaw are the actual data encoded in the template, which +// is useful for templates that differ in zero-padding, for example. +type RSAParams struct { + Symmetric *SymScheme + Sign *SigScheme + KeyBits uint16 + ExponentRaw uint32 + ModulusRaw tpmutil.U16Bytes +} + +// Exponent returns the RSA exponent value represented by ExponentRaw, handling +// the fact that an exponent of 0 represents a value of 65537 (2^16 + 1). +func (p *RSAParams) Exponent() uint32 { + if p.ExponentRaw == 0 { + return defaultRSAExponent + } + return p.ExponentRaw +} + +// Modulus returns the RSA modulus value represented by ModulusRaw, handling the +// fact that the same modulus value can have multiple different representations. +func (p *RSAParams) Modulus() *big.Int { + return new(big.Int).SetBytes(p.ModulusRaw) +} + +func (p *RSAParams) matchesTemplate(t *RSAParams) bool { + return reflect.DeepEqual(p.Symmetric, t.Symmetric) && + reflect.DeepEqual(p.Sign, t.Sign) && + p.KeyBits == t.KeyBits && p.ExponentRaw == t.ExponentRaw +} + +func (p *RSAParams) encode() ([]byte, error) { + if p == nil { + return nil, nil + } + sym, err := p.Symmetric.encode() + if err != nil { + return nil, fmt.Errorf("encoding Symmetric: %v", err) + } + sig, err := p.Sign.encode() + if err != nil { + return nil, fmt.Errorf("encoding Sign: %v", err) + } + rest, err := tpmutil.Pack(p.KeyBits, p.ExponentRaw, p.ModulusRaw) + if err != nil { + return nil, fmt.Errorf("encoding KeyBits, Exponent, Modulus: %v", err) + } + return concat(sym, sig, rest) +} + +func decodeRSAParams(in *bytes.Buffer) (*RSAParams, error) { + var params RSAParams + var err error + + if params.Symmetric, err = decodeSymScheme(in); err != nil { + return nil, fmt.Errorf("decoding Symmetric: %v", err) + } + if params.Sign, err = decodeSigScheme(in); err != nil { + return nil, fmt.Errorf("decoding Sign: %v", err) + } + if err := tpmutil.UnpackBuf(in, ¶ms.KeyBits, ¶ms.ExponentRaw, ¶ms.ModulusRaw); err != nil { + return nil, fmt.Errorf("decoding KeyBits, Exponent, Modulus: %v", err) + } + return ¶ms, nil +} + +// ECCParams represents parameters of an ECC key pair: +// both the TPMS_ECC_PARMS and the TPMS_ECC_POINT. +// +// Symmetric, Sign and KDF may be nil, depending on key Attributes in Public. +type ECCParams struct { + Symmetric *SymScheme + Sign *SigScheme + CurveID EllipticCurve + KDF *KDFScheme + Point ECPoint +} + +// ECPoint represents a ECC coordinates for a point using byte buffers. +type ECPoint struct { + XRaw, YRaw tpmutil.U16Bytes +} + +// X returns the X Point value reprsented by XRaw. +func (p ECPoint) X() *big.Int { + return new(big.Int).SetBytes(p.XRaw) +} + +// Y returns the Y Point value reprsented by YRaw. +func (p ECPoint) Y() *big.Int { + return new(big.Int).SetBytes(p.YRaw) +} + +func (p *ECCParams) matchesTemplate(t *ECCParams) bool { + return reflect.DeepEqual(p.Symmetric, t.Symmetric) && + reflect.DeepEqual(p.Sign, t.Sign) && + p.CurveID == t.CurveID && reflect.DeepEqual(p.KDF, t.KDF) +} + +func (p *ECCParams) encode() ([]byte, error) { + if p == nil { + return nil, nil + } + sym, err := p.Symmetric.encode() + if err != nil { + return nil, fmt.Errorf("encoding Symmetric: %v", err) + } + sig, err := p.Sign.encode() + if err != nil { + return nil, fmt.Errorf("encoding Sign: %v", err) + } + curve, err := tpmutil.Pack(p.CurveID) + if err != nil { + return nil, fmt.Errorf("encoding CurveID: %v", err) + } + kdf, err := p.KDF.encode() + if err != nil { + return nil, fmt.Errorf("encoding KDF: %v", err) + } + point, err := tpmutil.Pack(p.Point.XRaw, p.Point.YRaw) + if err != nil { + return nil, fmt.Errorf("encoding Point: %v", err) + } + return concat(sym, sig, curve, kdf, point) +} + +func decodeECCParams(in *bytes.Buffer) (*ECCParams, error) { + var params ECCParams + var err error + + if params.Symmetric, err = decodeSymScheme(in); err != nil { + return nil, fmt.Errorf("decoding Symmetric: %v", err) + } + if params.Sign, err = decodeSigScheme(in); err != nil { + return nil, fmt.Errorf("decoding Sign: %v", err) + } + if err := tpmutil.UnpackBuf(in, ¶ms.CurveID); err != nil { + return nil, fmt.Errorf("decoding CurveID: %v", err) + } + if params.KDF, err = decodeKDFScheme(in); err != nil { + return nil, fmt.Errorf("decoding KDF: %v", err) + } + if err := tpmutil.UnpackBuf(in, ¶ms.Point.XRaw, ¶ms.Point.YRaw); err != nil { + return nil, fmt.Errorf("decoding Point: %v", err) + } + return ¶ms, nil +} + +// SymCipherParams represents parameters of a symmetric block cipher TPM object: +// both the TPMS_SYMCIPHER_PARMS and the TPM2B_DIGEST (hash of the key). +type SymCipherParams struct { + Symmetric *SymScheme + Unique tpmutil.U16Bytes +} + +func (p *SymCipherParams) matchesTemplate(t *SymCipherParams) bool { + return reflect.DeepEqual(p.Symmetric, t.Symmetric) +} + +func (p *SymCipherParams) encode() ([]byte, error) { + sym, err := p.Symmetric.encode() + if err != nil { + return nil, fmt.Errorf("encoding Symmetric: %v", err) + } + unique, err := tpmutil.Pack(p.Unique) + if err != nil { + return nil, fmt.Errorf("encoding Unique: %v", err) + } + return concat(sym, unique) +} + +func decodeSymCipherParams(in *bytes.Buffer) (*SymCipherParams, error) { + var params SymCipherParams + var err error + + if params.Symmetric, err = decodeSymScheme(in); err != nil { + return nil, fmt.Errorf("decoding Symmetric: %v", err) + } + if err := tpmutil.UnpackBuf(in, ¶ms.Unique); err != nil { + return nil, fmt.Errorf("decoding Unique: %v", err) + } + return ¶ms, nil +} + +// KeyedHashParams represents parameters of a keyed hash TPM object: +// both the TPMS_KEYEDHASH_PARMS and the TPM2B_DIGEST (hash of the key). +type KeyedHashParams struct { + Alg Algorithm + Hash Algorithm + KDF Algorithm + Unique tpmutil.U16Bytes +} + +func (p *KeyedHashParams) matchesTemplate(t *KeyedHashParams) bool { + if p.Alg != t.Alg { + return false + } + switch p.Alg { + case AlgHMAC: + return p.Hash == t.Hash + case AlgXOR: + return p.Hash == t.Hash && p.KDF == t.KDF + default: + return true + } +} + +func (p *KeyedHashParams) encode() ([]byte, error) { + if p == nil { + return tpmutil.Pack(AlgNull, tpmutil.U16Bytes(nil)) + } + var params []byte + var err error + switch p.Alg { + case AlgNull: + params, err = tpmutil.Pack(p.Alg) + case AlgHMAC: + params, err = tpmutil.Pack(p.Alg, p.Hash) + case AlgXOR: + params, err = tpmutil.Pack(p.Alg, p.Hash, p.KDF) + default: + err = fmt.Errorf("unsupported KeyedHash Algorithm: 0x%x", p.Alg) + } + if err != nil { + return nil, fmt.Errorf("encoding Alg Params: %v", err) + } + unique, err := tpmutil.Pack(p.Unique) + if err != nil { + return nil, fmt.Errorf("encoding Unique: %v", err) + } + return concat(params, unique) +} + +func decodeKeyedHashParams(in *bytes.Buffer) (*KeyedHashParams, error) { + var p KeyedHashParams + var err error + if err = tpmutil.UnpackBuf(in, &p.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + switch p.Alg { + case AlgNull: + err = nil + case AlgHMAC: + err = tpmutil.UnpackBuf(in, &p.Hash) + case AlgXOR: + err = tpmutil.UnpackBuf(in, &p.Hash, &p.KDF) + default: + err = fmt.Errorf("unsupported KeyedHash Algorithm: 0x%x", p.Alg) + } + if err != nil { + return nil, fmt.Errorf("decoding Alg Params: %v", err) + } + if err = tpmutil.UnpackBuf(in, &p.Unique); err != nil { + return nil, fmt.Errorf("decoding Unique: %v", err) + } + return &p, nil +} + +// SymScheme represents a symmetric encryption scheme. +// Known in the specification by TPMT_SYM_DEF_OBJECT. +type SymScheme struct { + Alg Algorithm + KeyBits uint16 + Mode Algorithm +} + +func (s *SymScheme) encode() ([]byte, error) { + if s == nil || s.Alg.IsNull() { + return tpmutil.Pack(AlgNull) + } + return tpmutil.Pack(s.Alg, s.KeyBits, s.Mode) +} + +func decodeSymScheme(in *bytes.Buffer) (*SymScheme, error) { + var scheme SymScheme + if err := tpmutil.UnpackBuf(in, &scheme.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + if scheme.Alg == AlgNull { + return nil, nil + } + if err := tpmutil.UnpackBuf(in, &scheme.KeyBits, &scheme.Mode); err != nil { + return nil, fmt.Errorf("decoding KeyBits, Mode: %v", err) + } + return &scheme, nil +} + +// AsymScheme represents am asymmetric encryption scheme. +type AsymScheme struct { + Alg Algorithm + Hash Algorithm +} + +func (s *AsymScheme) encode() ([]byte, error) { + if s == nil || s.Alg.IsNull() { + return tpmutil.Pack(AlgNull) + } + if s.Alg.UsesHash() { + return tpmutil.Pack(s.Alg, s.Hash) + } + return tpmutil.Pack(s.Alg) +} + +// SigScheme represents a signing scheme. +type SigScheme struct { + Alg Algorithm + Hash Algorithm + Count uint32 +} + +func (s *SigScheme) encode() ([]byte, error) { + if s == nil || s.Alg.IsNull() { + return tpmutil.Pack(AlgNull) + } + if s.Alg.UsesCount() { + return tpmutil.Pack(s.Alg, s.Hash, s.Count) + } + return tpmutil.Pack(s.Alg, s.Hash) +} + +func decodeSigScheme(in *bytes.Buffer) (*SigScheme, error) { + var scheme SigScheme + if err := tpmutil.UnpackBuf(in, &scheme.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + if scheme.Alg == AlgNull { + return nil, nil + } + if err := tpmutil.UnpackBuf(in, &scheme.Hash); err != nil { + return nil, fmt.Errorf("decoding Hash: %v", err) + } + if scheme.Alg.UsesCount() { + if err := tpmutil.UnpackBuf(in, &scheme.Count); err != nil { + return nil, fmt.Errorf("decoding Count: %v", err) + } + } + return &scheme, nil +} + +// KDFScheme represents a KDF (Key Derivation Function) scheme. +type KDFScheme struct { + Alg Algorithm + Hash Algorithm +} + +func (s *KDFScheme) encode() ([]byte, error) { + if s == nil || s.Alg.IsNull() { + return tpmutil.Pack(AlgNull) + } + return tpmutil.Pack(s.Alg, s.Hash) +} + +func decodeKDFScheme(in *bytes.Buffer) (*KDFScheme, error) { + var scheme KDFScheme + if err := tpmutil.UnpackBuf(in, &scheme.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + if scheme.Alg == AlgNull { + return nil, nil + } + if err := tpmutil.UnpackBuf(in, &scheme.Hash); err != nil { + return nil, fmt.Errorf("decoding Hash: %v", err) + } + return &scheme, nil +} + +// Signature combines all possible signatures from RSA and ECC keys. Only one +// of RSA or ECC will be populated. +type Signature struct { + Alg Algorithm + RSA *SignatureRSA + ECC *SignatureECC +} + +// Encode serializes a Signature structure in TPM wire format. +func (s Signature) Encode() ([]byte, error) { + head, err := tpmutil.Pack(s.Alg) + if err != nil { + return nil, fmt.Errorf("encoding Alg: %v", err) + } + var signature []byte + switch s.Alg { + case AlgRSASSA, AlgRSAPSS: + if signature, err = tpmutil.Pack(s.RSA); err != nil { + return nil, fmt.Errorf("encoding RSA: %v", err) + } + case AlgECDSA: + signature, err = tpmutil.Pack(s.ECC.HashAlg, tpmutil.U16Bytes(s.ECC.R.Bytes()), tpmutil.U16Bytes(s.ECC.S.Bytes())) + if err != nil { + return nil, fmt.Errorf("encoding ECC: %v", err) + } + } + return concat(head, signature) +} + +// DecodeSignature decodes a serialized TPMT_SIGNATURE structure. +func DecodeSignature(in *bytes.Buffer) (*Signature, error) { + var sig Signature + if err := tpmutil.UnpackBuf(in, &sig.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + switch sig.Alg { + case AlgRSASSA, AlgRSAPSS: + sig.RSA = new(SignatureRSA) + if err := tpmutil.UnpackBuf(in, sig.RSA); err != nil { + return nil, fmt.Errorf("decoding RSA: %v", err) + } + case AlgECDSA: + sig.ECC = new(SignatureECC) + var r, s tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(in, &sig.ECC.HashAlg, &r, &s); err != nil { + return nil, fmt.Errorf("decoding ECC: %v", err) + } + sig.ECC.R = big.NewInt(0).SetBytes(r) + sig.ECC.S = big.NewInt(0).SetBytes(s) + default: + return nil, fmt.Errorf("unsupported signature algorithm 0x%x", sig.Alg) + } + return &sig, nil +} + +// SignatureRSA is an RSA-specific signature value. +type SignatureRSA struct { + HashAlg Algorithm + Signature tpmutil.U16Bytes +} + +// SignatureECC is an ECC-specific signature value. +type SignatureECC struct { + HashAlg Algorithm + R *big.Int + S *big.Int +} + +// Private contains private section of a TPM key. +type Private struct { + Type Algorithm + AuthValue tpmutil.U16Bytes + SeedValue tpmutil.U16Bytes + Sensitive tpmutil.U16Bytes +} + +// Encode serializes a Private structure in TPM wire format. +func (p Private) Encode() ([]byte, error) { + if p.Type.IsNull() { + return nil, nil + } + return tpmutil.Pack(p) +} + +// AttestationData contains data attested by TPM commands (like Certify). +type AttestationData struct { + Magic uint32 + Type tpmutil.Tag + QualifiedSigner Name + ExtraData tpmutil.U16Bytes + ClockInfo ClockInfo + FirmwareVersion uint64 + AttestedCertifyInfo *CertifyInfo + AttestedQuoteInfo *QuoteInfo + AttestedCreationInfo *CreationInfo +} + +// DecodeAttestationData decode a TPMS_ATTEST message. No error is returned if +// the input has extra trailing data. +func DecodeAttestationData(in []byte) (*AttestationData, error) { + buf := bytes.NewBuffer(in) + + var ad AttestationData + if err := tpmutil.UnpackBuf(buf, &ad.Magic, &ad.Type); err != nil { + return nil, fmt.Errorf("decoding Magic/Type: %v", err) + } + // All attestation structures have the magic prefix + // TPMS_GENERATED_VALUE to symbolize they were created by + // the TPM when signed with an AK. + if ad.Magic != 0xff544347 { + return nil, fmt.Errorf("incorrect magic value: %x", ad.Magic) + } + + n, err := DecodeName(buf) + if err != nil { + return nil, fmt.Errorf("decoding QualifiedSigner: %v", err) + } + ad.QualifiedSigner = *n + if err := tpmutil.UnpackBuf(buf, &ad.ExtraData, &ad.ClockInfo, &ad.FirmwareVersion); err != nil { + return nil, fmt.Errorf("decoding ExtraData/ClockInfo/FirmwareVersion: %v", err) + } + + // The spec specifies several other types of attestation data. We only need + // parsing of Certify & Creation attestation data for now. If you need + // support for other attestation types, add them here. + switch ad.Type { + case TagAttestCertify: + if ad.AttestedCertifyInfo, err = decodeCertifyInfo(buf); err != nil { + return nil, fmt.Errorf("decoding AttestedCertifyInfo: %v", err) + } + case TagAttestCreation: + if ad.AttestedCreationInfo, err = decodeCreationInfo(buf); err != nil { + return nil, fmt.Errorf("decoding AttestedCreationInfo: %v", err) + } + case TagAttestQuote: + if ad.AttestedQuoteInfo, err = decodeQuoteInfo(buf); err != nil { + return nil, fmt.Errorf("decoding AttestedQuoteInfo: %v", err) + } + default: + return nil, fmt.Errorf("only Quote, Certify & Creation attestation structures are supported, got type 0x%x", ad.Type) + } + + return &ad, nil +} + +// Encode serializes an AttestationData structure in TPM wire format. +func (ad AttestationData) Encode() ([]byte, error) { + head, err := tpmutil.Pack(ad.Magic, ad.Type) + if err != nil { + return nil, fmt.Errorf("encoding Magic, Type: %v", err) + } + signer, err := ad.QualifiedSigner.Encode() + if err != nil { + return nil, fmt.Errorf("encoding QualifiedSigner: %v", err) + } + tail, err := tpmutil.Pack(ad.ExtraData, ad.ClockInfo, ad.FirmwareVersion) + if err != nil { + return nil, fmt.Errorf("encoding ExtraData, ClockInfo, FirmwareVersion: %v", err) + } + + var info []byte + switch ad.Type { + case TagAttestCertify: + if info, err = ad.AttestedCertifyInfo.encode(); err != nil { + return nil, fmt.Errorf("encoding AttestedCertifyInfo: %v", err) + } + case TagAttestCreation: + if info, err = ad.AttestedCreationInfo.encode(); err != nil { + return nil, fmt.Errorf("encoding AttestedCreationInfo: %v", err) + } + case TagAttestQuote: + if info, err = ad.AttestedQuoteInfo.encode(); err != nil { + return nil, fmt.Errorf("encoding AttestedQuoteInfo: %v", err) + } + default: + return nil, fmt.Errorf("only Quote, Certify & Creation attestation structures are supported, got type 0x%x", ad.Type) + } + + return concat(head, signer, tail, info) +} + +// CreationInfo contains Creation-specific data for TPMS_ATTEST. +type CreationInfo struct { + Name Name + // Most TPM2B_Digest structures contain a TPMU_HA structure + // and get parsed to HashValue. This is never the case for the + // digest in TPMS_CREATION_INFO. + OpaqueDigest tpmutil.U16Bytes +} + +func decodeCreationInfo(in *bytes.Buffer) (*CreationInfo, error) { + var ci CreationInfo + + n, err := DecodeName(in) + if err != nil { + return nil, fmt.Errorf("decoding Name: %v", err) + } + ci.Name = *n + + if err := tpmutil.UnpackBuf(in, &ci.OpaqueDigest); err != nil { + return nil, fmt.Errorf("decoding Digest: %v", err) + } + + return &ci, nil +} + +func (ci CreationInfo) encode() ([]byte, error) { + n, err := ci.Name.Encode() + if err != nil { + return nil, fmt.Errorf("encoding Name: %v", err) + } + + d, err := tpmutil.Pack(ci.OpaqueDigest) + if err != nil { + return nil, fmt.Errorf("encoding Digest: %v", err) + } + + return concat(n, d) +} + +// CertifyInfo contains Certify-specific data for TPMS_ATTEST. +type CertifyInfo struct { + Name Name + QualifiedName Name +} + +func decodeCertifyInfo(in *bytes.Buffer) (*CertifyInfo, error) { + var ci CertifyInfo + + n, err := DecodeName(in) + if err != nil { + return nil, fmt.Errorf("decoding Name: %v", err) + } + ci.Name = *n + + n, err = DecodeName(in) + if err != nil { + return nil, fmt.Errorf("decoding QualifiedName: %v", err) + } + ci.QualifiedName = *n + + return &ci, nil +} + +func (ci CertifyInfo) encode() ([]byte, error) { + n, err := ci.Name.Encode() + if err != nil { + return nil, fmt.Errorf("encoding Name: %v", err) + } + qn, err := ci.QualifiedName.Encode() + if err != nil { + return nil, fmt.Errorf("encoding QualifiedName: %v", err) + } + return concat(n, qn) +} + +// QuoteInfo represents a TPMS_QUOTE_INFO structure. +type QuoteInfo struct { + PCRSelection PCRSelection + PCRDigest tpmutil.U16Bytes +} + +func decodeQuoteInfo(in *bytes.Buffer) (*QuoteInfo, error) { + var out QuoteInfo + sel, err := decodeOneTPMLPCRSelection(in) + if err != nil { + return nil, fmt.Errorf("decoding PCRSelection: %v", err) + } + out.PCRSelection = sel + + if err := tpmutil.UnpackBuf(in, &out.PCRDigest); err != nil { + return nil, fmt.Errorf("decoding PCRDigest: %v", err) + } + return &out, nil +} + +func (qi QuoteInfo) encode() ([]byte, error) { + sel, err := encodeTPMLPCRSelection(qi.PCRSelection) + if err != nil { + return nil, fmt.Errorf("encoding PCRSelection: %v", err) + } + + digest, err := tpmutil.Pack(qi.PCRDigest) + if err != nil { + return nil, fmt.Errorf("encoding PCRDigest: %v", err) + } + + return concat(sel, digest) +} + +// IDObject represents an encrypted credential bound to a TPM object. +type IDObject struct { + IntegrityHMAC tpmutil.U16Bytes + // EncIdentity is packed raw, as the bytes representing the size + // of the credential value are present within the encrypted blob. + EncIdentity tpmutil.RawBytes +} + +// CreationData describes the attributes and environment for an object created +// on the TPM. This structure encodes/decodes to/from TPMS_CREATION_DATA. +type CreationData struct { + PCRSelection PCRSelection + PCRDigest tpmutil.U16Bytes + Locality byte + ParentNameAlg Algorithm + ParentName Name + ParentQualifiedName Name + OutsideInfo tpmutil.U16Bytes +} + +// EncodeCreationData encodes byte array to TPMS_CREATION_DATA message. +func (cd *CreationData) EncodeCreationData() ([]byte, error) { + sel, err := encodeTPMLPCRSelection(cd.PCRSelection) + if err != nil { + return nil, fmt.Errorf("encoding PCRSelection: %v", err) + } + d, err := tpmutil.Pack(cd.PCRDigest, cd.Locality, cd.ParentNameAlg) + if err != nil { + return nil, fmt.Errorf("encoding PCRDigest, Locality, ParentNameAlg: %v", err) + } + pn, err := cd.ParentName.Encode() + if err != nil { + return nil, fmt.Errorf("encoding ParentName: %v", err) + } + pqn, err := cd.ParentQualifiedName.Encode() + if err != nil { + return nil, fmt.Errorf("encoding ParentQualifiedName: %v", err) + } + o, err := tpmutil.Pack(cd.OutsideInfo) + if err != nil { + return nil, fmt.Errorf("encoding OutsideInfo: %v", err) + } + return concat(sel, d, pn, pqn, o) +} + +// DecodeCreationData decodes a TPMS_CREATION_DATA message. No error is +// returned if the input has extra trailing data. +func DecodeCreationData(buf []byte) (*CreationData, error) { + in := bytes.NewBuffer(buf) + var out CreationData + + sel, err := decodeOneTPMLPCRSelection(in) + if err != nil { + return nil, fmt.Errorf("decodeOneTPMLPCRSelection returned error %v", err) + } + out.PCRSelection = sel + + if err := tpmutil.UnpackBuf(in, &out.PCRDigest, &out.Locality, &out.ParentNameAlg); err != nil { + return nil, fmt.Errorf("decoding PCRDigest, Locality, ParentNameAlg: %v", err) + } + + n, err := DecodeName(in) + if err != nil { + return nil, fmt.Errorf("decoding ParentName: %v", err) + } + out.ParentName = *n + if n, err = DecodeName(in); err != nil { + return nil, fmt.Errorf("decoding ParentQualifiedName: %v", err) + } + out.ParentQualifiedName = *n + + if err := tpmutil.UnpackBuf(in, &out.OutsideInfo); err != nil { + return nil, fmt.Errorf("decoding OutsideInfo: %v", err) + } + + return &out, nil +} + +// Name represents a TPM2B_NAME, a name for TPM entities. Only one of +// Handle or Digest should be set. +type Name struct { + Handle *tpmutil.Handle + Digest *HashValue +} + +// DecodeName deserializes a Name hash from the TPM wire format. +func DecodeName(in *bytes.Buffer) (*Name, error) { + var nameBuf tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(in, &nameBuf); err != nil { + return nil, err + } + + name := new(Name) + switch len(nameBuf) { + case 0: + // No name is present. + case 4: + name.Handle = new(tpmutil.Handle) + if err := tpmutil.UnpackBuf(bytes.NewBuffer(nameBuf), name.Handle); err != nil { + return nil, fmt.Errorf("decoding Handle: %v", err) + } + default: + var err error + name.Digest, err = decodeHashValue(bytes.NewBuffer(nameBuf)) + if err != nil { + return nil, fmt.Errorf("decoding Digest: %v", err) + } + } + return name, nil +} + +// Encode serializes a Name hash into the TPM wire format. +func (n Name) Encode() ([]byte, error) { + var buf []byte + var err error + switch { + case n.Handle != nil: + if buf, err = tpmutil.Pack(*n.Handle); err != nil { + return nil, fmt.Errorf("encoding Handle: %v", err) + } + case n.Digest != nil: + if buf, err = n.Digest.Encode(); err != nil { + return nil, fmt.Errorf("encoding Digest: %v", err) + } + default: + // Name is empty, which is valid. + } + return tpmutil.Pack(tpmutil.U16Bytes(buf)) +} + +// MatchesPublic compares Digest in Name against given Public structure. Note: +// this only works for regular Names, not Qualified Names. +func (n Name) MatchesPublic(p Public) (bool, error) { + if n.Digest == nil { + return false, errors.New("Name doesn't have a Digest, can't compare to Public") + } + expected, err := p.Name() + if err != nil { + return false, err + } + // No secrets, so no constant-time comparison + return bytes.Equal(expected.Digest.Value, n.Digest.Value), nil +} + +// HashValue is an algorithm-specific hash value. +type HashValue struct { + Alg Algorithm + Value tpmutil.U16Bytes +} + +func decodeHashValue(in *bytes.Buffer) (*HashValue, error) { + var hv HashValue + if err := tpmutil.UnpackBuf(in, &hv.Alg); err != nil { + return nil, fmt.Errorf("decoding Alg: %v", err) + } + hfn, err := hv.Alg.Hash() + if err != nil { + return nil, err + } + hv.Value = make(tpmutil.U16Bytes, hfn.Size()) + if _, err := in.Read(hv.Value); err != nil { + return nil, fmt.Errorf("decoding Value: %v", err) + } + return &hv, nil +} + +// Encode represents the given hash value as a TPMT_HA structure. +func (hv HashValue) Encode() ([]byte, error) { + return tpmutil.Pack(hv.Alg, tpmutil.RawBytes(hv.Value)) +} + +// ClockInfo contains TPM state info included in AttestationData. +type ClockInfo struct { + Clock uint64 + ResetCount uint32 + RestartCount uint32 + Safe byte +} + +// AlgorithmAttributes represents a TPMA_ALGORITHM value. +type AlgorithmAttributes uint32 + +// AlgorithmDescription represents a TPMS_ALGORITHM_DESCRIPTION structure. +type AlgorithmDescription struct { + ID Algorithm + Attributes AlgorithmAttributes +} + +// TaggedProperty represents a TPMS_TAGGED_PROPERTY structure. +type TaggedProperty struct { + Tag TPMProp + Value uint32 +} + +// Ticket represents evidence the TPM previously processed +// information. +type Ticket struct { + Type tpmutil.Tag + Hierarchy tpmutil.Handle + Digest tpmutil.U16Bytes +} + +// AuthCommand represents a TPMS_AUTH_COMMAND. This structure encapsulates parameters +// which authorize the use of a given handle or parameter. +type AuthCommand struct { + Session tpmutil.Handle + Nonce tpmutil.U16Bytes + Attributes SessionAttributes + Auth tpmutil.U16Bytes +} + +// TPMLDigest represents the TPML_Digest structure +// It is used to convey a list of digest values. +// This type is used in TPM2_PolicyOR() and in TPM2_PCR_Read() +type TPMLDigest struct { + Digests []tpmutil.U16Bytes +} + +// Encode converts the TPMLDigest structure into a byte slice +func (list *TPMLDigest) Encode() ([]byte, error) { + res, err := tpmutil.Pack(uint32(len(list.Digests))) + if err != nil { + return nil, err + } + for _, item := range list.Digests { + b, err := tpmutil.Pack(item) + if err != nil { + return nil, err + } + res = append(res, b...) + + } + return res, nil +} + +// DecodeTPMLDigest decodes a TPML_Digest part of a message. +func DecodeTPMLDigest(buf []byte) (*TPMLDigest, error) { + in := bytes.NewBuffer(buf) + var tpmld TPMLDigest + var count uint32 + if err := binary.Read(in, binary.BigEndian, &count); err != nil { + return nil, fmt.Errorf("decoding TPML_Digest: %v", err) + } + for in.Len() > 0 { + var hash tpmutil.U16Bytes + if err := hash.TPMUnmarshal(in); err != nil { + return nil, err + } + tpmld.Digests = append(tpmld.Digests, hash) + } + if count != uint32(len(tpmld.Digests)) { + return nil, fmt.Errorf("expected size and read size does not match") + } + return &tpmld, nil +} diff --git a/vendor/github.com/google/go-tpm/legacy/tpm2/tpm2.go b/vendor/github.com/google/go-tpm/legacy/tpm2/tpm2.go new file mode 100644 index 0000000000..18d5a96033 --- /dev/null +++ b/vendor/github.com/google/go-tpm/legacy/tpm2/tpm2.go @@ -0,0 +1,2326 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tpm2 supports direct communication with a TPM 2.0 device under Linux. +package tpm2 + +import ( + "bytes" + "crypto" + "fmt" + "io" + + "github.com/google/go-tpm/tpmutil" +) + +// GetRandom gets random bytes from the TPM. +func GetRandom(rw io.ReadWriter, size uint16) ([]byte, error) { + resp, err := runCommand(rw, TagNoSessions, CmdGetRandom, size) + if err != nil { + return nil, err + } + + var randBytes tpmutil.U16Bytes + if _, err := tpmutil.Unpack(resp, &randBytes); err != nil { + return nil, err + } + return randBytes, nil +} + +// FlushContext removes an object or session under handle to be removed from +// the TPM. This must be called for any loaded handle to avoid out-of-memory +// errors in TPM. +func FlushContext(rw io.ReadWriter, handle tpmutil.Handle) error { + _, err := runCommand(rw, TagNoSessions, CmdFlushContext, handle) + return err +} + +func encodeTPMLPCRSelection(sel ...PCRSelection) ([]byte, error) { + if len(sel) == 0 { + return tpmutil.Pack(uint32(0)) + } + + // PCR selection is a variable-size bitmask, where position of a set bit is + // the selected PCR index. + // Size of the bitmask in bytes is pre-pended. It should be at least + // sizeOfPCRSelect. + // + // For example, selecting PCRs 3 and 9 looks like: + // size(3) mask mask mask + // 00000011 00000000 00000001 00000100 + var retBytes []byte + for _, s := range sel { + if len(s.PCRs) == 0 { + return tpmutil.Pack(uint32(0)) + } + + ts := tpmsPCRSelection{ + Hash: s.Hash, + Size: sizeOfPCRSelect, + PCRs: make(tpmutil.RawBytes, sizeOfPCRSelect), + } + + // s[i].PCRs parameter is indexes of PCRs, convert that to set bits. + for _, n := range s.PCRs { + if n >= 8*sizeOfPCRSelect { + return nil, fmt.Errorf("PCR index %d is out of range (exceeds maximum value %d)", n, 8*sizeOfPCRSelect-1) + } + byteNum := n / 8 + bytePos := byte(1 << byte(n%8)) + ts.PCRs[byteNum] |= bytePos + } + + tmpBytes, err := tpmutil.Pack(ts) + if err != nil { + return nil, err + } + + retBytes = append(retBytes, tmpBytes...) + } + tmpSize, err := tpmutil.Pack(uint32(len(sel))) + if err != nil { + return nil, err + } + retBytes = append(tmpSize, retBytes...) + + return retBytes, nil +} + +func decodeTPMLPCRSelection(buf *bytes.Buffer) ([]PCRSelection, error) { + var count uint32 + var sel []PCRSelection + + // This unpacks buffer which is of type TPMLPCRSelection + // and returns the count of TPMSPCRSelections. + if err := tpmutil.UnpackBuf(buf, &count); err != nil { + return sel, err + } + + var ts tpmsPCRSelection + for i := 0; i < int(count); i++ { + var s PCRSelection + if err := tpmutil.UnpackBuf(buf, &ts.Hash, &ts.Size); err != nil { + return sel, err + } + ts.PCRs = make(tpmutil.RawBytes, ts.Size) + if _, err := buf.Read(ts.PCRs); err != nil { + return sel, err + } + s.Hash = ts.Hash + for j := 0; j < int(ts.Size); j++ { + for k := 0; k < 8; k++ { + set := ts.PCRs[j] & byte(1< 0, nil + case CapabilityAlgs: + var numAlgs uint32 + if err := tpmutil.UnpackBuf(buf, &numAlgs); err != nil { + return nil, false, fmt.Errorf("could not unpack algorithm count: %v", err) + } + + var algs []interface{} + for i := 0; i < int(numAlgs); i++ { + var alg AlgorithmDescription + if err := tpmutil.UnpackBuf(buf, &alg); err != nil { + return nil, false, fmt.Errorf("could not unpack algorithm description: %v", err) + } + algs = append(algs, alg) + } + return algs, moreData > 0, nil + case CapabilityTPMProperties: + var numProps uint32 + if err := tpmutil.UnpackBuf(buf, &numProps); err != nil { + return nil, false, fmt.Errorf("could not unpack fixed properties count: %v", err) + } + + var props []interface{} + for i := 0; i < int(numProps); i++ { + var prop TaggedProperty + if err := tpmutil.UnpackBuf(buf, &prop); err != nil { + return nil, false, fmt.Errorf("could not unpack tagged property: %v", err) + } + props = append(props, prop) + } + return props, moreData > 0, nil + + case CapabilityPCRs: + var pcrss []interface{} + pcrs, err := decodeTPMLPCRSelection(buf) + if err != nil { + return nil, false, fmt.Errorf("could not unpack pcr selection: %v", err) + } + for i := 0; i < len(pcrs); i++ { + pcrss = append(pcrss, pcrs[i]) + } + + return pcrss, moreData > 0, nil + + default: + return nil, false, fmt.Errorf("unsupported capability %v", capReported) + } +} + +// GetCapability returns various information about the TPM state. +// +// Currently only CapabilityHandles (list active handles) and CapabilityAlgs +// (list supported algorithms) are supported. CapabilityHandles will return +// a []tpmutil.Handle for vals, CapabilityAlgs will return +// []AlgorithmDescription. +// +// moreData is true if the TPM indicated that more data is available. Follow +// the spec for the capability in question on how to query for more data. +func GetCapability(rw io.ReadWriter, capa Capability, count, property uint32) (vals []interface{}, moreData bool, err error) { + resp, err := runCommand(rw, TagNoSessions, CmdGetCapability, capa, property, count) + if err != nil { + return nil, false, err + } + return decodeGetCapability(resp) +} + +// GetManufacturer returns the manufacturer ID +func GetManufacturer(rw io.ReadWriter) ([]byte, error) { + caps, _, err := GetCapability(rw, CapabilityTPMProperties, 1, uint32(Manufacturer)) + if err != nil { + return nil, err + } + + prop := caps[0].(TaggedProperty) + return tpmutil.Pack(prop.Value) +} + +func encodeAuthArea(sections ...AuthCommand) ([]byte, error) { + var res tpmutil.RawBytes + for _, s := range sections { + buf, err := tpmutil.Pack(s) + if err != nil { + return nil, err + } + res = append(res, buf...) + } + + size, err := tpmutil.Pack(uint32(len(res))) + if err != nil { + return nil, err + } + + return concat(size, res) +} + +func encodePCREvent(pcr tpmutil.Handle, eventData []byte) ([]byte, error) { + ha, err := tpmutil.Pack(pcr) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: EmptyAuth}) + if err != nil { + return nil, err + } + event, err := tpmutil.Pack(tpmutil.U16Bytes(eventData)) + if err != nil { + return nil, err + } + return concat(ha, auth, event) +} + +// PCREvent writes an update to the specified PCR. +func PCREvent(rw io.ReadWriter, pcr tpmutil.Handle, eventData []byte) error { + Cmd, err := encodePCREvent(pcr, eventData) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdPCREvent, tpmutil.RawBytes(Cmd)) + return err +} + +func encodeSensitiveArea(s tpmsSensitiveCreate) ([]byte, error) { + // TPMS_SENSITIVE_CREATE + buf, err := tpmutil.Pack(s) + if err != nil { + return nil, err + } + // TPM2B_SENSITIVE_CREATE + return tpmutil.Pack(tpmutil.U16Bytes(buf)) +} + +// encodeCreate works for both TPM2_Create and TPM2_CreatePrimary. +func encodeCreate(owner tpmutil.Handle, sel PCRSelection, auth AuthCommand, ownerPassword string, sensitiveData []byte, pub Public, outsideInfo []byte) ([]byte, error) { + parent, err := tpmutil.Pack(owner) + if err != nil { + return nil, err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return nil, err + } + inSensitive, err := encodeSensitiveArea(tpmsSensitiveCreate{ + UserAuth: []byte(ownerPassword), + Data: sensitiveData, + }) + if err != nil { + return nil, err + } + inPublic, err := pub.Encode() + if err != nil { + return nil, err + } + publicBlob, err := tpmutil.Pack(tpmutil.U16Bytes(inPublic)) + if err != nil { + return nil, err + } + outsideInfoBlob, err := tpmutil.Pack(tpmutil.U16Bytes(outsideInfo)) + if err != nil { + return nil, err + } + creationPCR, err := encodeTPMLPCRSelection(sel) + if err != nil { + return nil, err + } + return concat( + parent, + encodedAuth, + inSensitive, + publicBlob, + outsideInfoBlob, + creationPCR, + ) +} + +func decodeCreatePrimary(in []byte) (handle tpmutil.Handle, public, creationData, creationHash tpmutil.U16Bytes, ticket Ticket, creationName tpmutil.U16Bytes, err error) { + var paramSize uint32 + + buf := bytes.NewBuffer(in) + // Handle and auth data. + if err := tpmutil.UnpackBuf(buf, &handle, ¶mSize); err != nil { + return 0, nil, nil, nil, Ticket{}, nil, fmt.Errorf("decoding handle, paramSize: %v", err) + } + + if err := tpmutil.UnpackBuf(buf, &public, &creationData, &creationHash, &ticket, &creationName); err != nil { + return 0, nil, nil, nil, Ticket{}, nil, fmt.Errorf("decoding public, creationData, creationHash, ticket, creationName: %v", err) + } + + if _, err := DecodeCreationData(creationData); err != nil { + return 0, nil, nil, nil, Ticket{}, nil, fmt.Errorf("parsing CreationData: %v", err) + } + return handle, public, creationData, creationHash, ticket, creationName, err +} + +// CreatePrimary initializes the primary key in a given hierarchy. +// The second return value is the public part of the generated key. +func CreatePrimary(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, p Public) (tpmutil.Handle, crypto.PublicKey, error) { + hnd, public, _, _, _, _, err := CreatePrimaryEx(rw, owner, sel, parentPassword, ownerPassword, p) + if err != nil { + return 0, nil, err + } + + pub, err := DecodePublic(public) + if err != nil { + return 0, nil, fmt.Errorf("parsing public: %v", err) + } + + pubKey, err := pub.Key() + if err != nil { + return 0, nil, fmt.Errorf("extracting cryto.PublicKey from Public part of primary key: %v", err) + } + + return hnd, pubKey, err +} + +// CreatePrimaryEx initializes the primary key in a given hierarchy. +// This function differs from CreatePrimary in that all response elements +// are returned, and they are returned in relatively raw form. +func CreatePrimaryEx(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, pub Public) (keyHandle tpmutil.Handle, public, creationData, creationHash []byte, ticket Ticket, creationName []byte, err error) { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentPassword)} + Cmd, err := encodeCreate(owner, sel, auth, ownerPassword, nil /*inSensitive*/, pub, nil /*OutsideInfo*/) + if err != nil { + return 0, nil, nil, nil, Ticket{}, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdCreatePrimary, tpmutil.RawBytes(Cmd)) + if err != nil { + return 0, nil, nil, nil, Ticket{}, nil, err + } + + return decodeCreatePrimary(resp) +} + +// CreatePrimaryRawTemplate is CreatePrimary, but with the public template +// (TPMT_PUBLIC) provided pre-encoded. This is commonly used with key templates +// stored in NV RAM. +func CreatePrimaryRawTemplate(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, public []byte) (tpmutil.Handle, crypto.PublicKey, error) { + pub, err := DecodePublic(public) + if err != nil { + return 0, nil, fmt.Errorf("parsing input template: %v", err) + } + return CreatePrimary(rw, owner, sel, parentPassword, ownerPassword, pub) +} + +func decodeReadPublic(in []byte) (Public, []byte, []byte, error) { + var resp struct { + Public tpmutil.U16Bytes + Name tpmutil.U16Bytes + QualifiedName tpmutil.U16Bytes + } + if _, err := tpmutil.Unpack(in, &resp); err != nil { + return Public{}, nil, nil, err + } + pub, err := DecodePublic(resp.Public) + if err != nil { + return Public{}, nil, nil, err + } + return pub, resp.Name, resp.QualifiedName, nil +} + +// ReadPublic reads the public part of the object under handle. +// Returns the public data, name and qualified name. +func ReadPublic(rw io.ReadWriter, handle tpmutil.Handle) (Public, []byte, []byte, error) { + resp, err := runCommand(rw, TagNoSessions, CmdReadPublic, handle) + if err != nil { + return Public{}, nil, nil, err + } + + return decodeReadPublic(resp) +} + +func decodeCreate(in []byte) (private, public, creationData, creationHash tpmutil.U16Bytes, creationTicket Ticket, err error) { + buf := bytes.NewBuffer(in) + var paramSize uint32 + if err := tpmutil.UnpackBuf(buf, ¶mSize, &private, &public, &creationData, &creationHash, &creationTicket); err != nil { + return nil, nil, nil, nil, Ticket{}, fmt.Errorf("decoding Handle, Private, Public, CreationData, CreationHash, CreationTicket: %v", err) + } + if err != nil { + return nil, nil, nil, nil, Ticket{}, fmt.Errorf("decoding CreationTicket: %v", err) + } + if _, err := DecodeCreationData(creationData); err != nil { + return nil, nil, nil, nil, Ticket{}, fmt.Errorf("decoding CreationData: %v", err) + } + return private, public, creationData, creationHash, creationTicket, nil +} + +func create(rw io.ReadWriter, parentHandle tpmutil.Handle, auth AuthCommand, objectPassword string, sensitiveData []byte, pub Public, pcrSelection PCRSelection, outsideInfo []byte) (private, public, creationData, creationHash []byte, creationTicket Ticket, err error) { + cmd, err := encodeCreate(parentHandle, pcrSelection, auth, objectPassword, sensitiveData, pub, outsideInfo) + if err != nil { + return nil, nil, nil, nil, Ticket{}, err + } + resp, err := runCommand(rw, TagSessions, CmdCreate, tpmutil.RawBytes(cmd)) + if err != nil { + return nil, nil, nil, nil, Ticket{}, err + } + return decodeCreate(resp) +} + +// CreateKey creates a new key pair under the owner handle. +// Returns private key and public key blobs as well as the +// creation data, a hash of said data and the creation ticket. +func CreateKey(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, pub Public) (private, public, creationData, creationHash []byte, creationTicket Ticket, err error) { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentPassword)} + return create(rw, owner, auth, ownerPassword, nil /*inSensitive*/, pub, sel, nil /*OutsideInfo*/) +} + +// CreateKeyUsingAuth creates a new key pair under the owner handle using the +// provided AuthCommand. Returns private key and public key blobs as well as +// the creation data, a hash of said data, and the creation ticket. +func CreateKeyUsingAuth(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, auth AuthCommand, ownerPassword string, pub Public) (private, public, creationData, creationHash []byte, creationTicket Ticket, err error) { + return create(rw, owner, auth, ownerPassword, nil /*inSensitive*/, pub, sel, nil /*OutsideInfo*/) +} + +// CreateKeyWithSensitive is very similar to CreateKey, except +// that it can take in a piece of sensitive data. +func CreateKeyWithSensitive(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, pub Public, sensitive []byte) (private, public, creationData, creationHash []byte, creationTicket Ticket, err error) { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentPassword)} + return create(rw, owner, auth, ownerPassword, sensitive, pub, sel, nil /*OutsideInfo*/) +} + +// CreateKeyWithOutsideInfo is very similar to CreateKey, except +// that it returns the outside information. +func CreateKeyWithOutsideInfo(rw io.ReadWriter, owner tpmutil.Handle, sel PCRSelection, parentPassword, ownerPassword string, pub Public, outsideInfo []byte) (private, public, creationData, creationHash []byte, creationTicket Ticket, err error) { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentPassword)} + return create(rw, owner, auth, ownerPassword, nil /*inSensitive*/, pub, sel, outsideInfo) +} + +// Seal creates a data blob object that seals the sensitive data under a parent and with a +// password and auth policy. Access to the parent must be available with a simple password. +// Returns private and public portions of the created object. +func Seal(rw io.ReadWriter, parentHandle tpmutil.Handle, parentPassword, objectPassword string, objectAuthPolicy []byte, sensitiveData []byte) ([]byte, []byte, error) { + inPublic := Public{ + Type: AlgKeyedHash, + NameAlg: AlgSHA256, + Attributes: FlagFixedTPM | FlagFixedParent, + AuthPolicy: objectAuthPolicy, + } + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentPassword)} + private, public, _, _, _, err := create(rw, parentHandle, auth, objectPassword, sensitiveData, inPublic, PCRSelection{}, nil /*OutsideInfo*/) + if err != nil { + return nil, nil, err + } + return private, public, nil +} + +func encodeImport(parentHandle tpmutil.Handle, auth AuthCommand, publicBlob, privateBlob, symSeed, encryptionKey tpmutil.U16Bytes, sym *SymScheme) ([]byte, error) { + ph, err := tpmutil.Pack(parentHandle) + if err != nil { + return nil, err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return nil, err + } + data, err := tpmutil.Pack(encryptionKey, publicBlob, privateBlob, symSeed) + if err != nil { + return nil, err + } + encodedScheme, err := sym.encode() + if err != nil { + return nil, err + } + + return concat(ph, encodedAuth, data, encodedScheme) +} + +func decodeImport(resp []byte) ([]byte, error) { + var paramSize uint32 + var outPrivate tpmutil.U16Bytes + _, err := tpmutil.Unpack(resp, ¶mSize, &outPrivate) + return outPrivate, err +} + +// Import allows a user to import a key created on a different computer +// or in a different TPM. The publicBlob and privateBlob must always be +// provided. symSeed should be non-nil iff an "outer wrapper" is used. Both of +// encryptionKey and sym should be non-nil iff an "inner wrapper" is used. +func Import(rw io.ReadWriter, parentHandle tpmutil.Handle, auth AuthCommand, publicBlob, privateBlob, symSeed, encryptionKey []byte, sym *SymScheme) ([]byte, error) { + Cmd, err := encodeImport(parentHandle, auth, publicBlob, privateBlob, symSeed, encryptionKey, sym) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdImport, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeImport(resp) +} + +func encodeLoad(parentHandle tpmutil.Handle, auth AuthCommand, publicBlob, privateBlob tpmutil.U16Bytes) ([]byte, error) { + ah, err := tpmutil.Pack(parentHandle) + if err != nil { + return nil, err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(privateBlob, publicBlob) + if err != nil { + return nil, err + } + return concat(ah, encodedAuth, params) +} + +func decodeLoad(in []byte) (tpmutil.Handle, []byte, error) { + var handle tpmutil.Handle + var paramSize uint32 + var name tpmutil.U16Bytes + + if _, err := tpmutil.Unpack(in, &handle, ¶mSize, &name); err != nil { + return 0, nil, err + } + + // Re-encode the name as a TPM2B_NAME so it can be parsed by DecodeName(). + b := &bytes.Buffer{} + if err := name.TPMMarshal(b); err != nil { + return 0, nil, err + } + return handle, b.Bytes(), nil +} + +// Load loads public/private blobs into an object in the TPM. +// Returns loaded object handle and its name. +func Load(rw io.ReadWriter, parentHandle tpmutil.Handle, parentAuth string, publicBlob, privateBlob []byte) (tpmutil.Handle, []byte, error) { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(parentAuth)} + return LoadUsingAuth(rw, parentHandle, auth, publicBlob, privateBlob) +} + +// LoadUsingAuth loads public/private blobs into an object in the TPM using the +// provided AuthCommand. Returns loaded object handle and its name. +func LoadUsingAuth(rw io.ReadWriter, parentHandle tpmutil.Handle, auth AuthCommand, publicBlob, privateBlob []byte) (tpmutil.Handle, []byte, error) { + Cmd, err := encodeLoad(parentHandle, auth, publicBlob, privateBlob) + if err != nil { + return 0, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdLoad, tpmutil.RawBytes(Cmd)) + if err != nil { + return 0, nil, err + } + return decodeLoad(resp) +} + +func encodeLoadExternal(pub Public, private Private, hierarchy tpmutil.Handle) ([]byte, error) { + privateBlob, err := private.Encode() + if err != nil { + return nil, err + } + publicBlob, err := pub.Encode() + if err != nil { + return nil, err + } + + return tpmutil.Pack(tpmutil.U16Bytes(privateBlob), tpmutil.U16Bytes(publicBlob), hierarchy) +} + +func decodeLoadExternal(in []byte) (tpmutil.Handle, []byte, error) { + var handle tpmutil.Handle + var name tpmutil.U16Bytes + + if _, err := tpmutil.Unpack(in, &handle, &name); err != nil { + return 0, nil, err + } + return handle, name, nil +} + +// LoadExternal loads a public (and optionally a private) key into an object in +// the TPM. Returns loaded object handle and its name. +func LoadExternal(rw io.ReadWriter, pub Public, private Private, hierarchy tpmutil.Handle) (tpmutil.Handle, []byte, error) { + Cmd, err := encodeLoadExternal(pub, private, hierarchy) + if err != nil { + return 0, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdLoadExternal, tpmutil.RawBytes(Cmd)) + if err != nil { + return 0, nil, err + } + handle, name, err := decodeLoadExternal(resp) + if err != nil { + return 0, nil, err + } + return handle, name, nil +} + +// PolicyPassword sets password authorization requirement on the object. +func PolicyPassword(rw io.ReadWriter, handle tpmutil.Handle) error { + _, err := runCommand(rw, TagNoSessions, CmdPolicyPassword, handle) + return err +} + +func encodePolicySecret(entityHandle tpmutil.Handle, entityAuth AuthCommand, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef tpmutil.U16Bytes, expiry int32) ([]byte, error) { + auth, err := encodeAuthArea(entityAuth) + if err != nil { + return nil, err + } + handles, err := tpmutil.Pack(entityHandle, policyHandle) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(policyNonce, cpHash, policyRef, expiry) + if err != nil { + return nil, err + } + return concat(handles, auth, params) +} + +func decodePolicySecret(in []byte) ([]byte, *Ticket, error) { + buf := bytes.NewBuffer(in) + + var paramSize uint32 + var timeout tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(buf, ¶mSize, &timeout); err != nil { + return nil, nil, fmt.Errorf("decoding timeout: %v", err) + } + var t Ticket + if err := tpmutil.UnpackBuf(buf, &t); err != nil { + return nil, nil, fmt.Errorf("decoding ticket: %v", err) + } + return timeout, &t, nil +} + +// PolicySecret sets a secret authorization requirement on the provided entity. +func PolicySecret(rw io.ReadWriter, entityHandle tpmutil.Handle, entityAuth AuthCommand, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef []byte, expiry int32) ([]byte, *Ticket, error) { + Cmd, err := encodePolicySecret(entityHandle, entityAuth, policyHandle, policyNonce, cpHash, policyRef, expiry) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdPolicySecret, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodePolicySecret(resp) +} + +func encodePolicySigned(validationKeyHandle tpmutil.Handle, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef tpmutil.U16Bytes, expiry int32, auth []byte) ([]byte, error) { + handles, err := tpmutil.Pack(validationKeyHandle, policyHandle) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(policyNonce, cpHash, policyRef, expiry, auth) + if err != nil { + return nil, err + } + return concat(handles, params) +} + +func decodePolicySigned(in []byte) ([]byte, *Ticket, error) { + buf := bytes.NewBuffer(in) + + var timeout tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(buf, &timeout); err != nil { + return nil, nil, fmt.Errorf("decoding timeout: %v", err) + } + var t Ticket + if err := tpmutil.UnpackBuf(buf, &t); err != nil { + return nil, nil, fmt.Errorf("decoding ticket: %v", err) + } + return timeout, &t, nil +} + +// PolicySigned sets a signed authorization requirement on the provided policy. +func PolicySigned(rw io.ReadWriter, validationKeyHandle tpmutil.Handle, policyHandle tpmutil.Handle, policyNonce, cpHash, policyRef []byte, expiry int32, signedAuth []byte) ([]byte, *Ticket, error) { + Cmd, err := encodePolicySigned(validationKeyHandle, policyHandle, policyNonce, cpHash, policyRef, expiry, signedAuth) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdPolicySigned, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodePolicySigned(resp) +} + +func encodePolicyPCR(session tpmutil.Handle, expectedDigest tpmutil.U16Bytes, sel PCRSelection) ([]byte, error) { + params, err := tpmutil.Pack(session, expectedDigest) + if err != nil { + return nil, err + } + pcrs, err := encodeTPMLPCRSelection(sel) + if err != nil { + return nil, err + } + return concat(params, pcrs) +} + +// PolicyPCR sets PCR state binding for authorization on a session. +// +// expectedDigest is optional. When specified, it's compared against the digest +// of PCRs matched by sel. +// +// Note that expectedDigest must be a *digest* of the expected PCR value. You +// must compute the digest manually. ReadPCR returns raw PCR values, not their +// digests. +// If you wish to select multiple PCRs, concatenate their values before +// computing the digest. See "TPM 2.0 Part 1, Selecting Multiple PCR". +func PolicyPCR(rw io.ReadWriter, session tpmutil.Handle, expectedDigest []byte, sel PCRSelection) error { + Cmd, err := encodePolicyPCR(session, expectedDigest, sel) + if err != nil { + return err + } + _, err = runCommand(rw, TagNoSessions, CmdPolicyPCR, tpmutil.RawBytes(Cmd)) + return err +} + +// PolicyOr compares PolicySession→Digest against the list of provided values. +// If the current Session→Digest does not match any value in the list, +// the TPM shall return TPM_RC_VALUE. Otherwise, the TPM will reset policySession→Digest +// to a Zero Digest. Then policySession→Digest is extended by the concatenation of +// TPM_CC_PolicyOR and the concatenation of all of the digests. +func PolicyOr(rw io.ReadWriter, session tpmutil.Handle, digests TPMLDigest) error { + d, err := digests.Encode() + if err != nil { + return err + } + data, err := tpmutil.Pack(session, d) + if err != nil { + return err + } + _, err = runCommand(rw, TagNoSessions, CmdPolicyOr, data) + return err +} + +// PolicyGetDigest returns the current policyDigest of the session. +func PolicyGetDigest(rw io.ReadWriter, handle tpmutil.Handle) ([]byte, error) { + resp, err := runCommand(rw, TagNoSessions, CmdPolicyGetDigest, handle) + if err != nil { + return nil, err + } + + var digest tpmutil.U16Bytes + _, err = tpmutil.Unpack(resp, &digest) + return digest, err +} + +func encodeStartAuthSession(tpmKey, bindKey tpmutil.Handle, nonceCaller, secret tpmutil.U16Bytes, se SessionType, sym, hashAlg Algorithm) ([]byte, error) { + ha, err := tpmutil.Pack(tpmKey, bindKey) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(nonceCaller, secret, se, sym, hashAlg) + if err != nil { + return nil, err + } + return concat(ha, params) +} + +func decodeStartAuthSession(in []byte) (tpmutil.Handle, []byte, error) { + var handle tpmutil.Handle + var nonce tpmutil.U16Bytes + if _, err := tpmutil.Unpack(in, &handle, &nonce); err != nil { + return 0, nil, err + } + return handle, nonce, nil +} + +// StartAuthSession initializes a session object. +// Returns session handle and the initial nonce from the TPM. +func StartAuthSession(rw io.ReadWriter, tpmKey, bindKey tpmutil.Handle, nonceCaller, secret []byte, se SessionType, sym, hashAlg Algorithm) (tpmutil.Handle, []byte, error) { + Cmd, err := encodeStartAuthSession(tpmKey, bindKey, nonceCaller, secret, se, sym, hashAlg) + if err != nil { + return 0, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdStartAuthSession, tpmutil.RawBytes(Cmd)) + if err != nil { + return 0, nil, err + } + return decodeStartAuthSession(resp) +} + +func encodeUnseal(sessionHandle, itemHandle tpmutil.Handle, password string) ([]byte, error) { + ha, err := tpmutil.Pack(itemHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: sessionHandle, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + return concat(ha, auth) +} + +func decodeUnseal(in []byte) ([]byte, error) { + var paramSize uint32 + var unsealed tpmutil.U16Bytes + + if _, err := tpmutil.Unpack(in, ¶mSize, &unsealed); err != nil { + return nil, err + } + return unsealed, nil +} + +// Unseal returns the data for a loaded sealed object. +func Unseal(rw io.ReadWriter, itemHandle tpmutil.Handle, password string) ([]byte, error) { + return UnsealWithSession(rw, HandlePasswordSession, itemHandle, password) +} + +// UnsealWithSession returns the data for a loaded sealed object. +func UnsealWithSession(rw io.ReadWriter, sessionHandle, itemHandle tpmutil.Handle, password string) ([]byte, error) { + Cmd, err := encodeUnseal(sessionHandle, itemHandle, password) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdUnseal, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeUnseal(resp) +} + +func encodeQuote(signingHandle tpmutil.Handle, signerAuth string, toQuote tpmutil.U16Bytes, sel PCRSelection, sigAlg Algorithm) ([]byte, error) { + ha, err := tpmutil.Pack(signingHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(signerAuth)}) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(toQuote, sigAlg) + if err != nil { + return nil, err + } + pcrs, err := encodeTPMLPCRSelection(sel) + if err != nil { + return nil, err + } + return concat(ha, auth, params, pcrs) +} + +func decodeQuote(in []byte) ([]byte, []byte, error) { + buf := bytes.NewBuffer(in) + var paramSize uint32 + if err := tpmutil.UnpackBuf(buf, ¶mSize); err != nil { + return nil, nil, err + } + buf.Truncate(int(paramSize)) + var attest tpmutil.U16Bytes + if err := tpmutil.UnpackBuf(buf, &attest); err != nil { + return nil, nil, err + } + return attest, buf.Bytes(), nil +} + +// Quote returns a quote of PCR values. A quote is a signature of the PCR +// values, created using a signing TPM key. +// +// Returns attestation data and the decoded signature. +func Quote(rw io.ReadWriter, signingHandle tpmutil.Handle, signerAuth, unused string, toQuote []byte, sel PCRSelection, sigAlg Algorithm) ([]byte, *Signature, error) { + // TODO: Remove "unused" parameter on next breaking change. + attest, sigRaw, err := QuoteRaw(rw, signingHandle, signerAuth, unused, toQuote, sel, sigAlg) + if err != nil { + return nil, nil, err + } + sig, err := DecodeSignature(bytes.NewBuffer(sigRaw)) + if err != nil { + return nil, nil, err + } + return attest, sig, nil +} + +// QuoteRaw is very similar to Quote, except that it will return +// the raw signature in a byte array without decoding. +func QuoteRaw(rw io.ReadWriter, signingHandle tpmutil.Handle, signerAuth, _ string, toQuote []byte, sel PCRSelection, sigAlg Algorithm) ([]byte, []byte, error) { + // TODO: Remove "unused" parameter on next breaking change. + Cmd, err := encodeQuote(signingHandle, signerAuth, toQuote, sel, sigAlg) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdQuote, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodeQuote(resp) +} + +func encodeActivateCredential(auth []AuthCommand, activeHandle tpmutil.Handle, keyHandle tpmutil.Handle, credBlob, secret tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(activeHandle, keyHandle) + if err != nil { + return nil, err + } + a, err := encodeAuthArea(auth...) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(credBlob, secret) + if err != nil { + return nil, err + } + return concat(ha, a, params) +} + +func decodeActivateCredential(in []byte) ([]byte, error) { + var paramSize uint32 + var certInfo tpmutil.U16Bytes + + if _, err := tpmutil.Unpack(in, ¶mSize, &certInfo); err != nil { + return nil, err + } + return certInfo, nil +} + +// ActivateCredential associates an object with a credential. +// Returns decrypted certificate information. +func ActivateCredential(rw io.ReadWriter, activeHandle, keyHandle tpmutil.Handle, activePassword, protectorPassword string, credBlob, secret []byte) ([]byte, error) { + return ActivateCredentialUsingAuth(rw, []AuthCommand{ + {Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(activePassword)}, + {Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(protectorPassword)}, + }, activeHandle, keyHandle, credBlob, secret) +} + +// ActivateCredentialUsingAuth associates an object with a credential, using the +// given set of authorizations. Two authorization must be provided. +// Returns decrypted certificate information. +func ActivateCredentialUsingAuth(rw io.ReadWriter, auth []AuthCommand, activeHandle, keyHandle tpmutil.Handle, credBlob, secret []byte) ([]byte, error) { + if len(auth) != 2 { + return nil, fmt.Errorf("len(auth) = %d, want 2", len(auth)) + } + + Cmd, err := encodeActivateCredential(auth, activeHandle, keyHandle, credBlob, secret) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdActivateCredential, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeActivateCredential(resp) +} + +func encodeMakeCredential(protectorHandle tpmutil.Handle, credential, activeName tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(protectorHandle) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(credential, activeName) + if err != nil { + return nil, err + } + return concat(ha, params) +} + +func decodeMakeCredential(in []byte) ([]byte, []byte, error) { + var credBlob, encryptedSecret tpmutil.U16Bytes + + if _, err := tpmutil.Unpack(in, &credBlob, &encryptedSecret); err != nil { + return nil, nil, err + } + return credBlob, encryptedSecret, nil +} + +// MakeCredential creates an encrypted credential for use in MakeCredential. +// Returns encrypted credential and wrapped secret used to encrypt it. +func MakeCredential(rw io.ReadWriter, protectorHandle tpmutil.Handle, credential, activeName []byte) ([]byte, []byte, error) { + Cmd, err := encodeMakeCredential(protectorHandle, credential, activeName) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdMakeCredential, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodeMakeCredential(resp) +} + +func encodeEvictControl(ownerAuth string, owner, objectHandle, persistentHandle tpmutil.Handle) ([]byte, error) { + ha, err := tpmutil.Pack(owner, objectHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(ownerAuth)}) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(persistentHandle) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +// EvictControl toggles persistence of an object within the TPM. +func EvictControl(rw io.ReadWriter, ownerAuth string, owner, objectHandle, persistentHandle tpmutil.Handle) error { + Cmd, err := encodeEvictControl(ownerAuth, owner, objectHandle, persistentHandle) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdEvictControl, tpmutil.RawBytes(Cmd)) + return err +} + +func encodeClear(handle tpmutil.Handle, auth AuthCommand) ([]byte, error) { + ah, err := tpmutil.Pack(handle) + if err != nil { + return nil, err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return nil, err + } + return concat(ah, encodedAuth) +} + +// Clear clears lockout, endorsement and owner hierarchy authorization values +func Clear(rw io.ReadWriter, handle tpmutil.Handle, auth AuthCommand) error { + Cmd, err := encodeClear(handle, auth) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdClear, tpmutil.RawBytes(Cmd)) + return err +} + +func encodeHierarchyChangeAuth(handle tpmutil.Handle, auth AuthCommand, newAuth string) ([]byte, error) { + ah, err := tpmutil.Pack(handle) + if err != nil { + return nil, err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return nil, err + } + param, err := tpmutil.Pack(tpmutil.U16Bytes(newAuth)) + if err != nil { + return nil, err + } + return concat(ah, encodedAuth, param) +} + +// HierarchyChangeAuth changes the authorization values for a hierarchy or for the lockout authority +func HierarchyChangeAuth(rw io.ReadWriter, handle tpmutil.Handle, auth AuthCommand, newAuth string) error { + Cmd, err := encodeHierarchyChangeAuth(handle, auth, newAuth) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdHierarchyChangeAuth, tpmutil.RawBytes(Cmd)) + return err +} + +// ContextSave returns an encrypted version of the session, object or sequence +// context for storage outside of the TPM. The handle references context to +// store. +func ContextSave(rw io.ReadWriter, handle tpmutil.Handle) ([]byte, error) { + return runCommand(rw, TagNoSessions, CmdContextSave, handle) +} + +// ContextLoad reloads context data created by ContextSave. +func ContextLoad(rw io.ReadWriter, saveArea []byte) (tpmutil.Handle, error) { + resp, err := runCommand(rw, TagNoSessions, CmdContextLoad, tpmutil.RawBytes(saveArea)) + if err != nil { + return 0, err + } + var handle tpmutil.Handle + _, err = tpmutil.Unpack(resp, &handle) + return handle, err +} + +func encodeIncrementNV(handle tpmutil.Handle, authString string) ([]byte, error) { + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(authString)}) + if err != nil { + return nil, err + } + out, err := tpmutil.Pack(handle, handle) + if err != nil { + return nil, err + } + return concat(out, auth) +} + +// NVIncrement increments a counter in NVRAM. +func NVIncrement(rw io.ReadWriter, handle tpmutil.Handle, authString string) error { + Cmd, err := encodeIncrementNV(handle, authString) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdIncrementNVCounter, tpmutil.RawBytes(Cmd)) + return err +} + +// NVUndefineSpace removes an index from TPM's NV storage. +func NVUndefineSpace(rw io.ReadWriter, ownerAuth string, owner, index tpmutil.Handle) error { + authArea := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(ownerAuth)} + return NVUndefineSpaceEx(rw, owner, index, authArea) +} + +// NVUndefineSpaceEx removes an index from NVRAM. Unlike, NVUndefineSpace(), custom command +// authorization can be provided. +func NVUndefineSpaceEx(rw io.ReadWriter, owner, index tpmutil.Handle, authArea AuthCommand) error { + out, err := tpmutil.Pack(owner, index) + if err != nil { + return err + } + auth, err := encodeAuthArea(authArea) + if err != nil { + return err + } + cmd, err := concat(out, auth) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdUndefineSpace, tpmutil.RawBytes(cmd)) + return err +} + +// NVUndefineSpaceSpecial This command allows removal of a platform-created NV Index that has TPMA_NV_POLICY_DELETE SET. +// The policy to authorize NV index access needs to be created with PolicyCommandCode(rw, sessionHandle, CmdNVUndefineSpaceSpecial) function +// nvAuthCmd takes the session handle for the policy and the AuthValue (which can be emptyAuth) for the authorization. +// platformAuth takes either a sessionHandle for the platform policy or HandlePasswordSession and the platformAuth value for authorization. +func NVUndefineSpaceSpecial(rw io.ReadWriter, nvIndex tpmutil.Handle, nvAuth, platformAuth AuthCommand) error { + authBytes, err := encodeAuthArea(nvAuth, platformAuth) + if err != nil { + return err + } + auth, err := tpmutil.Pack(authBytes) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdNVUndefineSpaceSpecial, nvIndex, HandlePlatform, tpmutil.RawBytes(auth)) + return err +} + +// NVDefineSpace creates an index in TPM's NV storage. +func NVDefineSpace(rw io.ReadWriter, owner, handle tpmutil.Handle, ownerAuth, authString string, policy []byte, attributes NVAttr, dataSize uint16) error { + nvPub := NVPublic{ + NVIndex: handle, + NameAlg: AlgSHA1, + Attributes: attributes, + AuthPolicy: policy, + DataSize: dataSize, + } + authArea := AuthCommand{ + Session: HandlePasswordSession, + Attributes: AttrContinueSession, + Auth: []byte(ownerAuth), + } + return NVDefineSpaceEx(rw, owner, authString, nvPub, authArea) +} + +// NVDefineSpaceEx accepts NVPublic structure and AuthCommand, allowing more flexibility. +func NVDefineSpaceEx(rw io.ReadWriter, owner tpmutil.Handle, authVal string, pubInfo NVPublic, authArea AuthCommand) error { + ha, err := tpmutil.Pack(owner) + if err != nil { + return err + } + auth, err := encodeAuthArea(authArea) + if err != nil { + return err + } + publicInfo, err := tpmutil.Pack(pubInfo) + if err != nil { + return err + } + params, err := tpmutil.Pack(tpmutil.U16Bytes(authVal), tpmutil.U16Bytes(publicInfo)) + if err != nil { + return err + } + cmd, err := concat(ha, auth, params) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdDefineSpace, tpmutil.RawBytes(cmd)) + return err +} + +// NVWrite writes data into the TPM's NV storage. +func NVWrite(rw io.ReadWriter, authHandle, nvIndex tpmutil.Handle, authString string, data tpmutil.U16Bytes, offset uint16) error { + auth := AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(authString)} + return NVWriteEx(rw, authHandle, nvIndex, auth, data, offset) +} + +// NVWriteEx does the same as NVWrite with the exception of letting the user take care of the AuthCommand before calling the function. +// This allows more flexibility and does not limit the AuthCommand to PasswordSession. +func NVWriteEx(rw io.ReadWriter, authHandle, nvIndex tpmutil.Handle, authArea AuthCommand, data tpmutil.U16Bytes, offset uint16) error { + h, err := tpmutil.Pack(authHandle, nvIndex) + if err != nil { + return err + } + authEnc, err := encodeAuthArea(authArea) + if err != nil { + return err + } + + d, err := tpmutil.Pack(data, offset) + if err != nil { + return err + } + + b, err := concat(h, authEnc, d) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdWriteNV, tpmutil.RawBytes(b)) + return err +} + +func encodeLockNV(owner, handle tpmutil.Handle, authString string) ([]byte, error) { + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(authString)}) + if err != nil { + return nil, err + } + out, err := tpmutil.Pack(owner, handle) + if err != nil { + return nil, err + } + return concat(out, auth) +} + +// NVWriteLock inhibits further writes on the given NV index if at least one of +// the AttrWriteSTClear or AttrWriteDefine bits is set. +// +// AttrWriteSTClear causes the index to be locked until the TPM is restarted +// (see the Startup function). +// +// AttrWriteDefine causes the index to be locked permanently if data has been +// written to the index; otherwise the lock is removed on startup. +// +// NVWriteLock returns an error if neither bit is set. +// +// It is not an error to call NVWriteLock for an index that is already locked +// for writing. +func NVWriteLock(rw io.ReadWriter, owner, handle tpmutil.Handle, authString string) error { + Cmd, err := encodeLockNV(owner, handle, authString) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdWriteLockNV, tpmutil.RawBytes(Cmd)) + return err +} + +func decodeNVReadPublic(in []byte) (NVPublic, error) { + var pub NVPublic + var buf tpmutil.U16Bytes + if _, err := tpmutil.Unpack(in, &buf); err != nil { + return pub, err + } + _, err := tpmutil.Unpack(buf, &pub) + return pub, err +} + +// NVReadPublic reads the public data of an NV index. +func NVReadPublic(rw io.ReadWriter, index tpmutil.Handle) (NVPublic, error) { + // Read public area to determine data size. + resp, err := runCommand(rw, TagNoSessions, CmdReadPublicNV, index) + if err != nil { + return NVPublic{}, err + } + return decodeNVReadPublic(resp) +} + +func decodeNVRead(in []byte) ([]byte, error) { + var paramSize uint32 + var data tpmutil.U16Bytes + if _, err := tpmutil.Unpack(in, ¶mSize, &data); err != nil { + return nil, err + } + return data, nil +} + +func encodeNVRead(nvIndex, authHandle tpmutil.Handle, password string, offset, dataSize uint16) ([]byte, error) { + handles, err := tpmutil.Pack(authHandle, nvIndex) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + + params, err := tpmutil.Pack(dataSize, offset) + if err != nil { + return nil, err + } + + return concat(handles, auth, params) +} + +// NVRead reads a full data blob from an NV index. This function is +// deprecated; use NVReadEx instead. +func NVRead(rw io.ReadWriter, index tpmutil.Handle) ([]byte, error) { + return NVReadEx(rw, index, index, "", 0) +} + +// NVReadEx reads a full data blob from an NV index, using the given +// authorization handle. NVRead commands are done in blocks of blockSize. +// If blockSize is 0, the TPM is queried for TPM_PT_NV_BUFFER_MAX, and that +// value is used. +func NVReadEx(rw io.ReadWriter, index, authHandle tpmutil.Handle, password string, blockSize int) ([]byte, error) { + if blockSize == 0 { + readBuff, _, err := GetCapability(rw, CapabilityTPMProperties, 1, uint32(NVMaxBufferSize)) + if err != nil { + return nil, fmt.Errorf("GetCapability for TPM_PT_NV_BUFFER_MAX failed: %v", err) + } + if len(readBuff) != 1 { + return nil, fmt.Errorf("could not determine NVRAM read/write buffer size") + } + rb, ok := readBuff[0].(TaggedProperty) + if !ok { + return nil, fmt.Errorf("GetCapability returned unexpected type: %T, expected TaggedProperty", readBuff[0]) + } + blockSize = int(rb.Value) + } + + // Read public area to determine data size. + pub, err := NVReadPublic(rw, index) + if err != nil { + return nil, fmt.Errorf("decoding NV_ReadPublic response: %v", err) + } + + // Read the NVRAM area in blocks. + outBuff := make([]byte, 0, int(pub.DataSize)) + for len(outBuff) < int(pub.DataSize) { + readSize := blockSize + if readSize > (int(pub.DataSize) - len(outBuff)) { + readSize = int(pub.DataSize) - len(outBuff) + } + + Cmd, err := encodeNVRead(index, authHandle, password, uint16(len(outBuff)), uint16(readSize)) + if err != nil { + return nil, fmt.Errorf("building NV_Read command: %v", err) + } + resp, err := runCommand(rw, TagSessions, CmdReadNV, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, fmt.Errorf("running NV_Read command (cursor=%d,size=%d): %v", len(outBuff), readSize, err) + } + data, err := decodeNVRead(resp) + if err != nil { + return nil, fmt.Errorf("decoding NV_Read command: %v", err) + } + outBuff = append(outBuff, data...) + } + return outBuff, nil +} + +// NVReadLock inhibits further reads of the given NV index if AttrReadSTClear +// is set. After the TPM is restarted the index can be read again (see the +// Startup function). +// +// NVReadLock returns an error if the AttrReadSTClear bit is not set. +// +// It is not an error to call NVReadLock for an index that is already locked +// for reading. +func NVReadLock(rw io.ReadWriter, owner, handle tpmutil.Handle, authString string) error { + Cmd, err := encodeLockNV(owner, handle, authString) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdReadLockNV, tpmutil.RawBytes(Cmd)) + return err +} + +// decodeHash unpacks a successful response to TPM2_Hash, returning the computed digest and +// validation ticket. +func decodeHash(resp []byte) ([]byte, *Ticket, error) { + var digest tpmutil.U16Bytes + var validation Ticket + + buf := bytes.NewBuffer(resp) + if err := tpmutil.UnpackBuf(buf, &digest, &validation); err != nil { + return nil, nil, err + } + return digest, &validation, nil +} + +// Hash computes a hash of data in buf using TPM2_Hash, returning the computed +// digest and validation ticket. The validation ticket serves as confirmation +// from the TPM that the data in buf did not begin with TPM_GENERATED_VALUE. +// NOTE: TPM2_Hash can only accept data up to MAX_DIGEST_BUFFER in size, which +// is implementation-dependent, but guaranteed to be at least 1024 octets. +func Hash(rw io.ReadWriter, alg Algorithm, buf tpmutil.U16Bytes, hierarchy tpmutil.Handle) (digest []byte, validation *Ticket, err error) { + resp, err := runCommand(rw, TagNoSessions, CmdHash, buf, alg, hierarchy) + if err != nil { + return nil, nil, err + } + return decodeHash(resp) +} + +// HashSequenceStart starts a hash or an event sequence. If hashAlg is an +// implemented hash, then a hash sequence is started. If hashAlg is +// TPM_ALG_NULL, then an event sequence is started. +func HashSequenceStart(rw io.ReadWriter, sequenceAuth string, hashAlg Algorithm) (seqHandle tpmutil.Handle, err error) { + resp, err := runCommand(rw, TagNoSessions, CmdHashSequenceStart, tpmutil.U16Bytes(sequenceAuth), hashAlg) + if err != nil { + return 0, err + } + var handle tpmutil.Handle + _, err = tpmutil.Unpack(resp, &handle) + return handle, err +} + +func encodeSequenceUpdate(sequenceAuth string, seqHandle tpmutil.Handle, buf tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(seqHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(sequenceAuth)}) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(buf) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +// SequenceUpdate is used to add data to a hash or HMAC sequence. +func SequenceUpdate(rw io.ReadWriter, sequenceAuth string, seqHandle tpmutil.Handle, buffer []byte) error { + cmd, err := encodeSequenceUpdate(sequenceAuth, seqHandle, buffer) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdSequenceUpdate, tpmutil.RawBytes(cmd)) + return err +} + +func decodeSequenceComplete(resp []byte) ([]byte, *Ticket, error) { + var digest tpmutil.U16Bytes + var validation Ticket + var paramSize uint32 + + if _, err := tpmutil.Unpack(resp, ¶mSize, &digest, &validation); err != nil { + return nil, nil, err + } + return digest, &validation, nil +} + +func encodeSequenceComplete(sequenceAuth string, seqHandle, hierarchy tpmutil.Handle, buf tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(seqHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(sequenceAuth)}) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(buf, hierarchy) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +// SequenceComplete adds the last part of data, if any, to a hash/HMAC sequence +// and returns the result. +func SequenceComplete(rw io.ReadWriter, sequenceAuth string, seqHandle, hierarchy tpmutil.Handle, buffer []byte) (digest []byte, validation *Ticket, err error) { + cmd, err := encodeSequenceComplete(sequenceAuth, seqHandle, hierarchy, buffer) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdSequenceComplete, tpmutil.RawBytes(cmd)) + if err != nil { + return nil, nil, err + } + return decodeSequenceComplete(resp) +} + +func encodeEventSequenceComplete(auths []AuthCommand, pcrHandle, seqHandle tpmutil.Handle, buf tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(pcrHandle, seqHandle) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(auths...) + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(buf) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +func decodeEventSequenceComplete(resp []byte) ([]*HashValue, error) { + var paramSize uint32 + var hashCount uint32 + var err error + + buf := bytes.NewBuffer(resp) + if err := tpmutil.UnpackBuf(buf, ¶mSize, &hashCount); err != nil { + return nil, err + } + + buf.Truncate(int(paramSize)) + digests := make([]*HashValue, hashCount) + for i := uint32(0); i < hashCount; i++ { + if digests[i], err = decodeHashValue(buf); err != nil { + return nil, err + } + } + + return digests, nil +} + +// EventSequenceComplete adds the last part of data, if any, to an Event +// Sequence and returns the result in a digest list. If pcrHandle references a +// PCR and not AlgNull, then the returned digest list is processed in the same +// manner as the digest list input parameter to PCRExtend() with the pcrHandle +// in each bank extended with the associated digest value. +func EventSequenceComplete(rw io.ReadWriter, pcrAuth, sequenceAuth string, pcrHandle, seqHandle tpmutil.Handle, buffer []byte) (digests []*HashValue, err error) { + auth := []AuthCommand{ + {Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(pcrAuth)}, + {Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(sequenceAuth)}, + } + cmd, err := encodeEventSequenceComplete(auth, pcrHandle, seqHandle, buffer) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdEventSequenceComplete, tpmutil.RawBytes(cmd)) + if err != nil { + return nil, err + } + return decodeEventSequenceComplete(resp) +} + +// Startup initializes a TPM (usually done by the OS). +func Startup(rw io.ReadWriter, typ StartupType) error { + _, err := runCommand(rw, TagNoSessions, CmdStartup, typ) + return err +} + +// Shutdown shuts down a TPM (usually done by the OS). +func Shutdown(rw io.ReadWriter, typ StartupType) error { + _, err := runCommand(rw, TagNoSessions, CmdShutdown, typ) + return err +} + +// nullTicket is a hard-coded null ticket of type TPMT_TK_HASHCHECK. +// It is for Sign commands that do not require the TPM to verify that the digest +// is not from data that started with TPM_GENERATED_VALUE. +var nullTicket = Ticket{ + Type: TagHashCheck, + Hierarchy: HandleNull, + Digest: tpmutil.U16Bytes{}, +} + +func encodeSign(sessionHandle, key tpmutil.Handle, password string, digest tpmutil.U16Bytes, sigScheme *SigScheme, validation *Ticket) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: sessionHandle, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + d, err := tpmutil.Pack(digest) + if err != nil { + return nil, err + } + s, err := sigScheme.encode() + if err != nil { + return nil, err + } + if validation == nil { + validation = &nullTicket + } + v, err := tpmutil.Pack(validation) + if err != nil { + return nil, err + } + + return concat(ha, auth, d, s, v) +} + +func decodeSign(buf []byte) (*Signature, error) { + in := bytes.NewBuffer(buf) + var paramSize uint32 + if err := tpmutil.UnpackBuf(in, ¶mSize); err != nil { + return nil, err + } + return DecodeSignature(in) +} + +// SignWithSession computes a signature for digest using a given loaded key. Signature +// algorithm depends on the key type. Used for keys with non-password authorization policies. +// If 'key' references a Restricted Decryption key, 'validation' must be a valid hash verification +// ticket from the TPM, which can be obtained by using Hash() to hash the data with the TPM. +// If 'validation' is nil, a NULL ticket is passed to TPM2_Sign. +func SignWithSession(rw io.ReadWriter, sessionHandle, key tpmutil.Handle, password string, digest []byte, validation *Ticket, sigScheme *SigScheme) (*Signature, error) { + Cmd, err := encodeSign(sessionHandle, key, password, digest, sigScheme, validation) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdSign, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeSign(resp) +} + +// Sign computes a signature for digest using a given loaded key. Signature +// algorithm depends on the key type. +// If 'key' references a Restricted Decryption key, 'validation' must be a valid hash verification +// ticket from the TPM, which can be obtained by using Hash() to hash the data with the TPM. +// If 'validation' is nil, a NULL ticket is passed to TPM2_Sign. +func Sign(rw io.ReadWriter, key tpmutil.Handle, password string, digest []byte, validation *Ticket, sigScheme *SigScheme) (*Signature, error) { + return SignWithSession(rw, HandlePasswordSession, key, password, digest, validation, sigScheme) +} + +func encodeCertify(objectAuth, signerAuth string, object, signer tpmutil.Handle, qualifyingData tpmutil.U16Bytes) ([]byte, error) { + ha, err := tpmutil.Pack(object, signer) + if err != nil { + return nil, err + } + + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(objectAuth)}, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(signerAuth)}) + if err != nil { + return nil, err + } + + scheme := SigScheme{Alg: AlgRSASSA, Hash: AlgSHA256} + // Use signing key's scheme. + s, err := scheme.encode() + if err != nil { + return nil, err + } + data, err := tpmutil.Pack(qualifyingData) + if err != nil { + return nil, err + } + return concat(ha, auth, data, s) +} + +// This function differs from encodeCertify in that it takes the scheme to be used as an additional argument. +func encodeCertifyEx(objectAuth, signerAuth string, object, signer tpmutil.Handle, qualifyingData tpmutil.U16Bytes, scheme SigScheme) ([]byte, error) { + ha, err := tpmutil.Pack(object, signer) + if err != nil { + return nil, err + } + + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(objectAuth)}, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(signerAuth)}) + if err != nil { + return nil, err + } + + s, err := scheme.encode() + if err != nil { + return nil, err + } + data, err := tpmutil.Pack(qualifyingData) + if err != nil { + return nil, err + } + return concat(ha, auth, data, s) +} + +func decodeCertify(resp []byte) ([]byte, []byte, error) { + var paramSize uint32 + var attest tpmutil.U16Bytes + + buf := bytes.NewBuffer(resp) + if err := tpmutil.UnpackBuf(buf, ¶mSize); err != nil { + return nil, nil, err + } + buf.Truncate(int(paramSize)) + if err := tpmutil.UnpackBuf(buf, &attest); err != nil { + return nil, nil, err + } + return attest, buf.Bytes(), nil +} + +// Certify generates a signature of a loaded TPM object with a signing key +// signer. This function calls encodeCertify which makes use of the hardcoded +// signing scheme {AlgRSASSA, AlgSHA256}. Returned values are: attestation data (TPMS_ATTEST), +// signature and error, if any. +func Certify(rw io.ReadWriter, objectAuth, signerAuth string, object, signer tpmutil.Handle, qualifyingData []byte) ([]byte, []byte, error) { + cmd, err := encodeCertify(objectAuth, signerAuth, object, signer, qualifyingData) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdCertify, tpmutil.RawBytes(cmd)) + if err != nil { + return nil, nil, err + } + return decodeCertify(resp) +} + +// CertifyEx generates a signature of a loaded TPM object with a signing key +// signer. This function differs from Certify in that it takes the scheme +// to be used as an additional argument and calls encodeCertifyEx instead +// of encodeCertify. Returned values are: attestation data (TPMS_ATTEST), +// signature and error, if any. +func CertifyEx(rw io.ReadWriter, objectAuth, signerAuth string, object, signer tpmutil.Handle, qualifyingData []byte, scheme SigScheme) ([]byte, []byte, error) { + cmd, err := encodeCertifyEx(objectAuth, signerAuth, object, signer, qualifyingData, scheme) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdCertify, tpmutil.RawBytes(cmd)) + if err != nil { + return nil, nil, err + } + return decodeCertify(resp) +} + +func encodeCertifyCreation(objectAuth string, object, signer tpmutil.Handle, qualifyingData, creationHash tpmutil.U16Bytes, scheme SigScheme, ticket Ticket) ([]byte, error) { + handles, err := tpmutil.Pack(signer, object) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(objectAuth)}) + if err != nil { + return nil, err + } + s, err := scheme.encode() + if err != nil { + return nil, err + } + params, err := tpmutil.Pack(qualifyingData, creationHash, tpmutil.RawBytes(s), ticket) + if err != nil { + return nil, err + } + return concat(handles, auth, params) +} + +// CertifyCreation generates a signature of a newly-created & +// loaded TPM object, using signer as the signing key. +func CertifyCreation(rw io.ReadWriter, objectAuth string, object, signer tpmutil.Handle, qualifyingData, creationHash []byte, sigScheme SigScheme, creationTicket Ticket) (attestation, signature []byte, err error) { + Cmd, err := encodeCertifyCreation(objectAuth, object, signer, qualifyingData, creationHash, sigScheme, creationTicket) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdCertifyCreation, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodeCertify(resp) +} + +func runCommand(rw io.ReadWriter, tag tpmutil.Tag, Cmd tpmutil.Command, in ...interface{}) ([]byte, error) { + resp, code, err := tpmutil.RunCommand(rw, tag, Cmd, in...) + if err != nil { + return nil, err + } + if code != tpmutil.RCSuccess { + return nil, decodeResponse(code) + } + return resp, decodeResponse(code) +} + +// concat is a helper for encoding functions that separately encode handle, +// auth and param areas. A nil error is always returned, so that callers can +// simply return concat(a, b, c). +func concat(chunks ...[]byte) ([]byte, error) { + return bytes.Join(chunks, nil), nil +} + +func encodePCRExtend(pcr tpmutil.Handle, hashAlg Algorithm, hash tpmutil.RawBytes, password string) ([]byte, error) { + ha, err := tpmutil.Pack(pcr) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + pcrCount := uint32(1) + extend, err := tpmutil.Pack(pcrCount, hashAlg, hash) + if err != nil { + return nil, err + } + return concat(ha, auth, extend) +} + +// PCRExtend extends a value into the selected PCR +func PCRExtend(rw io.ReadWriter, pcr tpmutil.Handle, hashAlg Algorithm, hash []byte, password string) error { + Cmd, err := encodePCRExtend(pcr, hashAlg, hash, password) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdPCRExtend, tpmutil.RawBytes(Cmd)) + return err +} + +// ReadPCR reads the value of the given PCR. +func ReadPCR(rw io.ReadWriter, pcr int, hashAlg Algorithm) ([]byte, error) { + pcrSelection := PCRSelection{ + Hash: hashAlg, + PCRs: []int{pcr}, + } + pcrVals, err := ReadPCRs(rw, pcrSelection) + if err != nil { + return nil, fmt.Errorf("unable to read PCRs from TPM: %v", err) + } + pcrVal, present := pcrVals[pcr] + if !present { + return nil, fmt.Errorf("PCR %d value missing from response", pcr) + } + return pcrVal, nil +} + +func encodePCRReset(pcr tpmutil.Handle) ([]byte, error) { + ha, err := tpmutil.Pack(pcr) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: EmptyAuth}) + if err != nil { + return nil, err + } + return concat(ha, auth) +} + +// PCRReset resets the value of the given PCR. Usually, only PCR 16 (Debug) and +// PCR 23 (Application) are resettable on the default locality. +func PCRReset(rw io.ReadWriter, pcr tpmutil.Handle) error { + Cmd, err := encodePCRReset(pcr) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdPCRReset, tpmutil.RawBytes(Cmd)) + return err +} + +// EncryptSymmetric encrypts data using a symmetric key. +// +// WARNING: This command performs low-level cryptographic operations. +// Secure use of this command is subtle and requires careful analysis. +// Please consult with experts in cryptography for how to use it securely. +// +// The iv is the initialization vector. The iv must not be empty and its size depends on the +// details of the symmetric encryption scheme. +// +// The data may be longer than block size, EncryptSymmetric will chain +// multiple TPM calls to encrypt the entire blob. +// +// Key handle should point at SymCipher object which is a child of the key (and +// not e.g. RSA key itself). +func EncryptSymmetric(rw io.ReadWriteCloser, keyAuth string, key tpmutil.Handle, iv, data []byte) ([]byte, error) { + return encryptDecryptSymmetric(rw, keyAuth, key, iv, data, false) +} + +// DecryptSymmetric decrypts data using a symmetric key. +// +// WARNING: This command performs low-level cryptographic operations. +// Secure use of this command is subtle and requires careful analysis. +// Please consult with experts in cryptography for how to use it securely. +// +// The iv is the initialization vector. The iv must not be empty and its size +// depends on the details of the symmetric encryption scheme. +// +// The data may be longer than block size, DecryptSymmetric will chain multiple +// TPM calls to decrypt the entire blob. +// +// Key handle should point at SymCipher object which is a child of the key (and +// not e.g. RSA key itself). +func DecryptSymmetric(rw io.ReadWriteCloser, keyAuth string, key tpmutil.Handle, iv, data []byte) ([]byte, error) { + return encryptDecryptSymmetric(rw, keyAuth, key, iv, data, true) +} + +func encodeEncryptDecrypt(keyAuth string, key tpmutil.Handle, iv, data tpmutil.U16Bytes, decrypt bool) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(keyAuth)}) + if err != nil { + return nil, err + } + // Use encryption key's mode. + params, err := tpmutil.Pack(decrypt, AlgNull, iv, data) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +func encodeEncryptDecrypt2(keyAuth string, key tpmutil.Handle, iv, data tpmutil.U16Bytes, decrypt bool) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(keyAuth)}) + if err != nil { + return nil, err + } + // Use encryption key's mode. + params, err := tpmutil.Pack(data, decrypt, AlgNull, iv) + if err != nil { + return nil, err + } + return concat(ha, auth, params) +} + +func decodeEncryptDecrypt(resp []byte) ([]byte, []byte, error) { + var paramSize uint32 + var out, nextIV tpmutil.U16Bytes + if _, err := tpmutil.Unpack(resp, ¶mSize, &out, &nextIV); err != nil { + return nil, nil, err + } + return out, nextIV, nil +} + +func encryptDecryptBlockSymmetric(rw io.ReadWriteCloser, keyAuth string, key tpmutil.Handle, iv, data []byte, decrypt bool) ([]byte, []byte, error) { + Cmd, err := encodeEncryptDecrypt2(keyAuth, key, iv, data, decrypt) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagSessions, CmdEncryptDecrypt2, tpmutil.RawBytes(Cmd)) + if err != nil { + fmt0Err, ok := err.(Error) + if ok && fmt0Err.Code == RCCommandCode { + // If TPM2_EncryptDecrypt2 is not supported, fall back to + // TPM2_EncryptDecrypt. + Cmd, _ := encodeEncryptDecrypt(keyAuth, key, iv, data, decrypt) + resp, err = runCommand(rw, TagSessions, CmdEncryptDecrypt, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + } + } + if err != nil { + return nil, nil, err + } + return decodeEncryptDecrypt(resp) +} + +func encryptDecryptSymmetric(rw io.ReadWriteCloser, keyAuth string, key tpmutil.Handle, iv, data []byte, decrypt bool) ([]byte, error) { + var out, block []byte + var err error + + for rest := data; len(rest) > 0; { + if len(rest) > maxDigestBuffer { + block, rest = rest[:maxDigestBuffer], rest[maxDigestBuffer:] + } else { + block, rest = rest, nil + } + block, iv, err = encryptDecryptBlockSymmetric(rw, keyAuth, key, iv, block, decrypt) + if err != nil { + return nil, err + } + out = append(out, block...) + } + + return out, nil +} + +func encodeRSAEncrypt(key tpmutil.Handle, message tpmutil.U16Bytes, scheme *AsymScheme, label string) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + m, err := tpmutil.Pack(message) + if err != nil { + return nil, err + } + s, err := scheme.encode() + if err != nil { + return nil, err + } + if label != "" { + label += "\x00" + } + l, err := tpmutil.Pack(tpmutil.U16Bytes(label)) + if err != nil { + return nil, err + } + return concat(ha, m, s, l) +} + +func decodeRSAEncrypt(resp []byte) ([]byte, error) { + var out tpmutil.U16Bytes + _, err := tpmutil.Unpack(resp, &out) + return out, err +} + +// RSAEncrypt performs RSA encryption in the TPM according to RFC 3447. The key must be +// a (public) key loaded into the TPM beforehand. Note that when using OAEP with a label, +// a null byte is appended to the label and the null byte is included in the padding +// scheme. +func RSAEncrypt(rw io.ReadWriter, key tpmutil.Handle, message []byte, scheme *AsymScheme, label string) ([]byte, error) { + Cmd, err := encodeRSAEncrypt(key, message, scheme, label) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdRSAEncrypt, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeRSAEncrypt(resp) +} + +func encodeRSADecrypt(sessionHandle, key tpmutil.Handle, password string, message tpmutil.U16Bytes, scheme *AsymScheme, label string) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: sessionHandle, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + m, err := tpmutil.Pack(message) + if err != nil { + return nil, err + } + s, err := scheme.encode() + if err != nil { + return nil, err + } + if label != "" { + label += "\x00" + } + l, err := tpmutil.Pack(tpmutil.U16Bytes(label)) + if err != nil { + return nil, err + } + return concat(ha, auth, m, s, l) +} + +func decodeRSADecrypt(resp []byte) ([]byte, error) { + var out tpmutil.U16Bytes + var paramSize uint32 + _, err := tpmutil.Unpack(resp, ¶mSize, &out) + return out, err +} + +// RSADecrypt performs RSA decryption in the TPM according to RFC 3447. The key must be +// a private RSA key in the TPM with FlagDecrypt set. Note that when using OAEP with a +// label, a null byte is appended to the label and the null byte is included in the +// padding scheme. +func RSADecrypt(rw io.ReadWriter, key tpmutil.Handle, password string, message []byte, scheme *AsymScheme, label string) ([]byte, error) { + return RSADecryptWithSession(rw, HandlePasswordSession, key, password, message, scheme, label) +} + +// RSADecryptWithSession performs RSA decryption in the TPM according to RFC 3447. The key must be +// a private RSA key in the TPM with FlagDecrypt set. Note that when using OAEP with a +// label, a null byte is appended to the label and the null byte is included in the +// padding scheme. +func RSADecryptWithSession(rw io.ReadWriter, sessionHandle, key tpmutil.Handle, password string, message []byte, scheme *AsymScheme, label string) ([]byte, error) { + Cmd, err := encodeRSADecrypt(sessionHandle, key, password, message, scheme, label) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdRSADecrypt, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeRSADecrypt(resp) +} + +func encodeECDHKeyGen(key tpmutil.Handle) ([]byte, error) { + return tpmutil.Pack(key) +} + +func decodeECDHKeyGen(resp []byte) (*ECPoint, *ECPoint, error) { + // Unpack z and pub as TPM2B_ECC_POINT, which is a TPMS_ECC_POINT with a total size prepended. + var z2B, pub2B tpmutil.U16Bytes + _, err := tpmutil.Unpack(resp, &z2B, &pub2B) + if err != nil { + return nil, nil, err + } + var zPoint, pubPoint ECPoint + _, err = tpmutil.Unpack(z2B, &zPoint.XRaw, &zPoint.YRaw) + if err != nil { + return nil, nil, err + } + _, err = tpmutil.Unpack(pub2B, &pubPoint.XRaw, &pubPoint.YRaw) + if err != nil { + return nil, nil, err + } + return &zPoint, &pubPoint, nil +} + +// ECDHKeyGen generates an ephemeral ECC key, calculates the ECDH point multiplcation of the +// ephemeral private key and a loaded public key, and returns the public ephemeral point along with +// the coordinates of the resulting point. +func ECDHKeyGen(rw io.ReadWriter, key tpmutil.Handle) (zPoint, pubPoint *ECPoint, err error) { + Cmd, err := encodeECDHKeyGen(key) + if err != nil { + return nil, nil, err + } + resp, err := runCommand(rw, TagNoSessions, CmdECDHKeyGen, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, nil, err + } + return decodeECDHKeyGen(resp) +} + +func encodeECDHZGen(key tpmutil.Handle, password string, inPoint ECPoint) ([]byte, error) { + ha, err := tpmutil.Pack(key) + if err != nil { + return nil, err + } + auth, err := encodeAuthArea(AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession, Auth: []byte(password)}) + if err != nil { + return nil, err + } + p, err := tpmutil.Pack(inPoint) + if err != nil { + return nil, err + } + // Pack the TPMS_ECC_POINT as a TPM2B_ECC_POINT. + p2B, err := tpmutil.Pack(tpmutil.U16Bytes(p)) + if err != nil { + return nil, err + } + return concat(ha, auth, p2B) +} + +func decodeECDHZGen(resp []byte) (*ECPoint, error) { + var paramSize uint32 + // Unpack a TPM2B_ECC_POINT, which is a TPMS_ECC_POINT with a total size prepended. + var z2B tpmutil.U16Bytes + _, err := tpmutil.Unpack(resp, ¶mSize, &z2B) + if err != nil { + return nil, err + } + var zPoint ECPoint + _, err = tpmutil.Unpack(z2B, &zPoint.XRaw, &zPoint.YRaw) + if err != nil { + return nil, err + } + return &zPoint, nil +} + +// ECDHZGen performs ECDH point multiplication between a private key held in the TPM and a given +// public point, returning the coordinates of the resulting point. The key must have FlagDecrypt +// set. +func ECDHZGen(rw io.ReadWriter, key tpmutil.Handle, password string, inPoint ECPoint) (zPoint *ECPoint, err error) { + Cmd, err := encodeECDHZGen(key, password, inPoint) + if err != nil { + return nil, err + } + resp, err := runCommand(rw, TagSessions, CmdECDHZGen, tpmutil.RawBytes(Cmd)) + if err != nil { + return nil, err + } + return decodeECDHZGen(resp) +} + +// DictionaryAttackLockReset cancels the effect of a TPM lockout due to a number +// of successive authorization failures, by setting the lockout counter to zero. +// The command requires Lockout Authorization and only one lockoutAuth authorization +// failure is allowed for this command during a lockoutRecovery interval. +// Lockout Authorization value by default is empty and can be changed via +// a call to HierarchyChangeAuth(HandleLockout). +func DictionaryAttackLockReset(rw io.ReadWriter, auth AuthCommand) error { + ha, err := tpmutil.Pack(HandleLockout) + if err != nil { + return err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return err + } + Cmd, err := concat(ha, encodedAuth) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdDictionaryAttackLockReset, tpmutil.RawBytes(Cmd)) + return err +} + +// DictionaryAttackParameters changes the lockout parameters. +// The command requires Lockout Authorization and has same authorization policy +// as in DictionaryAttackLockReset. +func DictionaryAttackParameters(rw io.ReadWriter, auth AuthCommand, maxTries, recoveryTime, lockoutRecovery uint32) error { + ha, err := tpmutil.Pack(HandleLockout) + if err != nil { + return err + } + encodedAuth, err := encodeAuthArea(auth) + if err != nil { + return err + } + params, err := tpmutil.Pack(maxTries, recoveryTime, lockoutRecovery) + if err != nil { + return err + } + Cmd, err := concat(ha, encodedAuth, params) + if err != nil { + return err + } + _, err = runCommand(rw, TagSessions, CmdDictionaryAttackParameters, tpmutil.RawBytes(Cmd)) + return err +} + +// PolicyCommandCode indicates that the authorization will be limited to a specific command code +func PolicyCommandCode(rw io.ReadWriter, session tpmutil.Handle, cc tpmutil.Command) error { + data, err := tpmutil.Pack(session, cc) + if err != nil { + return err + } + _, err = runCommand(rw, TagNoSessions, CmdPolicyCommandCode, data) + return err +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/encoding.go b/vendor/github.com/google/go-tpm/tpmutil/encoding.go new file mode 100644 index 0000000000..5983cc215c --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/encoding.go @@ -0,0 +1,211 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpmutil + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "reflect" +) + +var ( + selfMarshalerType = reflect.TypeOf((*SelfMarshaler)(nil)).Elem() + handlesAreaType = reflect.TypeOf((*[]Handle)(nil)) +) + +// packWithHeader takes a header and a sequence of elements that are either of +// fixed length or slices of fixed-length types and packs them into a single +// byte array using binary.Write. It updates the CommandHeader to have the right +// length. +func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { + hdrSize := binary.Size(ch) + body, err := Pack(cmd...) + if err != nil { + return nil, fmt.Errorf("couldn't pack message body: %v", err) + } + bodySize := len(body) + ch.Size = uint32(hdrSize + bodySize) + header, err := Pack(ch) + if err != nil { + return nil, fmt.Errorf("couldn't pack message header: %v", err) + } + return append(header, body...), nil +} + +// Pack encodes a set of elements into a single byte array, using +// encoding/binary. This means that all the elements must be encodeable +// according to the rules of encoding/binary. +// +// It has one difference from encoding/binary: it encodes byte slices with a +// prepended length, to match how the TPM encodes variable-length arrays. If +// you wish to add a byte slice without length prefix, use RawBytes. +func Pack(elts ...interface{}) ([]byte, error) { + buf := new(bytes.Buffer) + if err := packType(buf, elts...); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// tryMarshal attempts to use a TPMMarshal() method defined on the type +// to pack v into buf. True is returned if the method exists and the +// marshal was attempted. +func tryMarshal(buf io.Writer, v reflect.Value) (bool, error) { + t := v.Type() + if t.Implements(selfMarshalerType) { + if v.Kind() == reflect.Ptr && v.IsNil() { + return true, fmt.Errorf("cannot call TPMMarshal on a nil pointer of type %T", v) + } + return true, v.Interface().(SelfMarshaler).TPMMarshal(buf) + } + + // We might have a non-pointer struct field, but we dont have a + // pointer with which to implement the interface. + // If the pointer of the type implements the interface, we should be + // able to construct a value to call TPMMarshal() with. + // TODO(awly): Try and avoid blowing away private data by using Addr() instead of Set() + if reflect.PtrTo(t).Implements(selfMarshalerType) { + tmp := reflect.New(t) + tmp.Elem().Set(v) + return true, tmp.Interface().(SelfMarshaler).TPMMarshal(buf) + } + + return false, nil +} + +func packValue(buf io.Writer, v reflect.Value) error { + if v.Type() == handlesAreaType { + v = v.Convert(reflect.TypeOf((*handleList)(nil))) + } + if canMarshal, err := tryMarshal(buf, v); canMarshal { + return err + } + + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return fmt.Errorf("cannot pack nil %s", v.Type().String()) + } + return packValue(buf, v.Elem()) + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if err := packValue(buf, f); err != nil { + return err + } + } + default: + return binary.Write(buf, binary.BigEndian, v.Interface()) + } + return nil +} + +func packType(buf io.Writer, elts ...interface{}) error { + for _, e := range elts { + if err := packValue(buf, reflect.ValueOf(e)); err != nil { + return err + } + } + + return nil +} + +// tryUnmarshal attempts to use TPMUnmarshal() to perform the +// unpack, if the given value implements SelfMarshaler. +// True is returned if v implements SelfMarshaler & TPMUnmarshal +// was called, along with an error returned from TPMUnmarshal. +func tryUnmarshal(buf io.Reader, v reflect.Value) (bool, error) { + t := v.Type() + if t.Implements(selfMarshalerType) { + if v.Kind() == reflect.Ptr && v.IsNil() { + return true, fmt.Errorf("cannot call TPMUnmarshal on a nil pointer") + } + return true, v.Interface().(SelfMarshaler).TPMUnmarshal(buf) + } + + // We might have a non-pointer struct field, which is addressable, + // If the pointer of the type implements the interface, and the + // value is addressable, we should be able to call TPMUnmarshal(). + if v.CanAddr() && reflect.PtrTo(t).Implements(selfMarshalerType) { + return true, v.Addr().Interface().(SelfMarshaler).TPMUnmarshal(buf) + } + + return false, nil +} + +// Unpack is a convenience wrapper around UnpackBuf. Unpack returns the number +// of bytes read from b to fill elts and error, if any. +func Unpack(b []byte, elts ...interface{}) (int, error) { + buf := bytes.NewBuffer(b) + err := UnpackBuf(buf, elts...) + read := len(b) - buf.Len() + return read, err +} + +func unpackValue(buf io.Reader, v reflect.Value) error { + if v.Type() == handlesAreaType { + v = v.Convert(reflect.TypeOf((*handleList)(nil))) + } + if didUnmarshal, err := tryUnmarshal(buf, v); didUnmarshal { + return err + } + + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return fmt.Errorf("cannot unpack nil %s", v.Type().String()) + } + return unpackValue(buf, v.Elem()) + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if err := unpackValue(buf, f); err != nil { + return err + } + } + return nil + default: + // binary.Read can only set pointer values, so we need to take the address. + if !v.CanAddr() { + return fmt.Errorf("cannot unpack unaddressable leaf type %q", v.Type().String()) + } + return binary.Read(buf, binary.BigEndian, v.Addr().Interface()) + } +} + +// UnpackBuf recursively unpacks types from a reader just as encoding/binary +// does under binary.BigEndian, but with one difference: it unpacks a byte +// slice by first reading an integer with lengthPrefixSize bytes, then reading +// that many bytes. It assumes that incoming values are pointers to values so +// that, e.g., underlying slices can be resized as needed. +func UnpackBuf(buf io.Reader, elts ...interface{}) error { + for _, e := range elts { + v := reflect.ValueOf(e) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("non-pointer value %q passed to UnpackBuf", v.Type().String()) + } + if v.IsNil() { + return errors.New("nil pointer passed to UnpackBuf") + } + + if err := unpackValue(buf, v); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/poll_other.go b/vendor/github.com/google/go-tpm/tpmutil/poll_other.go new file mode 100644 index 0000000000..ba7e062e32 --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/poll_other.go @@ -0,0 +1,10 @@ +//go:build !linux && !darwin + +package tpmutil + +import ( + "os" +) + +// Not implemented on Windows. +func poll(_ *os.File) error { return nil } diff --git a/vendor/github.com/google/go-tpm/tpmutil/poll_unix.go b/vendor/github.com/google/go-tpm/tpmutil/poll_unix.go new file mode 100644 index 0000000000..89d85d3814 --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/poll_unix.go @@ -0,0 +1,32 @@ +//go:build linux || darwin + +package tpmutil + +import ( + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +// poll blocks until the file descriptor is ready for reading or an error occurs. +func poll(f *os.File) error { + var ( + fds = []unix.PollFd{{ + Fd: int32(f.Fd()), + Events: 0x1, // POLLIN + }} + timeout = -1 // Indefinite timeout + ) + + if _, err := unix.Poll(fds, timeout); err != nil { + return err + } + + // Revents is filled in by the kernel. + // If the expected event happened, Revents should match Events. + if fds[0].Revents != fds[0].Events { + return fmt.Errorf("unexpected poll Revents 0x%x", fds[0].Revents) + } + return nil +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/run.go b/vendor/github.com/google/go-tpm/tpmutil/run.go new file mode 100644 index 0000000000..c07e3abab4 --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/run.go @@ -0,0 +1,113 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tpmutil provides common utility functions for both TPM 1.2 and TPM +// 2.0 devices. +package tpmutil + +import ( + "errors" + "io" + "os" + "time" +) + +// maxTPMResponse is the largest possible response from the TPM. We need to know +// this because we don't always know the length of the TPM response, and +// /dev/tpm insists on giving it all back in a single value rather than +// returning a header and a body in separate responses. +const maxTPMResponse = 4096 + +// RunCommandRaw executes the given raw command and returns the raw response. +// Does not check the response code except to execute retry logic. +func RunCommandRaw(rw io.ReadWriter, inb []byte) ([]byte, error) { + if rw == nil { + return nil, errors.New("nil TPM handle") + } + + // f(t) = (2^t)ms, up to 2s + var backoffFac uint + var rh responseHeader + var outb []byte + + for { + if _, err := rw.Write(inb); err != nil { + return nil, err + } + + // If the TPM is a real device, it may not be ready for reading + // immediately after writing the command. Wait until the file + // descriptor is ready to be read from. + if f, ok := rw.(*os.File); ok { + if err := poll(f); err != nil { + return nil, err + } + } + + outb = make([]byte, maxTPMResponse) + outlen, err := rw.Read(outb) + if err != nil { + return nil, err + } + // Resize the buffer to match the amount read from the TPM. + outb = outb[:outlen] + + _, err = Unpack(outb, &rh) + if err != nil { + return nil, err + } + + // If TPM is busy, retry the command after waiting a few ms. + if rh.Res == RCRetry { + if backoffFac < 11 { + dur := (1 << backoffFac) * time.Millisecond + time.Sleep(dur) + backoffFac++ + } else { + return nil, err + } + } else { + break + } + } + + return outb, nil +} + +// RunCommand executes cmd with given tag and arguments. Returns TPM response +// body (without response header) and response code from the header. Returned +// error may be nil if response code is not RCSuccess; caller should check +// both. +func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { + inb, err := packWithHeader(commandHeader{tag, 0, cmd}, in...) + if err != nil { + return nil, 0, err + } + + outb, err := RunCommandRaw(rw, inb) + if err != nil { + return nil, 0, err + } + + var rh responseHeader + read, err := Unpack(outb, &rh) + if err != nil { + return nil, 0, err + } + if rh.Res != RCSuccess { + return nil, rh.Res, nil + } + + return outb[read:], rh.Res, nil +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/run_other.go b/vendor/github.com/google/go-tpm/tpmutil/run_other.go new file mode 100644 index 0000000000..2a142d39ec --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/run_other.go @@ -0,0 +1,111 @@ +//go:build !windows + +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpmutil + +import ( + "fmt" + "io" + "net" + "os" +) + +// OpenTPM opens a channel to the TPM at the given path. If the file is a +// device, then it treats it like a normal TPM device, and if the file is a +// Unix domain socket, then it opens a connection to the socket. +func OpenTPM(path string) (io.ReadWriteCloser, error) { + // If it's a regular file, then open it + var rwc io.ReadWriteCloser + fi, err := os.Stat(path) + if err != nil { + return nil, err + } + + if fi.Mode()&os.ModeDevice != 0 { + var f *os.File + f, err = os.OpenFile(path, os.O_RDWR, 0600) + if err != nil { + return nil, err + } + rwc = io.ReadWriteCloser(f) + } else if fi.Mode()&os.ModeSocket != 0 { + rwc = NewEmulatorReadWriteCloser(path) + } else { + return nil, fmt.Errorf("unsupported TPM file mode %s", fi.Mode().String()) + } + + return rwc, nil +} + +// dialer abstracts the net.Dial call so test code can provide its own net.Conn +// implementation. +type dialer func(network, path string) (net.Conn, error) + +// EmulatorReadWriteCloser manages connections with a TPM emulator over a Unix +// domain socket. These emulators often operate in a write/read/disconnect +// sequence, so the Write method always connects, and the Read method always +// closes. EmulatorReadWriteCloser is not thread safe. +type EmulatorReadWriteCloser struct { + path string + conn net.Conn + dialer dialer +} + +// NewEmulatorReadWriteCloser stores information about a Unix domain socket to +// write to and read from. +func NewEmulatorReadWriteCloser(path string) *EmulatorReadWriteCloser { + return &EmulatorReadWriteCloser{ + path: path, + dialer: net.Dial, + } +} + +// Read implements io.Reader by reading from the Unix domain socket and closing +// it. +func (erw *EmulatorReadWriteCloser) Read(p []byte) (int, error) { + // Read is always the second operation in a Write/Read sequence. + if erw.conn == nil { + return 0, fmt.Errorf("must call Write then Read in an alternating sequence") + } + n, err := erw.conn.Read(p) + erw.conn.Close() + erw.conn = nil + return n, err +} + +// Write implements io.Writer by connecting to the Unix domain socket and +// writing. +func (erw *EmulatorReadWriteCloser) Write(p []byte) (int, error) { + if erw.conn != nil { + return 0, fmt.Errorf("must call Write then Read in an alternating sequence") + } + var err error + erw.conn, err = erw.dialer("unix", erw.path) + if err != nil { + return 0, err + } + return erw.conn.Write(p) +} + +// Close implements io.Closer by closing the Unix domain socket if one is open. +func (erw *EmulatorReadWriteCloser) Close() error { + if erw.conn == nil { + return fmt.Errorf("cannot call Close when no connection is open") + } + err := erw.conn.Close() + erw.conn = nil + return err +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/run_windows.go b/vendor/github.com/google/go-tpm/tpmutil/run_windows.go new file mode 100644 index 0000000000..f355b81012 --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/run_windows.go @@ -0,0 +1,84 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpmutil + +import ( + "io" + + "github.com/google/go-tpm/tpmutil/tbs" +) + +// winTPMBuffer is a ReadWriteCloser to access the TPM in Windows. +type winTPMBuffer struct { + context tbs.Context + outBuffer []byte +} + +// Executes the TPM command specified by commandBuffer (at Normal Priority), returning the number +// of bytes in the command and any error code returned by executing the TPM command. Command +// response can be read by calling Read(). +func (rwc *winTPMBuffer) Write(commandBuffer []byte) (int, error) { + // TPM spec defines longest possible response to be maxTPMResponse. + rwc.outBuffer = rwc.outBuffer[:maxTPMResponse] + + outBufferLen, err := rwc.context.SubmitCommand( + tbs.NormalPriority, + commandBuffer, + rwc.outBuffer, + ) + + if err != nil { + rwc.outBuffer = rwc.outBuffer[:0] + return 0, err + } + // Shrink outBuffer so it is length of response. + rwc.outBuffer = rwc.outBuffer[:outBufferLen] + return len(commandBuffer), nil +} + +// Provides TPM response from the command called in the last Write call. +func (rwc *winTPMBuffer) Read(responseBuffer []byte) (int, error) { + if len(rwc.outBuffer) == 0 { + return 0, io.EOF + } + lenCopied := copy(responseBuffer, rwc.outBuffer) + // Cut out the piece of slice which was just read out, maintaining original slice capacity. + rwc.outBuffer = append(rwc.outBuffer[:0], rwc.outBuffer[lenCopied:]...) + return lenCopied, nil +} + +func (rwc *winTPMBuffer) Close() error { + return rwc.context.Close() +} + +// OpenTPM creates a new instance of a ReadWriteCloser which can interact with a +// Windows TPM. +func OpenTPM() (io.ReadWriteCloser, error) { + tpmContext, err := tbs.CreateContext(tbs.TPMVersion20, tbs.IncludeTPM12|tbs.IncludeTPM20) + rwc := &winTPMBuffer{ + context: tpmContext, + outBuffer: make([]byte, 0, maxTPMResponse), + } + return rwc, err +} + +// FromContext creates a new instance of a ReadWriteCloser which can +// interact with a Windows TPM, using the specified TBS handle. +func FromContext(ctx tbs.Context) io.ReadWriteCloser { + return &winTPMBuffer{ + context: ctx, + outBuffer: make([]byte, 0, maxTPMResponse), + } +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/structures.go b/vendor/github.com/google/go-tpm/tpmutil/structures.go new file mode 100644 index 0000000000..893b6b6df9 --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/structures.go @@ -0,0 +1,195 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tpmutil + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" +) + +// maxBytesBufferSize sets a sane upper bound on the size of a U32Bytes +// buffer. This limit exists to prevent a maliciously large size prefix +// from resulting in a massive memory allocation, potentially causing +// an OOM condition on the system. +// We expect no buffer from a TPM to approach 1Mb in size. +const maxBytesBufferSize uint32 = 1024 * 1024 // 1Mb. + +// RawBytes is for Pack and RunCommand arguments that are already encoded. +// Compared to []byte, RawBytes will not be prepended with slice length during +// encoding. +type RawBytes []byte + +// U16Bytes is a byte slice with a 16-bit header +type U16Bytes []byte + +// TPMMarshal packs U16Bytes +func (b *U16Bytes) TPMMarshal(out io.Writer) error { + size := len([]byte(*b)) + if err := binary.Write(out, binary.BigEndian, uint16(size)); err != nil { + return err + } + + n, err := out.Write(*b) + if err != nil { + return err + } + if n != size { + return fmt.Errorf("unable to write all contents of U16Bytes") + } + return nil +} + +// TPMUnmarshal unpacks a U16Bytes +func (b *U16Bytes) TPMUnmarshal(in io.Reader) error { + var tmpSize uint16 + if err := binary.Read(in, binary.BigEndian, &tmpSize); err != nil { + return err + } + size := int(tmpSize) + + if len(*b) >= size { + *b = (*b)[:size] + } else { + *b = append(*b, make([]byte, size-len(*b))...) + } + + n, err := in.Read(*b) + if err != nil { + return err + } + if n != size { + return io.ErrUnexpectedEOF + } + return nil +} + +// U32Bytes is a byte slice with a 32-bit header +type U32Bytes []byte + +// TPMMarshal packs U32Bytes +func (b *U32Bytes) TPMMarshal(out io.Writer) error { + size := len([]byte(*b)) + if err := binary.Write(out, binary.BigEndian, uint32(size)); err != nil { + return err + } + + n, err := out.Write(*b) + if err != nil { + return err + } + if n != size { + return fmt.Errorf("unable to write all contents of U32Bytes") + } + return nil +} + +// TPMUnmarshal unpacks a U32Bytes +func (b *U32Bytes) TPMUnmarshal(in io.Reader) error { + var tmpSize uint32 + if err := binary.Read(in, binary.BigEndian, &tmpSize); err != nil { + return err + } + + if tmpSize > maxBytesBufferSize { + return bytes.ErrTooLarge + } + // We can now safely cast to an int on 32-bit or 64-bit machines + size := int(tmpSize) + + if len(*b) >= size { + *b = (*b)[:size] + } else { + *b = append(*b, make([]byte, size-len(*b))...) + } + + n, err := in.Read(*b) + if err != nil { + return err + } + if n != size { + return fmt.Errorf("unable to read all contents in to U32Bytes") + } + return nil +} + +// Tag is a command tag. +type Tag uint16 + +// Command is an identifier of a TPM command. +type Command uint32 + +// A commandHeader is the header for a TPM command. +type commandHeader struct { + Tag Tag + Size uint32 + Cmd Command +} + +// ResponseCode is a response code returned by TPM. +type ResponseCode uint32 + +// RCSuccess is response code for successful command. Identical for TPM 1.2 and +// 2.0. +const RCSuccess ResponseCode = 0x000 + +// RCRetry is response code for TPM is busy. +const RCRetry ResponseCode = 0x922 + +// A responseHeader is a header for TPM responses. +type responseHeader struct { + Tag Tag + Size uint32 + Res ResponseCode +} + +// A Handle is a reference to a TPM object. +type Handle uint32 + +// HandleValue returns the handle value. This behavior is intended to satisfy +// an interface that can be implemented by other, more complex types as well. +func (h Handle) HandleValue() uint32 { + return uint32(h) +} + +type handleList []Handle + +func (l *handleList) TPMMarshal(_ io.Writer) error { + return fmt.Errorf("TPMMarhsal on []Handle is not supported yet") +} + +func (l *handleList) TPMUnmarshal(in io.Reader) error { + var numHandles uint16 + if err := binary.Read(in, binary.BigEndian, &numHandles); err != nil { + return err + } + + // Make len(e) match size exactly. + size := int(numHandles) + if len(*l) >= size { + *l = (*l)[:size] + } else { + *l = append(*l, make([]Handle, size-len(*l))...) + } + return binary.Read(in, binary.BigEndian, *l) +} + +// SelfMarshaler allows custom types to override default encoding/decoding +// behavior in Pack, Unpack and UnpackBuf. +type SelfMarshaler interface { + TPMMarshal(out io.Writer) error + TPMUnmarshal(in io.Reader) error +} diff --git a/vendor/github.com/google/go-tpm/tpmutil/tbs/tbs_windows.go b/vendor/github.com/google/go-tpm/tpmutil/tbs/tbs_windows.go new file mode 100644 index 0000000000..b23bf96a1e --- /dev/null +++ b/vendor/github.com/google/go-tpm/tpmutil/tbs/tbs_windows.go @@ -0,0 +1,267 @@ +// Copyright (c) 2018, Google LLC All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tbs provides an low-level interface directly mapping to Windows +// Tbs.dll system library commands: +// https://docs.microsoft.com/en-us/windows/desktop/TBS/tpm-base-services-portal +// Public field descriptions contain links to the high-level Windows documentation. +package tbs + +import ( + "fmt" + "syscall" + "unsafe" +) + +// Context references the current TPM context +type Context uintptr + +// Version of TPM being used by the application. +type Version uint32 + +// Flag indicates TPM versions that are supported by the application. +type Flag uint32 + +// CommandPriority is used to determine which pending command to submit whenever the TPM is free. +type CommandPriority uint32 + +// Command parameters: +// https://github.com/tpn/winsdk-10/blob/master/Include/10.0.10240.0/shared/tbs.h +const ( + // https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/ns-tbs-tdtbs_context_params2 + // OR flags to use multiple. + RequestRaw Flag = 1 << iota // Add flag to request raw context + IncludeTPM12 // Add flag to support TPM 1.2 + IncludeTPM20 // Add flag to support TPM 2 + + TPMVersion12 Version = 1 // For TPM 1.2 applications + TPMVersion20 Version = 2 // For TPM 2 applications or applications using multiple TPM versions + + // https://docs.microsoft.com/en-us/windows/desktop/tbs/command-scheduling + // https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_submit_command#parameters + LowPriority CommandPriority = 100 // For low priority application use + NormalPriority CommandPriority = 200 // For normal priority application use + HighPriority CommandPriority = 300 // For high priority application use + SystemPriority CommandPriority = 400 // For system tasks that access the TPM + + commandLocalityZero uint32 = 0 // Windows currently only supports TBS_COMMAND_LOCALITY_ZERO. +) + +// Error is the return type of all functions in this package. +type Error uint32 + +func (err Error) Error() string { + if description, ok := errorDescriptions[err]; ok { + return fmt.Sprintf("TBS Error 0x%X: %s", uint32(err), description) + } + return fmt.Sprintf("Unrecognized TBS Error 0x%X", uint32(err)) +} + +func getError(err uintptr) error { + // tbs.dll uses 0x0 as the return value for success. + if err == 0 { + return nil + } + return Error(err) +} + +// TBS Return Codes: +// https://docs.microsoft.com/en-us/windows/desktop/TBS/tbs-return-codes +const ( + ErrInternalError Error = 0x80284001 + ErrBadParameter Error = 0x80284002 + ErrInvalidOutputPointer Error = 0x80284003 + ErrInvalidContext Error = 0x80284004 + ErrInsufficientBuffer Error = 0x80284005 + ErrIOError Error = 0x80284006 + ErrInvalidContextParam Error = 0x80284007 + ErrServiceNotRunning Error = 0x80284008 + ErrTooManyTBSContexts Error = 0x80284009 + ErrTooManyResources Error = 0x8028400A + ErrServiceStartPending Error = 0x8028400B + ErrPPINotSupported Error = 0x8028400C + ErrCommandCanceled Error = 0x8028400D + ErrBufferTooLarge Error = 0x8028400E + ErrTPMNotFound Error = 0x8028400F + ErrServiceDisabled Error = 0x80284010 + ErrNoEventLog Error = 0x80284011 + ErrAccessDenied Error = 0x80284012 + ErrProvisioningNotAllowed Error = 0x80284013 + ErrPPIFunctionUnsupported Error = 0x80284014 + ErrOwnerauthNotFound Error = 0x80284015 +) + +var errorDescriptions = map[Error]string{ + ErrInternalError: "An internal software error occurred.", + ErrBadParameter: "One or more parameter values are not valid.", + ErrInvalidOutputPointer: "A specified output pointer is bad.", + ErrInvalidContext: "The specified context handle does not refer to a valid context.", + ErrInsufficientBuffer: "The specified output buffer is too small.", + ErrIOError: "An error occurred while communicating with the TPM.", + ErrInvalidContextParam: "A context parameter that is not valid was passed when attempting to create a TBS context.", + ErrServiceNotRunning: "The TBS service is not running and could not be started.", + ErrTooManyTBSContexts: "A new context could not be created because there are too many open contexts.", + ErrTooManyResources: "A new virtual resource could not be created because there are too many open virtual resources.", + ErrServiceStartPending: "The TBS service has been started but is not yet running.", + ErrPPINotSupported: "The physical presence interface is not supported.", + ErrCommandCanceled: "The command was canceled.", + ErrBufferTooLarge: "The input or output buffer is too large.", + ErrTPMNotFound: "A compatible Trusted Platform Module (TPM) Security Device cannot be found on this computer.", + ErrServiceDisabled: "The TBS service has been disabled.", + ErrNoEventLog: "The TBS event log is not available.", + ErrAccessDenied: "The caller does not have the appropriate rights to perform the requested operation.", + ErrProvisioningNotAllowed: "The TPM provisioning action is not allowed by the specified flags.", + ErrPPIFunctionUnsupported: "The Physical Presence Interface of this firmware does not support the requested method.", + ErrOwnerauthNotFound: "The requested TPM OwnerAuth value was not found.", +} + +// Tbs.dll provides an API for making calls to the TPM: +// https://docs.microsoft.com/en-us/windows/desktop/TBS/tpm-base-services-portal +var ( + tbsDLL = syscall.NewLazyDLL("Tbs.dll") + tbsGetDeviceInfo = tbsDLL.NewProc("Tbsi_GetDeviceInfo") + tbsCreateContext = tbsDLL.NewProc("Tbsi_Context_Create") + tbsContextClose = tbsDLL.NewProc("Tbsip_Context_Close") + tbsSubmitCommand = tbsDLL.NewProc("Tbsip_Submit_Command") + tbsGetTCGLog = tbsDLL.NewProc("Tbsi_Get_TCG_Log") +) + +// Returns the address of the beginning of a slice or 0 for a nil slice. +func sliceAddress(s []byte) uintptr { + if len(s) == 0 { + return 0 + } + return uintptr(unsafe.Pointer(&(s[0]))) +} + +// DeviceInfo is TPM_DEVICE_INFO from tbs.h +type DeviceInfo struct { + StructVersion uint32 + TPMVersion Version + TPMInterfaceType uint32 + TPMImpRevision uint32 +} + +// GetDeviceInfo gets the DeviceInfo of the current TPM: +// https://docs.microsoft.com/en-us/windows/win32/api/tbs/nf-tbs-tbsi_getdeviceinfo +func GetDeviceInfo() (*DeviceInfo, error) { + info := DeviceInfo{} + // TBS_RESULT Tbsi_GetDeviceInfo( + // UINT32 Size, + // PVOID Info + // ); + if err := tbsGetDeviceInfo.Find(); err != nil { + return nil, err + } + result, _, _ := tbsGetDeviceInfo.Call( + unsafe.Sizeof(info), + uintptr(unsafe.Pointer(&info)), + ) + return &info, getError(result) +} + +// CreateContext creates a new TPM context: +// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsi_context_create +func CreateContext(version Version, flag Flag) (Context, error) { + var context Context + params := struct { + Version + Flag + }{version, flag} + // TBS_RESULT Tbsi_Context_Create( + // _In_ PCTBS_CONTEXT_PARAMS pContextParams, + // _Out_ PTBS_HCONTEXT *phContext + // ); + if err := tbsCreateContext.Find(); err != nil { + return context, err + } + result, _, _ := tbsCreateContext.Call( + uintptr(unsafe.Pointer(¶ms)), + uintptr(unsafe.Pointer(&context)), + ) + return context, getError(result) +} + +// Close closes an existing TPM context: +// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_context_close +func (context Context) Close() error { + // TBS_RESULT Tbsip_Context_Close( + // _In_ TBS_HCONTEXT hContext + // ); + if err := tbsContextClose.Find(); err != nil { + return err + } + result, _, _ := tbsContextClose.Call(uintptr(context)) + return getError(result) +} + +// SubmitCommand sends commandBuffer to the TPM, returning the number of bytes +// written to responseBuffer. ErrInsufficientBuffer is returned if the +// responseBuffer is too short. ErrInvalidOutputPointer is returned if the +// responseBuffer is nil. On failure, the returned length is unspecified. +// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsip_submit_command +func (context Context) SubmitCommand( + priority CommandPriority, + commandBuffer []byte, + responseBuffer []byte, +) (uint32, error) { + responseBufferLen := uint32(len(responseBuffer)) + + // TBS_RESULT Tbsip_Submit_Command( + // _In_ TBS_HCONTEXT hContext, + // _In_ TBS_COMMAND_LOCALITY Locality, + // _In_ TBS_COMMAND_PRIORITY Priority, + // _In_ const PCBYTE *pabCommand, + // _In_ UINT32 cbCommand, + // _Out_ PBYTE *pabResult, + // _Inout_ UINT32 *pcbOutput + // ); + if err := tbsSubmitCommand.Find(); err != nil { + return 0, err + } + result, _, _ := tbsSubmitCommand.Call( + uintptr(context), + uintptr(commandLocalityZero), + uintptr(priority), + sliceAddress(commandBuffer), + uintptr(len(commandBuffer)), + sliceAddress(responseBuffer), + uintptr(unsafe.Pointer(&responseBufferLen)), + ) + return responseBufferLen, getError(result) +} + +// GetTCGLog gets the system event log, returning the number of bytes written +// to logBuffer. If logBuffer is nil, the size of the TCG log is returned. +// ErrInsufficientBuffer is returned if the logBuffer is too short. On failure, +// the returned length is unspecified. +// https://docs.microsoft.com/en-us/windows/desktop/api/Tbs/nf-tbs-tbsi_get_tcg_log +func (context Context) GetTCGLog(logBuffer []byte) (uint32, error) { + logBufferLen := uint32(len(logBuffer)) + + // TBS_RESULT Tbsi_Get_TCG_Log( + // TBS_HCONTEXT hContext, + // PBYTE pOutputBuf, + // PUINT32 pOutputBufLen + // ); + if err := tbsGetTCGLog.Find(); err != nil { + return 0, err + } + result, _, _ := tbsGetTCGLog.Call( + uintptr(context), + sliceAddress(logBuffer), + uintptr(unsafe.Pointer(&logBufferLen)), + ) + return logBufferLen, getError(result) +} diff --git a/vendor/github.com/nats-io/nats-server/v2/conf/fuzz.go b/vendor/github.com/nats-io/nats-server/v2/conf/fuzz.go index 3aba1551eb..2db114ce72 100644 --- a/vendor/github.com/nats-io/nats-server/v2/conf/fuzz.go +++ b/vendor/github.com/nats-io/nats-server/v2/conf/fuzz.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build gofuzz -// +build gofuzz package conf diff --git a/vendor/github.com/nats-io/nats-server/v2/conf/parse.go b/vendor/github.com/nats-io/nats-server/v2/conf/parse.go index 4e91c6667b..c1f064ae75 100644 --- a/vendor/github.com/nats-io/nats-server/v2/conf/parse.go +++ b/vendor/github.com/nats-io/nats-server/v2/conf/parse.go @@ -26,6 +26,8 @@ package conf // see parse_test.go for more examples. import ( + "crypto/sha256" + "encoding/json" "fmt" "os" "path/filepath" @@ -35,6 +37,8 @@ import ( "unicode" ) +const _EMPTY_ = "" + type parser struct { mapping map[string]any lx *lexer @@ -69,6 +73,15 @@ func Parse(data string) (map[string]any, error) { return p.mapping, nil } +// ParseWithChecks is equivalent to Parse but runs in pedantic mode. +func ParseWithChecks(data string) (map[string]any, error) { + p, err := parse(data, "", true) + if err != nil { + return nil, err + } + return p.mapping, nil +} + // ParseFile is a helper to open file, etc. and parse the contents. func ParseFile(fp string) (map[string]any, error) { data, err := os.ReadFile(fp) @@ -98,6 +111,44 @@ func ParseFileWithChecks(fp string) (map[string]any, error) { return p.mapping, nil } +// cleanupUsedEnvVars will recursively remove all already used +// environment variables which might be in the parsed tree. +func cleanupUsedEnvVars(m map[string]any) { + for k, v := range m { + t := v.(*token) + if t.usedVariable { + delete(m, k) + continue + } + // Cleanup any other env var that is still in the map. + if tm, ok := t.value.(map[string]any); ok { + cleanupUsedEnvVars(tm) + } + } +} + +// ParseFileWithChecksDigest returns the processed config and a digest +// that represents the configuration. +func ParseFileWithChecksDigest(fp string) (map[string]any, string, error) { + data, err := os.ReadFile(fp) + if err != nil { + return nil, _EMPTY_, err + } + p, err := parse(string(data), fp, true) + if err != nil { + return nil, _EMPTY_, err + } + // Filter out any environment variables before taking the digest. + cleanupUsedEnvVars(p.mapping) + digest := sha256.New() + e := json.NewEncoder(digest) + err = e.Encode(p.mapping) + if err != nil { + return nil, _EMPTY_, err + } + return p.mapping, fmt.Sprintf("sha256:%x", digest.Sum(nil)), nil +} + type token struct { item item value any @@ -105,6 +156,10 @@ type token struct { sourceFile string } +func (t *token) MarshalJSON() ([]byte, error) { + return json.Marshal(t.value) +} + func (t *token) Value() any { return t.value } diff --git a/vendor/github.com/nats-io/nats-server/v2/logger/syslog.go b/vendor/github.com/nats-io/nats-server/v2/logger/syslog.go index d1c9fea2cc..211dd97cad 100644 --- a/vendor/github.com/nats-io/nats-server/v2/logger/syslog.go +++ b/vendor/github.com/nats-io/nats-server/v2/logger/syslog.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build !windows -// +build !windows package logger diff --git a/vendor/github.com/nats-io/nats-server/v2/server/README.md b/vendor/github.com/nats-io/nats-server/v2/server/README.md index 3184eeda87..38e8621eef 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/README.md +++ b/vendor/github.com/nats-io/nats-server/v2/server/README.md @@ -10,7 +10,7 @@ The script `runTestsOnTravis.sh` will run a given job based on the definition fo As for the naming convention: -- All JetStream tests name should start with `TestJetStream` +- All JetStream test name should start with `TestJetStream` - Cluster tests should go into `jetstream_cluster_test.go` and start with `TestJetStreamCluster` - Super-cluster tests should go into `jetstream_super_cluster_test.go` and start with `TestJetStreamSuperCluster` diff --git a/vendor/github.com/nats-io/nats-server/v2/server/accounts.go b/vendor/github.com/nats-io/nats-server/v2/server/accounts.go index a821b9d8e5..8146bb02a1 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/accounts.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/accounts.go @@ -1,4 +1,4 @@ -// Copyright 2018-2024 The NATS Authors +// Copyright 2018-2025 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -61,6 +61,7 @@ type Account struct { sqmu sync.Mutex sl *Sublist ic *client + sq *sendq isid uint64 etmr *time.Timer ctmr *time.Timer @@ -97,6 +98,12 @@ type Account struct { nameTag string lastLimErr int64 routePoolIdx int + // If the trace destination is specified and a message with a traceParentHdr + // is received, and has the least significant bit of the last token set to 1, + // then if traceDestSampling is > 0 and < 100, a random value will be selected + // and if it falls between 0 and that value, message tracing will be triggered. + traceDest string + traceDestSampling int // Guarantee that only one goroutine can be running either checkJetStreamMigrate // or clearObserverState at a given time for this account to prevent interleaving. jscmMu sync.Mutex @@ -132,6 +139,10 @@ type streamImport struct { claim *jwt.Import usePub bool invalid bool + // This is `allow_trace` and when true and message tracing is happening, + // we will trace egresses past the account boundary, if `false`, we stop + // at the account boundary. + atrc bool } const ClientInfoHdr = "Nats-Request-Info" @@ -156,6 +167,7 @@ type serviceImport struct { share bool tracking bool didDeliver bool + atrc bool // allow trace (got from service export) trackingHdr http.Header // header from request } @@ -213,6 +225,11 @@ type serviceExport struct { latency *serviceLatency rtmr *time.Timer respThresh time.Duration + // This is `allow_trace` and when true and message tracing is happening, + // when processing a service import we will go through account boundary + // and trace egresses on that other account. If `false`, we stop at the + // account boundary. + atrc bool } // Used to track service latency. @@ -250,11 +267,29 @@ func (a *Account) String() string { return a.Name } +func (a *Account) setTraceDest(dest string) { + a.mu.Lock() + a.traceDest = dest + a.mu.Unlock() +} + +func (a *Account) getTraceDestAndSampling() (string, int) { + a.mu.RLock() + dest := a.traceDest + sampling := a.traceDestSampling + a.mu.RUnlock() + return dest, sampling +} + // Used to create shallow copies of accounts for transfer // from opts to real accounts in server struct. +// Account `na` write lock is expected to be held on entry +// while account `a` is the one from the Options struct +// being loaded/reloaded and do not need locking. func (a *Account) shallowCopy(na *Account) { na.Nkey = a.Nkey na.Issuer = a.Issuer + na.traceDest, na.traceDestSampling = a.traceDest, a.traceDestSampling if a.imports.streams != nil { na.imports.streams = make([]*streamImport, 0, len(a.imports.streams)) @@ -425,6 +460,29 @@ func (a *Account) GetName() string { return name } +// getNameTag will return the name tag or the account name if not set. +func (a *Account) getNameTag() string { + if a == nil { + return _EMPTY_ + } + a.mu.RLock() + defer a.mu.RUnlock() + return a.getNameTagLocked() +} + +// getNameTagLocked will return the name tag or the account name if not set. +// Lock should be held. +func (a *Account) getNameTagLocked() string { + if a == nil { + return _EMPTY_ + } + nameTag := a.nameTag + if nameTag == _EMPTY_ { + nameTag = a.Name + } + return nameTag +} + // NumConnections returns active number of clients for this account for // all known servers. func (a *Account) NumConnections() int { @@ -623,7 +681,7 @@ func (a *Account) AddWeightedMappings(src string, dests ...*MapDest) error { if tw[d.Cluster] > 100 { return fmt.Errorf("total weight needs to be <= 100") } - err := ValidateMappingDestination(d.Subject) + err := ValidateMapping(src, d.Subject) if err != nil { return err } @@ -1905,11 +1963,13 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im return nil, ErrMissingAccount } + var atrc bool dest.mu.RLock() se := dest.getServiceExport(to) if se != nil { rt = se.respType lat = se.latency + atrc = se.atrc } dest.mu.RUnlock() @@ -1954,7 +2014,7 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im if claim != nil { share = claim.Share } - si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, nil} + si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, atrc, nil} a.imports.services[from] = si a.mu.Unlock() @@ -2178,9 +2238,15 @@ func shouldSample(l *serviceLatency, c *client) (bool, http.Header) { } return true, http.Header{trcB3: b3} // sampling allowed or left to recipient of header } else if tId := h[trcCtx]; len(tId) != 0 { + var sample bool // sample 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01 tk := strings.Split(tId[0], "-") - if len(tk) == 4 && len([]byte(tk[3])) == 2 && tk[3] == "01" { + if len(tk) == 4 && len([]byte(tk[3])) == 2 { + if hexVal, err := strconv.ParseInt(tk[3], 16, 8); err == nil { + sample = hexVal&0x1 == 0x1 + } + } + if sample { return true, newTraceCtxHeader(h, tId) } else { return false, nil @@ -2392,6 +2458,18 @@ func (a *Account) SetServiceExportResponseThreshold(export string, maxTime time. return nil } +func (a *Account) SetServiceExportAllowTrace(export string, allowTrace bool) error { + a.mu.Lock() + se := a.getServiceExport(export) + if se == nil { + a.mu.Unlock() + return fmt.Errorf("no export defined for %q", export) + } + se.atrc = allowTrace + a.mu.Unlock() + return nil +} + // This is for internal service import responses. func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImport, tracking bool, header http.Header) *serviceImport { nrr := string(osi.acc.newServiceReply(tracking)) @@ -2401,7 +2479,7 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp // dest is the requestor's account. a is the service responder with the export. // Marked as internal here, that is how we distinguish. - si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, nil} + si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, false, nil} if a.exports.responses == nil { a.exports.responses = make(map[string]*serviceImport) @@ -2430,6 +2508,10 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp // AddStreamImportWithClaim will add in the stream import from a specific account with optional token. func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string, imClaim *jwt.Import) error { + return a.addStreamImportWithClaim(account, from, prefix, false, imClaim) +} + +func (a *Account) addStreamImportWithClaim(account *Account, from, prefix string, allowTrace bool, imClaim *jwt.Import) error { if account == nil { return ErrMissingAccount } @@ -2452,7 +2534,7 @@ func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string } } - return a.AddMappedStreamImportWithClaim(account, from, prefix+from, imClaim) + return a.addMappedStreamImportWithClaim(account, from, prefix+from, allowTrace, imClaim) } // AddMappedStreamImport helper for AddMappedStreamImportWithClaim @@ -2462,6 +2544,10 @@ func (a *Account) AddMappedStreamImport(account *Account, from, to string) error // AddMappedStreamImportWithClaim will add in the stream import from a specific account with optional token. func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to string, imClaim *jwt.Import) error { + return a.addMappedStreamImportWithClaim(account, from, to, false, imClaim) +} + +func (a *Account) addMappedStreamImportWithClaim(account *Account, from, to string, allowTrace bool, imClaim *jwt.Import) error { if account == nil { return ErrMissingAccount } @@ -2507,7 +2593,10 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri a.mu.Unlock() return ErrStreamImportDuplicate } - a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false}) + if imClaim != nil { + allowTrace = imClaim.AllowTrace + } + a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false, allowTrace}) a.mu.Unlock() return nil } @@ -2525,7 +2614,7 @@ func (a *Account) isStreamImportDuplicate(acc *Account, from string) bool { // AddStreamImport will add in the stream import from a specific account. func (a *Account) AddStreamImport(account *Account, from, prefix string) error { - return a.AddStreamImportWithClaim(account, from, prefix, nil) + return a.addStreamImportWithClaim(account, from, prefix, false, nil) } // IsPublicExport is a placeholder to denote a public export. @@ -2844,7 +2933,9 @@ func (a *Account) checkStreamImportsEqual(b *Account) bool { bm[bim.acc.Name+bim.from+bim.to] = bim } for _, aim := range a.imports.streams { - if _, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok { + if bim, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok { + return false + } else if aim.atrc != bim.atrc { return false } } @@ -2930,6 +3021,9 @@ func isServiceExportEqual(a, b *serviceExport) bool { return false } } + if a.atrc != b.atrc { + return false + } return true } @@ -3205,6 +3299,19 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // Grab trace label under lock. tl := a.traceLabel() + var td string + var tds int + if ac.Trace != nil { + // Update trace destination and sampling + td, tds = string(ac.Trace.Destination), ac.Trace.Sampling + if !IsValidPublishSubject(td) { + td, tds = _EMPTY_, 0 + } else if tds <= 0 || tds > 100 { + tds = 100 + } + } + a.traceDest, a.traceDestSampling = td, tds + // Check for external authorization. if ac.HasExternalAuthorization() { a.extAuth = &jwt.ExternalAuthorization{} @@ -3333,6 +3440,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err) } } + if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil { + s.Debugf("Error adding allow_trace for %q: %v", sub, err) + } } var revocationChanged *bool @@ -3470,10 +3580,15 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim if si != nil && si.acc.Name == a.Name { // Check for if we are still authorized for an import. si.invalid = !a.checkServiceImportAuthorized(acc, si.to, si.claim) - if si.latency != nil && !si.response { - // Make sure we should still be tracking latency. + // Make sure we should still be tracking latency and if we + // are allowed to trace. + if !si.response { if se := a.getServiceExport(si.to); se != nil { - si.latency = se.latency + if si.latency != nil { + si.latency = se.latency + } + // Update allow trace. + si.atrc = se.atrc } } } @@ -3567,6 +3682,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.updated = time.Now() clients := a.getClientsLocked() + ajs := a.js a.mu.Unlock() // Sort if we are over the limit. @@ -3591,6 +3707,26 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.enableAllJetStreamServiceImportsAndMappings() } + if ajs != nil { + // Check whether the account NRG status changed. If it has then we need to notify the + // Raft groups running on the system so that they can move their subs if needed. + a.mu.Lock() + previous := ajs.nrgAccount + switch ac.ClusterTraffic { + case "system", _EMPTY_: + ajs.nrgAccount = _EMPTY_ + case "owner": + ajs.nrgAccount = a.Name + default: + s.Errorf("Account claim for %q has invalid value %q for cluster traffic account", a.Name, ac.ClusterTraffic) + } + changed := ajs.nrgAccount != previous + a.mu.Unlock() + if changed { + s.updateNRGAccountStatus() + } + } + for i, c := range clients { a.mu.RLock() exceeded := a.mconns != jwt.NoLimit && i >= int(a.mconns) @@ -3906,6 +4042,25 @@ func (dr *DirAccResolver) Reload() error { return dr.DirJWTStore.Reload() } +// ServerAPIClaimUpdateResponse is the response to $SYS.REQ.ACCOUNT..CLAIMS.UPDATE and $SYS.REQ.CLAIMS.UPDATE +type ServerAPIClaimUpdateResponse struct { + Server *ServerInfo `json:"server"` + Data *ClaimUpdateStatus `json:"data,omitempty"` + Error *ClaimUpdateError `json:"error,omitempty"` +} + +type ClaimUpdateError struct { + Account string `json:"account,omitempty"` + Code int `json:"code"` + Description string `json:"description,omitempty"` +} + +type ClaimUpdateStatus struct { + Account string `json:"account,omitempty"` + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + func respondToUpdate(s *Server, respSubj string, acc string, message string, err error) { if err == nil { if acc == _EMPTY_ { @@ -3923,22 +4078,26 @@ func respondToUpdate(s *Server, respSubj string, acc string, message string, err if respSubj == _EMPTY_ { return } - server := &ServerInfo{} - response := map[string]interface{}{"server": server} - m := map[string]interface{}{} - if acc != _EMPTY_ { - m["account"] = acc + + response := ServerAPIClaimUpdateResponse{ + Server: &ServerInfo{}, } + if err == nil { - m["code"] = http.StatusOK - m["message"] = message - response["data"] = m + response.Data = &ClaimUpdateStatus{ + Account: acc, + Code: http.StatusOK, + Message: message, + } } else { - m["code"] = http.StatusInternalServerError - m["description"] = fmt.Sprintf("%s - %v", message, err) - response["error"] = m + response.Error = &ClaimUpdateError{ + Account: acc, + Code: http.StatusInternalServerError, + Description: fmt.Sprintf("%s - %v", message, err), + } } - s.sendInternalMsgLocked(respSubj, _EMPTY_, server, response) + + s.sendInternalMsgLocked(respSubj, _EMPTY_, response.Server, response) } func handleListRequest(store *DirJWTStore, s *Server, reply string) { diff --git a/vendor/github.com/nats-io/nats-server/v2/server/auth.go b/vendor/github.com/nats-io/nats-server/v2/server/auth.go index 5a1a4acd54..6578dee114 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/auth.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/auth.go @@ -417,6 +417,10 @@ func (c *client) matchesPinnedCert(tlsPinnedCerts PinnedCertSet) bool { return true } +var ( + mustacheRE = regexp.MustCompile(`{{2}([^}]+)}{2}`) +) + func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.UserClaims, acc *Account) (jwt.UserPermissionLimits, error) { nArrayCartesianProduct := func(a ...[]string) [][]string { c := 1 @@ -448,16 +452,26 @@ func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.User } return p } + isTag := func(op string) []string { + if strings.EqualFold("tag(", op[:4]) && strings.HasSuffix(op, ")") { + v := strings.TrimPrefix(op, "tag(") + v = strings.TrimSuffix(v, ")") + return []string{"tag", v} + } else if strings.EqualFold("account-tag(", op[:12]) && strings.HasSuffix(op, ")") { + v := strings.TrimPrefix(op, "account-tag(") + v = strings.TrimSuffix(v, ")") + return []string{"account-tag", v} + } + return nil + } applyTemplate := func(list jwt.StringList, failOnBadSubject bool) (jwt.StringList, error) { found := false FOR_FIND: for i := 0; i < len(list); i++ { // check if templates are present - for _, tk := range strings.Split(list[i], tsep) { - if strings.HasPrefix(tk, "{{") && strings.HasSuffix(tk, "}}") { - found = true - break FOR_FIND - } + if mustacheRE.MatchString(list[i]) { + found = true + break FOR_FIND } } if !found { @@ -466,94 +480,78 @@ func processUserPermissionsTemplate(lim jwt.UserPermissionLimits, ujwt *jwt.User // process the templates emittedList := make([]string, 0, len(list)) for i := 0; i < len(list); i++ { - tokens := strings.Split(list[i], tsep) - - newTokens := make([]string, len(tokens)) - tagValues := [][]string{} - + // find all the templates {{}} in this acl + tokens := mustacheRE.FindAllString(list[i], -1) + srcs := make([]string, len(tokens)) + values := make([][]string, len(tokens)) + hasTags := false for tokenNum, tk := range tokens { - if strings.HasPrefix(tk, "{{") && strings.HasSuffix(tk, "}}") { - op := strings.ToLower(strings.TrimSuffix(strings.TrimPrefix(tk, "{{"), "}}")) - switch { - case op == "name()": - tk = ujwt.Name - case op == "subject()": - tk = ujwt.Subject - case op == "account-name()": + srcs[tokenNum] = tk + op := strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(tk, "{{"), "}}")) + if strings.EqualFold("name()", op) { + values[tokenNum] = []string{ujwt.Name} + } else if strings.EqualFold("subject()", op) { + values[tokenNum] = []string{ujwt.Subject} + } else if strings.EqualFold("account-name()", op) { + acc.mu.RLock() + values[tokenNum] = []string{acc.nameTag} + acc.mu.RUnlock() + } else if strings.EqualFold("account-subject()", op) { + // this always has an issuer account since this is a scoped signer + values[tokenNum] = []string{ujwt.IssuerAccount} + } else if isTag(op) != nil { + hasTags = true + match := isTag(op) + var tags jwt.TagList + if match[0] == "account-tag" { acc.mu.RLock() - name := acc.nameTag + tags = acc.tags acc.mu.RUnlock() - tk = name - case op == "account-subject()": - tk = ujwt.IssuerAccount - case (strings.HasPrefix(op, "tag(") || strings.HasPrefix(op, "account-tag(")) && - strings.HasSuffix(op, ")"): - // insert dummy tav value that will throw of subject validation (in case nothing is found) - tk = _EMPTY_ - // collect list of matching tag values - - var tags jwt.TagList - var tagPrefix string - if strings.HasPrefix(op, "account-tag(") { - acc.mu.RLock() - tags = acc.tags - acc.mu.RUnlock() - tagPrefix = fmt.Sprintf("%s:", strings.ToLower( - strings.TrimSuffix(strings.TrimPrefix(op, "account-tag("), ")"))) - } else { - tags = ujwt.Tags - tagPrefix = fmt.Sprintf("%s:", strings.ToLower( - strings.TrimSuffix(strings.TrimPrefix(op, "tag("), ")"))) - } - - valueList := []string{} - for _, tag := range tags { - if strings.HasPrefix(tag, tagPrefix) { - tagValue := strings.TrimPrefix(tag, tagPrefix) - valueList = append(valueList, tagValue) - } - } - if len(valueList) != 0 { - tagValues = append(tagValues, valueList) - } - default: - // if macro is not recognized, throw off subject check on purpose - tk = " " + } else { + tags = ujwt.Tags } + tagPrefix := fmt.Sprintf("%s:", strings.ToLower(match[1])) + var valueList []string + for _, tag := range tags { + if strings.HasPrefix(tag, tagPrefix) { + tagValue := strings.TrimPrefix(tag, tagPrefix) + valueList = append(valueList, tagValue) + } + } + if len(valueList) != 0 { + values[tokenNum] = valueList + } else if failOnBadSubject { + return nil, fmt.Errorf("generated invalid subject %q: %q is not defined", list[i], match[1]) + } else { + // generate an invalid subject? + values[tokenNum] = []string{" "} + } + } else if failOnBadSubject { + return nil, fmt.Errorf("template operation in %q: %q is not defined", list[i], op) } - newTokens[tokenNum] = tk } - // fill in tag value placeholders - if len(tagValues) == 0 { - emitSubj := strings.Join(newTokens, tsep) - if IsValidSubject(emitSubj) { - emittedList = append(emittedList, emitSubj) + if !hasTags { + subj := list[i] + for idx, m := range srcs { + subj = strings.Replace(subj, m, values[idx][0], -1) + } + if IsValidSubject(subj) { + emittedList = append(emittedList, subj) } else if failOnBadSubject { return nil, fmt.Errorf("generated invalid subject") } - // else skip emitting } else { - // compute the cartesian product and compute subject to emit for each combination - for _, valueList := range nArrayCartesianProduct(tagValues...) { - b := strings.Builder{} - for i, token := range newTokens { - if token == _EMPTY_ && len(valueList) > 0 { - b.WriteString(valueList[0]) - valueList = valueList[1:] - } else { - b.WriteString(token) - } - if i != len(newTokens)-1 { - b.WriteString(tsep) - } + a := nArrayCartesianProduct(values...) + for _, aa := range a { + subj := list[i] + for j := 0; j < len(srcs); j++ { + subj = strings.Replace(subj, srcs[j], aa[j], -1) } - emitSubj := b.String() - if IsValidSubject(emitSubj) { - emittedList = append(emittedList, emitSubj) + if IsValidSubject(subj) { + emittedList = append(emittedList, subj) } else if failOnBadSubject { return nil, fmt.Errorf("generated invalid subject") } - // else skip emitting } } } @@ -606,13 +604,39 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au } return } - // We have a juc defined here, check account. + // We have a juc, check if externally managed, i.e. should be delegated + // to the auth callout service. if juc != nil && !acc.hasExternalAuth() { if !authorized { s.sendAccountAuthErrorEvent(c, c.acc, reason) } return } + // Check config-mode. The global account is a condition since users that + // are not found in the config are implicitly bound to the global account. + // This means those users should be implicitly delegated to auth callout + // if configured. Exclude LEAF connections from this check. + if c.kind != LEAF && juc == nil && opts.AuthCallout != nil && c.acc.Name != globalAccountName { + // If no allowed accounts are defined, then all accounts are in scope. + // Otherwise see if the account is in the list. + delegated := len(opts.AuthCallout.AllowedAccounts) == 0 + if !delegated { + for _, n := range opts.AuthCallout.AllowedAccounts { + if n == c.acc.Name { + delegated = true + break + } + } + } + + // Not delegated, so return with previous authorized result. + if !delegated { + if !authorized { + s.sendAccountAuthErrorEvent(c, c.acc, reason) + } + return + } + } // We have auth callout set here. var skip bool @@ -1471,7 +1495,8 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { switch ctuc { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/vendor/github.com/nats-io/nats-server/v2/server/client.go b/vendor/github.com/nats-io/nats-server/v2/server/client.go index 8cbf98e517..138164eb39 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/client.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/client.go @@ -144,6 +144,7 @@ const ( connectProcessFinished // Marks if this connection has finished the connect process. compressionNegotiated // Marks if this connection has negotiated compression level with remote. didTLSFirst // Marks if this connection requested and was accepted doing the TLS handshake first (prior to INFO). + isSlowConsumer // Marks connection as a slow consumer. ) // set the flag (would be equivalent to set the boolean to true) @@ -283,6 +284,7 @@ type client struct { trace bool echo bool noIcb bool + iproc bool // In-Process connection, set at creation and immutable. tags jwt.TagList nameTag string @@ -1703,9 +1705,11 @@ func (c *client) flushOutbound() bool { } // Ignore ErrShortWrite errors, they will be handled as partials. + var gotWriteTimeout bool if err != nil && err != io.ErrShortWrite { // Handle timeout error (slow consumer) differently if ne, ok := err.(net.Error); ok && ne.Timeout() { + gotWriteTimeout = true if closed := c.handleWriteTimeout(n, attempted, len(orig)); closed { return true } @@ -1743,6 +1747,11 @@ func (c *client) flushOutbound() bool { close(c.out.stc) c.out.stc = nil } + // Check if the connection is recovering from being a slow consumer. + if !gotWriteTimeout && c.flags.isSet(isSlowConsumer) { + c.Noticef("Slow Consumer Recovered: Flush took %.3fs with %d chunks of %d total bytes.", time.Since(start).Seconds(), len(orig), attempted) + c.flags.clear(isSlowConsumer) + } return true } @@ -1768,6 +1777,11 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo c.markConnAsClosed(SlowConsumerWriteDeadline) return true } + alreadySC := c.flags.isSet(isSlowConsumer) + scState := "Detected" + if alreadySC { + scState = "State" + } // Aggregate slow consumers. atomic.AddInt64(&c.srv.slowConsumers, 1) @@ -1775,7 +1789,10 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo case CLIENT: c.srv.scStats.clients.Add(1) case ROUTER: - c.srv.scStats.routes.Add(1) + // Only count each Slow Consumer event once. + if !alreadySC { + c.srv.scStats.routes.Add(1) + } case GATEWAY: c.srv.scStats.gateways.Add(1) case LEAF: @@ -1784,13 +1801,15 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo if c.acc != nil { atomic.AddInt64(&c.acc.slowConsumers, 1) } - c.Noticef("Slow Consumer Detected: WriteDeadline of %v exceeded with %d chunks of %d total bytes.", - c.out.wdl, numChunks, attempted) + c.Noticef("Slow Consumer %s: WriteDeadline of %v exceeded with %d chunks of %d total bytes.", + scState, c.out.wdl, numChunks, attempted) // We always close CLIENT connections, or when nothing was written at all... if c.kind == CLIENT || written == 0 { c.markConnAsClosed(SlowConsumerWriteDeadline) return true + } else { + c.flags.setIfNotSet(isSlowConsumer) } return false } @@ -2064,10 +2083,26 @@ func (c *client) processConnect(arg []byte) error { } } - // If websocket client and JWT not in the CONNECT, use the cookie JWT (possibly empty). - if ws := c.ws; ws != nil && c.opts.JWT == "" { - c.opts.JWT = ws.cookieJwt + // if websocket client, maybe some options through cookies + if ws := c.ws; ws != nil { + // if JWT not in the CONNECT, use the cookie JWT (possibly empty). + if c.opts.JWT == _EMPTY_ { + c.opts.JWT = ws.cookieJwt + } + // if user not in the CONNECT, use the cookie user (possibly empty) + if c.opts.Username == _EMPTY_ { + c.opts.Username = ws.cookieUsername + } + // if pass not in the CONNECT, use the cookie password (possibly empty). + if c.opts.Password == _EMPTY_ { + c.opts.Password = ws.cookiePassword + } + // if token not in the CONNECT, use the cookie token (possibly empty). + if c.opts.Token == _EMPTY_ { + c.opts.Token = ws.cookieToken + } } + // when not in operator mode, discard the jwt if srv != nil && srv.trustedKeys == nil { c.opts.JWT = _EMPTY_ @@ -2526,7 +2561,7 @@ func (c *client) msgParts(data []byte) (hdr []byte, msg []byte) { } // Header pubs take form HPUB [reply] \r\n -func (c *client) processHeaderPub(arg []byte) error { +func (c *client) processHeaderPub(arg, remaining []byte) error { if !c.headers { return ErrMsgHeadersNotSupported } @@ -2584,6 +2619,16 @@ func (c *client) processHeaderPub(arg []byte) error { maxPayload := atomic.LoadInt32(&c.mpay) // Use int64() to avoid int32 overrun... if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) { + // If we are given the remaining read buffer (since we do blind reads + // we may have the beginning of the message header/payload), we will + // look for the tracing header and if found, we will generate a + // trace event with the max payload ingress error. + // Do this only for CLIENT connections. + if c.kind == CLIENT && len(remaining) > 0 { + if td := getHeader(MsgTraceDest, remaining); len(td) > 0 { + c.initAndSendIngressErrEvent(remaining, string(td), ErrMaxPayload) + } + } c.maxPayloadViolation(c.pa.size, maxPayload) return ErrMaxPayload } @@ -3386,23 +3431,33 @@ var needFlush = struct{}{} // deliverMsg will deliver a message to a matching subscription and its underlying client. // We process all connection/client types. mh is the part that will be protocol/client specific. func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, subject, reply, mh, msg []byte, gwrply bool) bool { + // Check if message tracing is enabled. + mt, traceOnly := c.isMsgTraceEnabled() + + client := sub.client // Check sub client and check echo. Only do this if not a service import. - if sub.client == nil || (c == sub.client && !sub.client.echo && !sub.si) { + if client == nil || (c == client && !client.echo && !sub.si) { + if client != nil && mt != nil { + client.mu.Lock() + mt.addEgressEvent(client, sub, errMsgTraceNoEcho) + client.mu.Unlock() + } return false } - client := sub.client client.mu.Lock() // Check if we have a subscribe deny clause. This will trigger us to check the subject // for a match against the denied subjects. if client.mperms != nil && client.checkDenySub(string(subject)) { + mt.addEgressEvent(client, sub, errMsgTraceSubDeny) client.mu.Unlock() return false } // New race detector forces this now. if sub.isClosed() { + mt.addEgressEvent(client, sub, errMsgTraceSubClosed) client.mu.Unlock() return false } @@ -3410,15 +3465,56 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // Check if we are a leafnode and have perms to check. if client.kind == LEAF && client.perms != nil { if !client.pubAllowedFullCheck(string(subject), true, true) { + mt.addEgressEvent(client, sub, errMsgTracePubViolation) client.mu.Unlock() client.Debugf("Not permitted to deliver to %q", subject) return false } } + var mtErr string + if mt != nil { + // For non internal subscription, and if the remote does not support + // the tracing feature... + if sub.icb == nil && !client.msgTraceSupport() { + if traceOnly { + // We are not sending the message at all because the user + // expects a trace-only and the remote does not support + // tracing, which means that it would process/deliver this + // message, which may break applications. + // Add the Egress with the no-support error message. + mt.addEgressEvent(client, sub, errMsgTraceOnlyNoSupport) + client.mu.Unlock() + return false + } + // If we are doing delivery, we will still forward the message, + // but we add an error to the Egress event to hint that one should + // not expect a tracing event from that remote. + mtErr = errMsgTraceNoSupport + } + // For ROUTER, GATEWAY and LEAF, even if we intend to do tracing only, + // we will still deliver the message. The remote side will + // generate an event based on what happened on that server. + if traceOnly && (client.kind == ROUTER || client.kind == GATEWAY || client.kind == LEAF) { + traceOnly = false + } + // If we skip delivery and this is not for a service import, we are done. + if traceOnly && (sub.icb == nil || c.noIcb) { + mt.addEgressEvent(client, sub, _EMPTY_) + client.mu.Unlock() + // Although the message is not actually delivered, for the + // purpose of "didDeliver", we need to return "true" here. + return true + } + } + srv := client.srv - sub.nm++ + // We don't want to bump the number of delivered messages to the subscription + // if we are doing trace-only (since really we are not sending it to the sub). + if !traceOnly { + sub.nm++ + } // Check if we should auto-unsubscribe. if sub.max > 0 { @@ -3442,6 +3538,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su defer client.unsubscribe(client.acc, sub, true, true) } else if sub.nm > sub.max { client.Debugf("Auto-unsubscribe limit [%d] exceeded", sub.max) + mt.addEgressEvent(client, sub, errMsgTraceAutoSubExceeded) client.mu.Unlock() client.unsubscribe(client.acc, sub, true, true) if shouldForward { @@ -3472,7 +3569,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // We do not update the outbound stats if we are doing trace only since // this message will not be sent out. // Also do not update on internal callbacks. - if sub.icb == nil { + if !traceOnly && sub.icb == nil { // No atomic needed since accessed under client lock. // Monitor is reading those also under client's lock. client.outMsgs++ @@ -3514,6 +3611,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // with a limit. if c.kind == CLIENT && client.out.stc != nil { if srv.getOpts().NoFastProducerStall { + mt.addEgressEvent(client, sub, errMsgTraceFastProdNoStall) client.mu.Unlock() return false } @@ -3522,10 +3620,17 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // Check for closed connection if client.isClosed() { + mt.addEgressEvent(client, sub, errMsgTraceClientClosed) client.mu.Unlock() return false } + // We have passed cases where we could possibly fail to deliver. + // Do not call for service-import. + if mt != nil && sub.icb == nil { + mt.addEgressEvent(client, sub, mtErr) + } + // Do a fast check here to see if we should be tracking this from a latency // perspective. This will be for a request being received for an exported service. // This needs to be from a non-client (otherwise tracking happens at requestor). @@ -3715,24 +3820,34 @@ func (c *client) pruneDenyCache() { // prunePubPermsCache will prune the cache via randomly // deleting items. Doing so pruneSize items at a time. func (c *client) prunePubPermsCache() { - // There is a case where we can invoke this from multiple go routines, - // (in deliverMsg() if sub.client is a LEAF), so we make sure to prune - // from only one go routine at a time. - if !atomic.CompareAndSwapInt32(&c.perms.prun, 0, 1) { - return - } - const maxPruneAtOnce = 1000 - r := 0 - c.perms.pcache.Range(func(k, _ any) bool { - c.perms.pcache.Delete(k) - if r++; (r > pruneSize && atomic.LoadInt32(&c.perms.pcsz) < int32(maxPermCacheSize)) || - (r > maxPruneAtOnce) { - return false + // With parallel additions to the cache, it is possible that this function + // would not be able to reduce the cache to its max size in one go. We + // will try a few times but will release/reacquire the "lock" at each + // attempt to give a chance to another go routine to take over and not + // have this go routine do too many attempts. + for i := 0; i < 5; i++ { + // There is a case where we can invoke this from multiple go routines, + // (in deliverMsg() if sub.client is a LEAF), so we make sure to prune + // from only one go routine at a time. + if !atomic.CompareAndSwapInt32(&c.perms.prun, 0, 1) { + return } - return true - }) - atomic.AddInt32(&c.perms.pcsz, -int32(r)) - atomic.StoreInt32(&c.perms.prun, 0) + const maxPruneAtOnce = 1000 + r := 0 + c.perms.pcache.Range(func(k, _ any) bool { + c.perms.pcache.Delete(k) + if r++; (r > pruneSize && atomic.LoadInt32(&c.perms.pcsz) < int32(maxPermCacheSize)) || + (r > maxPruneAtOnce) { + return false + } + return true + }) + n := atomic.AddInt32(&c.perms.pcsz, -int32(r)) + atomic.StoreInt32(&c.perms.prun, 0) + if n <= int32(maxPermCacheSize) { + return + } + } } // pubAllowed checks on publish permissioning. @@ -3841,6 +3956,10 @@ func (c *client) selectMappedSubject() bool { return changed } +// clientNRGPrefix is used in processInboundClientMsg to detect if publishes +// are being made from normal clients to NRG subjects. +var clientNRGPrefix = []byte("$NRG.") + // processInboundClientMsg is called to process an inbound msg from a client. // Return if the message was delivered, and if the message was not delivered // due to a permission issue. @@ -3873,6 +3992,13 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { } c.mu.Unlock() + // Check if the client is trying to publish to reserved NRG subjects. + // Doesn't apply to NRGs themselves as they use SYSTEM-kind clients instead. + if c.kind == CLIENT && bytes.HasPrefix(c.pa.subject, clientNRGPrefix) && acc != c.srv.SystemAccount() { + c.pubPermissionViolation(c.pa.subject) + return false, true + } + // Now check for reserved replies. These are used for service imports. if c.kind == CLIENT && len(c.pa.reply) > 0 && isReservedReply(c.pa.reply) { c.replySubjectViolation(c.pa.reply) @@ -4244,6 +4370,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt } } siAcc := si.acc + allowTrace := si.atrc acc.mu.RUnlock() // We have a special case where JetStream pulls in all service imports through one export. @@ -4254,6 +4381,8 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt return false } + mt, traceOnly := c.isMsgTraceEnabled() + var nrr []byte var rsi *serviceImport @@ -4382,17 +4511,42 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt var lrts [routeTargetInit]routeTarget c.in.rts = lrts[:0] + var skipProcessing bool + // If message tracing enabled, add the service import trace. + if mt != nil { + mt.addServiceImportEvent(siAcc.GetName(), string(pacopy.subject), to) + // If we are not allowing tracing and doing trace only, we stop at this level. + if !allowTrace { + if traceOnly { + skipProcessing = true + } else { + // We are going to do normal processing, and possibly chainning + // with other server imports, but the rest won't be traced. + // We do so by setting the c.pa.trace to nil (it will be restored + // with c.pa = pacopy). + c.pa.trace = nil + // We also need to disable the message trace headers so that + // if the message is routed, it does not initialize tracing in the + // remote. + positions := disableTraceHeaders(c, msg) + defer enableTraceHeaders(msg, positions) + } + } + } + var didDeliver bool - // If this is not a gateway connection but gateway is enabled, - // try to send this converted message to all gateways. - if c.srv.gateway.enabled { - flags |= pmrCollectQueueNames - var queues [][]byte - didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) - didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues, false) || didDeliver - } else { - didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + if !skipProcessing { + // If this is not a gateway connection but gateway is enabled, + // try to send this converted message to all gateways. + if c.srv.gateway.enabled { + flags |= pmrCollectQueueNames + var queues [][]byte + didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues, false) || didDeliver + } else { + didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + } } // Restore to original values. @@ -4403,6 +4557,12 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt // If we override due to tracing and traceOnly we do not want to send back a no responders. c.pa.delivered = didDeliver + // If this was a message trace but we skip last-mile delivery, we need to + // do the remove, so: + if mt != nil && traceOnly && didDeliver { + didDeliver = false + } + // Determine if we should remove this service import. This is for response service imports. // We will remove if we did not deliver, or if we are a response service import and we are // a singleton, or we have an EOF message. @@ -4551,6 +4711,8 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } } + mt, traceOnly := c.isMsgTraceEnabled() + // Loop over all normal subscriptions that match. for _, sub := range r.psubs { // Check if this is a send to a ROUTER. We now process @@ -4579,6 +4741,11 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, // Assume delivery subject is the normal subject to this point. dsubj = subj + // We may need to disable tracing, by setting c.pa.trace to `nil` + // before the call to deliverMsg, if so, this will indicate that + // we need to put it back. + var restorePaTrace bool + // Check for stream import mapped subs (shadow subs). These apply to local subs only. if sub.im != nil { // If this message was a service import do not re-export to an exported stream. @@ -4594,6 +4761,25 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, dsubj = append(_dsubj[:0], sub.im.to...) } + if mt != nil { + mt.addStreamExportEvent(sub.client, dsubj) + // If allow_trace is false... + if !sub.im.atrc { + // If we are doing only message tracing, we can move to the + // next sub. + if traceOnly { + // Although the message was not delivered, for the purpose + // of didDeliver, we need to set to true (to avoid possible + // no responders). + didDeliver = true + continue + } + // If we are delivering the message, we need to disable tracing + // before calling deliverMsg(). + c.pa.trace, restorePaTrace = nil, true + } + } + // Make sure deliver is set if inbound from a route. if remapped && (c.kind == GATEWAY || c.kind == ROUTER || c.kind == LEAF) { deliver = subj @@ -4620,6 +4806,9 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } didDeliver = true } + if restorePaTrace { + c.pa.trace = mt + } } // Set these up to optionally filter based on the queue lists. @@ -4774,6 +4963,13 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, // Assume delivery subject is normal subject to this point. dsubj = subj + + // We may need to disable tracing, by setting c.pa.trace to `nil` + // before the call to deliverMsg, if so, this will indicate that + // we need to put it back. + var restorePaTrace bool + var skipDelivery bool + // Check for stream import mapped subs. These apply to local subs only. if sub.im != nil { // If this message was a service import do not re-export to an exported stream. @@ -4788,6 +4984,23 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } else { dsubj = append(_dsubj[:0], sub.im.to...) } + + if mt != nil { + mt.addStreamExportEvent(sub.client, dsubj) + // If allow_trace is false... + if !sub.im.atrc { + // If we are doing only message tracing, we are done + // with this queue group. + if traceOnly { + skipDelivery = true + } else { + // If we are delivering, we need to disable tracing + // before the call to deliverMsg() + c.pa.trace, restorePaTrace = nil, true + } + } + } + // Make sure deliver is set if inbound from a route. if remapped && (c.kind == GATEWAY || c.kind == ROUTER || c.kind == LEAF) { deliver = subj @@ -4800,11 +5013,20 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } } - mh := c.msgHeader(dsubj, creply, sub) - if c.deliverMsg(prodIsMQTT, sub, acc, subject, creply, mh, msg, rplyHasGWPrefix) { - if sub.icb == nil { + var delivered bool + if !skipDelivery { + mh := c.msgHeader(dsubj, creply, sub) + delivered = c.deliverMsg(prodIsMQTT, sub, acc, subject, creply, mh, msg, rplyHasGWPrefix) + if restorePaTrace { + c.pa.trace = mt + } + } + if skipDelivery || delivered { + // Update only if not skipped. + if !skipDelivery && sub.icb == nil { dlvMsgs++ } + // Do the rest even when message delivery was skipped. didDeliver = true // Clear rsub rsub = nil @@ -4845,6 +5067,16 @@ sendToRoutesOrLeafs: // Copy off original pa in case it changes. pa := c.pa + if mt != nil { + // We are going to replace "pa" with our copy of c.pa, but to restore + // to the original copy of c.pa, we need to save it again. + cpa := pa + msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg) + defer func() { c.pa = cpa }() + // Update pa with our current c.pa state. + pa = c.pa + } + // We address by index to avoid struct copy. // We have inline structs for memory layout and cache coherency. for i := range c.in.rts { @@ -4874,6 +5106,11 @@ sendToRoutesOrLeafs: } } + if mt != nil { + dmsg = mt.setHopHeader(c, dmsg) + hset = true + } + mh := c.msgHeaderForRouteOrLeaf(subject, reply, rt, acc) if c.deliverMsg(prodIsMQTT, rt.sub, acc, subject, reply, mh, dmsg, false) { if rt.sub.icb == nil { @@ -4920,7 +5157,11 @@ func (c *client) checkLeafClientInfoHeader(msg []byte) (dmsg []byte, setHdr bool } func (c *client) pubPermissionViolation(subject []byte) { - c.sendErr(fmt.Sprintf("Permissions Violation for Publish to %q", subject)) + errTxt := fmt.Sprintf("Permissions Violation for Publish to %q", subject) + if mt, _ := c.isMsgTraceEnabled(); mt != nil { + mt.setIngressError(errTxt) + } + c.sendErr(errTxt) c.Errorf("Publish Violation - %s, Subject %q", c.getAuthUser(), subject) } @@ -4940,7 +5181,11 @@ func (c *client) subPermissionViolation(sub *subscription) { } func (c *client) replySubjectViolation(reply []byte) { - c.sendErr(fmt.Sprintf("Permissions Violation for Publish with Reply of %q", reply)) + errTxt := fmt.Sprintf("Permissions Violation for Publish with Reply of %q", reply) + if mt, _ := c.isMsgTraceEnabled(); mt != nil { + mt.setIngressError(errTxt) + } + c.sendErr(errTxt) c.Errorf("Publish Violation - %s, Reply %q", c.getAuthUser(), reply) } @@ -5874,7 +6119,8 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { switch i { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -5901,7 +6147,11 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { case CLIENT: switch c.clientType() { case NATS: - want = jwt.ConnectionTypeStandard + if c.iproc { + want = jwt.ConnectionTypeInProcess + } else { + want = jwt.ConnectionTypeStandard + } case WS: want = jwt.ConnectionTypeWebsocket case MQTT: diff --git a/vendor/github.com/nats-io/nats-server/v2/server/const.go b/vendor/github.com/nats-io/nats-server/v2/server/const.go index 1a6be0b2f6..ff78445b77 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/const.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/const.go @@ -1,4 +1,4 @@ -// Copyright 2012-2025 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -14,6 +14,7 @@ package server import ( + "regexp" "runtime/debug" "time" ) @@ -38,6 +39,8 @@ var ( gitCommit, serverVersion string // trustedKeys is a whitespace separated array of trusted operator's public nkeys. trustedKeys string + // SemVer regexp to validate the VERSION. + semVerRe = regexp.MustCompile(`^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$`) ) func init() { @@ -55,7 +58,7 @@ func init() { const ( // VERSION is the current version for the server. - VERSION = "2.10.26" + VERSION = "2.11.0" // PROTO is the currently supported protocol. // 0 was the original diff --git a/vendor/github.com/nats-io/nats-server/v2/server/consumer.go b/vendor/github.com/nats-io/nats-server/v2/server/consumer.go index 83f2f3ce83..740806190e 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/consumer.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/consumer.go @@ -21,6 +21,7 @@ import ( "fmt" "math/rand" "reflect" + "regexp" "slices" "strconv" "strings" @@ -37,6 +38,12 @@ import ( const ( JSPullRequestPendingMsgs = "Nats-Pending-Messages" JSPullRequestPendingBytes = "Nats-Pending-Bytes" + JSPullRequestWrongPinID = "NATS/1.0 423 Nats-Wrong-Pin-Id\r\n\r\n" + JSPullRequestNatsPinId = "Nats-Pin-Id" +) + +var ( + validGroupName = regexp.MustCompile(`^[a-zA-Z0-9/_=-]{1,16}$`) ) // Headers sent when batch size was completed, but there were remaining bytes. @@ -55,8 +62,17 @@ type ConsumerInfo struct { NumPending uint64 `json:"num_pending"` Cluster *ClusterInfo `json:"cluster,omitempty"` PushBound bool `json:"push_bound,omitempty"` + Paused bool `json:"paused,omitempty"` + PauseRemaining time.Duration `json:"pause_remaining,omitempty"` // TimeStamp indicates when the info was gathered - TimeStamp time.Time `json:"ts"` + TimeStamp time.Time `json:"ts"` + PriorityGroups []PriorityGroupState `json:"priority_groups,omitempty"` +} + +type PriorityGroupState struct { + Group string `json:"group"` + PinnedClientID string `json:"pinned_client_id,omitempty"` + PinnedTS time.Time `json:"pinned_ts,omitempty"` } type ConsumerConfig struct { @@ -77,7 +93,6 @@ type ConsumerConfig struct { SampleFrequency string `json:"sample_freq,omitempty"` MaxWaiting int `json:"max_waiting,omitempty"` MaxAckPending int `json:"max_ack_pending,omitempty"` - Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` FlowControl bool `json:"flow_control,omitempty"` HeadersOnly bool `json:"headers_only,omitempty"` @@ -87,8 +102,9 @@ type ConsumerConfig struct { MaxRequestMaxBytes int `json:"max_bytes,omitempty"` // Push based consumers. - DeliverSubject string `json:"deliver_subject,omitempty"` - DeliverGroup string `json:"deliver_group,omitempty"` + DeliverSubject string `json:"deliver_subject,omitempty"` + DeliverGroup string `json:"deliver_group,omitempty"` + Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` // Ephemeral inactivity threshold. InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"` @@ -103,6 +119,14 @@ type ConsumerConfig struct { // Metadata is additional metadata for the Consumer. Metadata map[string]string `json:"metadata,omitempty"` + + // PauseUntil is for suspending the consumer until the deadline. + PauseUntil *time.Time `json:"pause_until,omitempty"` + + // Priority groups + PriorityGroups []string `json:"priority_groups,omitempty"` + PriorityPolicy PriorityPolicy `json:"priority_policy,omitempty"` + PinnedTTL time.Duration `json:"priority_timeout,omitempty"` } // SequenceInfo has both the consumer and the stream sequence and last activity. @@ -113,9 +137,10 @@ type SequenceInfo struct { } type CreateConsumerRequest struct { - Stream string `json:"stream_name"` - Config ConsumerConfig `json:"config"` - Action ConsumerAction `json:"action"` + Stream string `json:"stream_name"` + Config ConsumerConfig `json:"config"` + Action ConsumerAction `json:"action"` + Pedantic bool `json:"pedantic,omitempty"` } type ConsumerAction int @@ -182,6 +207,68 @@ type ConsumerNakOptions struct { Delay time.Duration `json:"delay"` } +// PriorityPolicy determines policy for selecting messages based on priority. +type PriorityPolicy int + +const ( + // No priority policy. + PriorityNone PriorityPolicy = iota + // Clients will get the messages only if certain criteria are specified. + PriorityOverflow + // Single client takes over handling of the messages, while others are on standby. + PriorityPinnedClient +) + +const ( + PriorityNoneJSONString = `"none"` + PriorityOverflowJSONString = `"overflow"` + PriorityPinnedClientJSONString = `"pinned_client"` +) + +var ( + PriorityNoneJSONBytes = []byte(PriorityNoneJSONString) + PriorityOverflowJSONBytes = []byte(PriorityOverflowJSONString) + PriorityPinnedClientJSONBytes = []byte(PriorityPinnedClientJSONString) +) + +func (pp PriorityPolicy) String() string { + switch pp { + case PriorityOverflow: + return PriorityOverflowJSONString + case PriorityPinnedClient: + return PriorityPinnedClientJSONString + default: + return PriorityNoneJSONString + } +} + +func (pp PriorityPolicy) MarshalJSON() ([]byte, error) { + switch pp { + case PriorityOverflow: + return PriorityOverflowJSONBytes, nil + case PriorityPinnedClient: + return PriorityPinnedClientJSONBytes, nil + case PriorityNone: + return PriorityNoneJSONBytes, nil + default: + return nil, fmt.Errorf("unknown priority policy: %v", pp) + } +} + +func (pp *PriorityPolicy) UnmarshalJSON(data []byte) error { + switch string(data) { + case PriorityOverflowJSONString: + *pp = PriorityOverflow + case PriorityPinnedClientJSONString: + *pp = PriorityPinnedClient + case PriorityNoneJSONString: + *pp = PriorityNone + default: + return fmt.Errorf("unknown priority policy: %v", string(data)) + } + return nil +} + // DeliverPolicy determines how the consumer should select the first message to deliver. type DeliverPolicy int @@ -357,11 +444,12 @@ type consumer struct { active bool replay bool dtmr *time.Timer + uptmr *time.Timer // Unpause timer gwdtmr *time.Timer dthresh time.Duration - mch chan struct{} - qch chan struct{} - inch chan bool + mch chan struct{} // Message channel + qch chan struct{} // Quit channel + inch chan bool // Interest change channel sfreq int32 ackEventT string nakEventT string @@ -395,6 +483,17 @@ type consumer struct { // for stream signaling when multiple filters are set. sigSubs []string + + // Priority groups + // Details described in ADR-42. + + // currentPinId is the current nuid for the pinned consumer. + // If the Consumer is running in `PriorityPinnedClient` mode, server will + // pick up a new nuid and assign it to first pending pull request. + currentPinId string + /// pinnedTtl is the remaining time before the current PinId expires. + pinnedTtl *time.Timer + pinnedTS time.Time } // A single subject filter. @@ -431,10 +530,13 @@ const ( JsFlowControlMaxPending = 32 * 1024 * 1024 // JsDefaultMaxAckPending is set for consumers with explicit ack that do not set the max ack pending. JsDefaultMaxAckPending = 1000 + // JsDefaultPinnedTTL is the default grace period for the pinned consumer to send a new request before a new pin + // is picked by a server. + JsDefaultPinnedTTL = 2 * time.Minute ) // Helper function to set consumer config defaults from above. -func setConsumerConfigDefaults(config *ConsumerConfig, streamCfg *StreamConfig, lim *JSLimitOpts, accLim *JetStreamAccountLimits) { +func setConsumerConfigDefaults(config *ConsumerConfig, streamCfg *StreamConfig, lim *JSLimitOpts, accLim *JetStreamAccountLimits, pedantic bool) *ApiError { // Set to default if not specified. if config.DeliverSubject == _EMPTY_ && config.MaxWaiting == 0 { config.MaxWaiting = JSWaitQueueDefaultMax @@ -449,12 +551,21 @@ func setConsumerConfigDefaults(config *ConsumerConfig, streamCfg *StreamConfig, } // If BackOff was specified that will override the AckWait and the MaxDeliver. if len(config.BackOff) > 0 { + if pedantic && config.AckWait != config.BackOff[0] { + return NewJSPedanticError(errors.New("first backoff value has to equal batch AckWait")) + } config.AckWait = config.BackOff[0] } if config.MaxAckPending == 0 { + if pedantic && streamCfg.ConsumerLimits.MaxAckPending > 0 { + return NewJSPedanticError(errors.New("max_ack_pending must be set if it's configured in stream limits")) + } config.MaxAckPending = streamCfg.ConsumerLimits.MaxAckPending } if config.InactiveThreshold == 0 { + if pedantic && streamCfg.ConsumerLimits.InactiveThreshold > 0 { + return NewJSPedanticError(errors.New("inactive_threshold must be set if it's configured in stream limits")) + } config.InactiveThreshold = streamCfg.ConsumerLimits.InactiveThreshold } // Set proper default for max ack pending if we are ack explicit and none has been set. @@ -470,8 +581,17 @@ func setConsumerConfigDefaults(config *ConsumerConfig, streamCfg *StreamConfig, } // if applicable set max request batch size if config.DeliverSubject == _EMPTY_ && config.MaxRequestBatch == 0 && lim.MaxRequestBatch > 0 { + if pedantic { + return NewJSPedanticError(errors.New("max_request_batch must be set if it's JetStream limits are set")) + } config.MaxRequestBatch = lim.MaxRequestBatch } + + // set the default value only if pinned policy is used. + if config.PriorityPolicy == PriorityPinnedClient && config.PinnedTTL == 0 { + config.PinnedTTL = JsDefaultPinnedTTL + } + return nil } // Check the consumer config. If we are recovering don't check filter subjects. @@ -696,18 +816,38 @@ func checkConsumerCfg( return NewJSConsumerMetadataLengthError(fmt.Sprintf("%dKB", JSMaxMetadataLen/1024)) } + if config.PriorityPolicy != PriorityNone { + if len(config.PriorityGroups) == 0 { + return NewJSConsumerPriorityPolicyWithoutGroupError() + } + + for _, group := range config.PriorityGroups { + if group == _EMPTY_ { + return NewJSConsumerEmptyGroupNameError() + } + if !validGroupName.MatchString(group) { + return NewJSConsumerInvalidGroupNameError() + } + } + } + + // For now don't allow preferred server in placement. + if cfg.Placement != nil && cfg.Placement.Preferred != _EMPTY_ { + return NewJSStreamInvalidConfigError(fmt.Errorf("preferred server not permitted in placement")) + } + return nil } -func (mset *stream) addConsumerWithAction(config *ConsumerConfig, action ConsumerAction) (*consumer, error) { - return mset.addConsumerWithAssignment(config, _EMPTY_, nil, false, action) +func (mset *stream) addConsumerWithAction(config *ConsumerConfig, action ConsumerAction, pedantic bool) (*consumer, error) { + return mset.addConsumerWithAssignment(config, _EMPTY_, nil, false, action, pedantic) } func (mset *stream) addConsumer(config *ConsumerConfig) (*consumer, error) { - return mset.addConsumerWithAction(config, ActionCreateOrUpdate) + return mset.addConsumerWithAction(config, ActionCreateOrUpdate, false) } -func (mset *stream) addConsumerWithAssignment(config *ConsumerConfig, oname string, ca *consumerAssignment, isRecovering bool, action ConsumerAction) (*consumer, error) { +func (mset *stream) addConsumerWithAssignment(config *ConsumerConfig, oname string, ca *consumerAssignment, isRecovering bool, action ConsumerAction, pedantic bool) (*consumer, error) { // Check if this stream has closed. if mset.closed.Load() { return nil, NewJSStreamInvalidError() @@ -737,8 +877,11 @@ func (mset *stream) addConsumerWithAssignment(config *ConsumerConfig, oname stri // Make sure we have sane defaults. Do so with the JS lock, otherwise a // badly timed meta snapshot can result in a race condition. mset.js.mu.Lock() - setConsumerConfigDefaults(config, &cfg, srvLim, selectedLimits) + err := setConsumerConfigDefaults(config, &cfg, srvLim, selectedLimits, pedantic) mset.js.mu.Unlock() + if err != nil { + return nil, err + } if err := checkConsumerCfg(config, srvLim, &cfg, acc, selectedLimits, isRecovering); err != nil { return nil, err @@ -1080,6 +1223,34 @@ func (o *consumer) updateInactiveThreshold(cfg *ConsumerConfig) { } } +// Updates the paused state. If we are the leader and the pause deadline +// hasn't passed yet then we will start a timer to kick the consumer once +// that deadline is reached. Lock should be held. +func (o *consumer) updatePauseState(cfg *ConsumerConfig) { + if o.uptmr != nil { + stopAndClearTimer(&o.uptmr) + } + if !o.isLeader() { + // Only the leader will run the timer as only the leader will run + // loopAndGatherMsgs. + return + } + if cfg.PauseUntil == nil || cfg.PauseUntil.IsZero() || cfg.PauseUntil.Before(time.Now()) { + // Either the PauseUntil is unset (is effectively zero) or the + // deadline has already passed, in which case there is nothing + // to do. + return + } + o.uptmr = time.AfterFunc(time.Until(*cfg.PauseUntil), func() { + o.mu.Lock() + defer o.mu.Unlock() + + stopAndClearTimer(&o.uptmr) + o.sendPauseAdvisoryLocked(&o.cfg) + o.signalNewMessages() + }) +} + func (o *consumer) consumerAssignment() *consumerAssignment { o.mu.RLock() defer o.mu.RUnlock() @@ -1202,8 +1373,14 @@ func (o *consumer) setLeader(isLeader bool) { o.rdq = nil o.rdqi.Empty() - // Restore our saved state. During non-leader status we just update our underlying store. - o.readStoredState(lseq) + // Restore our saved state. + // During non-leader status we just update our underlying store when not clustered. + // If clustered we need to propose our initial (possibly skipped ahead) o.sseq to the group. + if o.node == nil || o.dseq > 1 || (o.store != nil && o.store.HasState()) { + o.readStoredState(lseq) + } else if o.node != nil && o.sseq >= 1 { + o.updateSkipped(o.sseq) + } // Setup initial num pending. o.streamNumPending() @@ -1213,11 +1390,6 @@ func (o *consumer) setLeader(isLeader bool) { o.lss = nil } - // Update the group on the our starting sequence if we are starting but we skipped some in the stream. - if o.dseq == 1 && o.sseq > 1 { - o.updateSkipped(o.sseq) - } - // Do info sub. if o.infoSub == nil && jsa != nil { isubj := fmt.Sprintf(clusterConsumerInfoT, jsa.acc(), stream, o.name) @@ -1277,6 +1449,9 @@ func (o *consumer) setLeader(isLeader bool) { o.dtmr = time.AfterFunc(o.dthresh, o.deleteNotActive) } + // Update the consumer pause tracking. + o.updatePauseState(&o.cfg) + // If we are not in ReplayInstant mode mark us as in replay state until resolved. if o.cfg.ReplayPolicy != ReplayInstant { o.replay = true @@ -1347,7 +1522,8 @@ func (o *consumer) setLeader(isLeader bool) { } // Stop any inactivity timers. Should only be running on leaders. stopAndClearTimer(&o.dtmr) - + // Stop any unpause timers. Should only be running on leaders. + stopAndClearTimer(&o.uptmr) // Make sure to clear out any re-deliver queues o.stopAndClearPtmr() o.rdq = nil @@ -1452,6 +1628,45 @@ func (o *consumer) sendDeleteAdvisoryLocked() { o.sendAdvisory(subj, e) } +func (o *consumer) sendPinnedAdvisoryLocked(group string) { + e := JSConsumerGroupPinnedAdvisory{ + TypedEvent: TypedEvent{ + Type: JSConsumerGroupPinnedAdvisoryType, + ID: nuid.Next(), + Time: time.Now().UTC(), + }, + Account: o.acc.Name, + Stream: o.stream, + Consumer: o.name, + Domain: o.srv.getOpts().JetStreamDomain, + PinnedClientId: o.currentPinId, + Group: group, + } + + subj := JSAdvisoryConsumerPinnedPre + "." + o.stream + "." + o.name + o.sendAdvisory(subj, e) + +} +func (o *consumer) sendUnpinnedAdvisoryLocked(group string, reason string) { + e := JSConsumerGroupUnpinnedAdvisory{ + TypedEvent: TypedEvent{ + Type: JSConsumerGroupUnpinnedAdvisoryType, + ID: nuid.Next(), + Time: time.Now().UTC(), + }, + Account: o.acc.Name, + Stream: o.stream, + Consumer: o.name, + Domain: o.srv.getOpts().JetStreamDomain, + Group: group, + Reason: reason, + } + + subj := JSAdvisoryConsumerUnpinnedPre + "." + o.stream + "." + o.name + o.sendAdvisory(subj, e) + +} + func (o *consumer) sendCreateAdvisory() { o.mu.Lock() defer o.mu.Unlock() @@ -1472,6 +1687,27 @@ func (o *consumer) sendCreateAdvisory() { o.sendAdvisory(subj, e) } +func (o *consumer) sendPauseAdvisoryLocked(cfg *ConsumerConfig) { + e := JSConsumerPauseAdvisory{ + TypedEvent: TypedEvent{ + Type: JSConsumerPauseAdvisoryType, + ID: nuid.Next(), + Time: time.Now().UTC(), + }, + Stream: o.stream, + Consumer: o.name, + Domain: o.srv.getOpts().JetStreamDomain, + } + + if cfg.PauseUntil != nil { + e.PauseUntil = *cfg.PauseUntil + e.Paused = time.Now().Before(e.PauseUntil) + } + + subj := JSAdvisoryConsumerPausePre + "." + o.stream + "." + o.name + o.sendAdvisory(subj, e) +} + // Created returns created time. func (o *consumer) createdTime() time.Time { o.mu.Lock() @@ -1687,8 +1923,8 @@ func (o *consumer) deleteNotActive() { } nca := js.consumerAssignment(acc, stream, name) js.mu.RUnlock() - // Make sure this is not a new consumer with the same name. - if nca != nil && nca == ca { + // Make sure this is the same consumer assignment, and not a new consumer with the same name. + if nca != nil && reflect.DeepEqual(nca, ca) { s.Warnf("Consumer assignment for '%s > %s > %s' not cleaned up, retrying", acc, stream, name) meta.ForwardProposal(removeEntry) if interval < cnaMax { @@ -1908,6 +2144,12 @@ func (o *consumer) updateConfig(cfg *ConsumerConfig) error { return err } + // Make sure we always store PauseUntil in UTC. + if cfg.PauseUntil != nil { + utc := (*cfg.PauseUntil).UTC() + cfg.PauseUntil = &utc + } + if o.store != nil { // Update local state always. if err := o.store.UpdateConfig(cfg); err != nil { @@ -1956,6 +2198,22 @@ func (o *consumer) updateConfig(cfg *ConsumerConfig) error { o.dtmr = time.AfterFunc(o.dthresh, o.deleteNotActive) } } + // Check whether the pause has changed + { + var old, new time.Time + if o.cfg.PauseUntil != nil { + old = *o.cfg.PauseUntil + } + if cfg.PauseUntil != nil { + new = *cfg.PauseUntil + } + if !old.Equal(new) { + o.updatePauseState(cfg) + if o.isLeader() { + o.sendPauseAdvisoryLocked(cfg) + } + } + } // Check for Subject Filters update. newSubjects := gatherSubjectFilters(cfg.FilterSubject, cfg.FilterSubjects) @@ -2179,7 +2437,7 @@ func (o *consumer) loopAndForwardProposals(qch chan struct{}) { forwardProposals := func() error { o.mu.Lock() - if o.node == nil || o.node.State() != Leader { + if o.node == nil || !o.node.Leader() { o.mu.Unlock() return errors.New("no longer leader") } @@ -2559,10 +2817,7 @@ func (o *consumer) applyState(state *ConsumerState) { return } - // If o.sseq is greater don't update. Don't go backwards on o.sseq if leader. - if !o.isLeader() || o.sseq <= state.Delivered.Stream { - o.sseq = state.Delivered.Stream + 1 - } + o.sseq = state.Delivered.Stream + 1 o.dseq = state.Delivered.Consumer + 1 o.adflr = state.AckFloor.Consumer o.asflr = state.AckFloor.Stream @@ -2673,6 +2928,16 @@ func (o *consumer) infoWithSnapAndReply(snap bool, reply string) *ConsumerInfo { rg = o.ca.Group } + priorityGroups := []PriorityGroupState{} + // TODO(jrm): when we introduce supporting many priority groups, we need to update assigning `o.currentNuid` for each group. + if len(o.cfg.PriorityGroups) > 0 { + priorityGroups = append(priorityGroups, PriorityGroupState{ + Group: o.cfg.PriorityGroups[0], + PinnedClientID: o.currentPinId, + PinnedTS: o.pinnedTS, + }) + } + cfg := o.cfg info := &ConsumerInfo{ Stream: o.stream, @@ -2692,6 +2957,13 @@ func (o *consumer) infoWithSnapAndReply(snap bool, reply string) *ConsumerInfo { NumPending: o.checkNumPending(), PushBound: o.isPushMode() && o.active, TimeStamp: time.Now().UTC(), + PriorityGroups: priorityGroups, + } + if o.cfg.PauseUntil != nil { + p := *o.cfg.PauseUntil + if info.Paused = time.Now().Before(p); info.Paused { + info.PauseRemaining = time.Until(p) + } } // If we are replicated, we need to pull certain data from our store. @@ -2703,9 +2975,13 @@ func (o *consumer) infoWithSnapAndReply(snap bool, reply string) *ConsumerInfo { } // If we are the leader we could have o.sseq that is skipped ahead. // To maintain consistency in reporting (e.g. jsz) we always take the state for our delivered/ackfloor stream sequence. - info.Delivered.Consumer, info.Delivered.Stream = state.Delivered.Consumer, state.Delivered.Stream + // Only use skipped ahead o.sseq if we're a new consumer and have not yet replicated this state yet. + leader := o.isLeader() + if !leader || o.store.HasState() { + info.Delivered.Consumer, info.Delivered.Stream = state.Delivered.Consumer, state.Delivered.Stream + } info.AckFloor.Consumer, info.AckFloor.Stream = state.AckFloor.Consumer, state.AckFloor.Stream - if !o.isLeader() { + if !leader { info.NumAckPending = len(state.Pending) info.NumRedelivered = len(state.Redelivered) } @@ -3054,6 +3330,13 @@ func (o *consumer) needAck(sseq uint64, subj string) bool { return needAck } +type PriorityGroup struct { + Group string `json:"group,omitempty"` + MinPending int64 `json:"min_pending,omitempty"` + MinAckPending int64 `json:"min_ack_pending,omitempty"` + Id string `json:"id,omitempty"` +} + // Used in nextReqFromMsg, since the json.Unmarshal causes the request // struct to escape to the heap always. This should reduce GC pressure. var jsGetNextPool = sync.Pool{ @@ -3063,12 +3346,12 @@ var jsGetNextPool = sync.Pool{ } // Helper for the next message requests. -func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, error) { +func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, *PriorityGroup, error) { req := bytes.TrimSpace(msg) switch { case len(req) == 0: - return time.Time{}, 1, 0, false, 0, time.Time{}, nil + return time.Time{}, 1, 0, false, 0, time.Time{}, nil, nil case req[0] == '{': cr := jsGetNextPool.Get().(*JSApiConsumerGetNextRequest) @@ -3077,42 +3360,44 @@ func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time. jsGetNextPool.Put(cr) }() if err := json.Unmarshal(req, &cr); err != nil { - return time.Time{}, -1, 0, false, 0, time.Time{}, err + return time.Time{}, -1, 0, false, 0, time.Time{}, nil, err } var hbt time.Time if cr.Heartbeat > 0 { if cr.Heartbeat*2 > cr.Expires { - return time.Time{}, 1, 0, false, 0, time.Time{}, errors.New("heartbeat value too large") + return time.Time{}, 1, 0, false, 0, time.Time{}, nil, errors.New("heartbeat value too large") } hbt = time.Now().Add(cr.Heartbeat) } + priorityGroup := cr.PriorityGroup if cr.Expires == time.Duration(0) { - return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil + return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, &priorityGroup, nil } - return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil + return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, &priorityGroup, nil default: if n, err := strconv.Atoi(string(req)); err == nil { - return time.Time{}, n, 0, false, 0, time.Time{}, nil + return time.Time{}, n, 0, false, 0, time.Time{}, nil, nil } } - return time.Time{}, 1, 0, false, 0, time.Time{}, nil + return time.Time{}, 1, 0, false, 0, time.Time{}, nil, nil } // Represents a request that is on the internal waiting queue type waitingRequest struct { - next *waitingRequest - acc *Account - interest string - reply string - n int // For batching - d int // num delivered - b int // For max bytes tracking - expires time.Time - received time.Time - hb time.Duration - hbt time.Time - noWait bool + next *waitingRequest + acc *Account + interest string + reply string + n int // For batching + d int // num delivered + b int // For max bytes tracking + expires time.Time + received time.Time + hb time.Duration + hbt time.Time + noWait bool + priorityGroup *PriorityGroup } // sync.Pool for waiting requests. @@ -3210,6 +3495,16 @@ func (wq *waitQueue) peek() *waitingRequest { return wq.head } +func (wq *waitQueue) cycle() { + wr := wq.peek() + if wr != nil { + // Always remove current now on a pop, and move to end if still valid. + // If we were the only one don't need to remove since this can be a no-op. + wq.removeCurrent() + wq.add(wr) + } +} + // pop will return the next request and move the read cursor. // This will now place a request that still has pending items at the ends of the list. func (wq *waitQueue) pop() *waitingRequest { @@ -3271,6 +3566,20 @@ func (o *consumer) pendingRequests() map[string]*waitingRequest { return m } +func (o *consumer) setPinnedTimer(priorityGroup string) { + if o.pinnedTtl != nil { + o.pinnedTtl.Reset(o.cfg.PinnedTTL) + } else { + o.pinnedTtl = time.AfterFunc(o.cfg.PinnedTTL, func() { + o.mu.Lock() + o.currentPinId = _EMPTY_ + o.sendUnpinnedAdvisoryLocked(priorityGroup, "timeout") + o.mu.Unlock() + o.signalNewMessages() + }) + } +} + // Return next waiting request. This will check for expirations but not noWait or interest. // That will be handled by processWaiting. // Lock should be held. @@ -3278,6 +3587,16 @@ func (o *consumer) nextWaiting(sz int) *waitingRequest { if o.waiting == nil || o.waiting.isEmpty() { return nil } + + // Check if server needs to assign a new pin id. + needNewPin := o.currentPinId == _EMPTY_ && o.cfg.PriorityPolicy == PriorityPinnedClient + // As long as we support only one priority group, we can capture that group here and reuse it. + var priorityGroup string + if len(o.cfg.PriorityGroups) > 0 { + priorityGroup = o.cfg.PriorityGroups[0] + } + + lastRequest := o.waiting.tail for wr := o.waiting.peek(); !o.waiting.isEmpty(); wr = o.waiting.peek() { if wr == nil { break @@ -3307,11 +3626,73 @@ func (o *consumer) nextWaiting(sz int) *waitingRequest { } if wr.expires.IsZero() || time.Now().Before(wr.expires) { + if needNewPin { + if wr.priorityGroup.Id == _EMPTY_ { + o.currentPinId = nuid.Next() + o.pinnedTS = time.Now().UTC() + wr.priorityGroup.Id = o.currentPinId + o.setPinnedTimer(priorityGroup) + + } else { + // There is pin id set, but not a matching one. Send a notification to the client and remove the request. + // Probably this is the old pin id. + o.outq.send(newJSPubMsg(wr.reply, _EMPTY_, _EMPTY_, []byte(JSPullRequestWrongPinID), nil, nil, 0)) + o.waiting.removeCurrent() + if o.node != nil { + o.removeClusterPendingRequest(wr.reply) + } + wr.recycle() + continue + } + } else if o.currentPinId != _EMPTY_ { + // Check if we have a match on the currentNuid + if wr.priorityGroup != nil && wr.priorityGroup.Id == o.currentPinId { + // If we have a match, we do nothing here and will deliver the message later down the code path. + } else if wr.priorityGroup.Id == _EMPTY_ { + o.waiting.cycle() + if wr == lastRequest { + return nil + } + continue + } else { + // There is pin id set, but not a matching one. Send a notification to the client and remove the request. + o.outq.send(newJSPubMsg(wr.reply, _EMPTY_, _EMPTY_, []byte(JSPullRequestWrongPinID), nil, nil, 0)) + o.waiting.removeCurrent() + if o.node != nil { + o.removeClusterPendingRequest(wr.reply) + } + wr.recycle() + continue + } + } + + if o.cfg.PriorityPolicy == PriorityOverflow { + if wr.priorityGroup != nil && + // We need to check o.npc+1, because before calling nextWaiting, we do o.npc-- + (wr.priorityGroup.MinPending > 0 && wr.priorityGroup.MinPending > o.npc+1 || + wr.priorityGroup.MinAckPending > 0 && wr.priorityGroup.MinAckPending > int64(len(o.pending))) { + o.waiting.cycle() + // We're done cycling through the requests. + if wr == lastRequest { + return nil + } + continue + } + } if wr.acc.sl.HasInterest(wr.interest) { + if needNewPin { + o.sendPinnedAdvisoryLocked(priorityGroup) + } return o.waiting.pop() } else if time.Since(wr.received) < defaultGatewayRecentSubExpiration && (o.srv.leafNodeEnabled || o.srv.gateway.enabled) { + if needNewPin { + o.sendPinnedAdvisoryLocked(priorityGroup) + } return o.waiting.pop() } else if o.srv.gateway.enabled && o.srv.hasGatewayInterest(wr.acc.Name, wr.interest) { + if needNewPin { + o.sendPinnedAdvisoryLocked(priorityGroup) + } return o.waiting.pop() } } else { @@ -3338,6 +3719,7 @@ func (o *consumer) nextWaiting(sz int) *waitingRequest { } wr.recycle() } + return nil } @@ -3411,7 +3793,7 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) { } // Check payload here to see if they sent in batch size or a formal request. - expires, batchSize, maxBytes, noWait, hb, hbt, err := nextReqFromMsg(msg) + expires, batchSize, maxBytes, noWait, hb, hbt, priorityGroup, err := nextReqFromMsg(msg) if err != nil { sendErr(400, fmt.Sprintf("Bad Request - %v", err)) return @@ -3433,6 +3815,44 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) { return } + if priorityGroup != nil { + if (priorityGroup.MinPending != 0 || priorityGroup.MinAckPending != 0) && o.cfg.PriorityPolicy != PriorityOverflow { + sendErr(400, "Bad Request - Not a Overflow Priority consumer") + } + + if priorityGroup.Id != _EMPTY_ && o.cfg.PriorityPolicy != PriorityPinnedClient { + sendErr(400, "Bad Request - Not a Pinned Client Priority consumer") + } + } + + if priorityGroup != nil && o.cfg.PriorityPolicy != PriorityNone { + if priorityGroup.Group == _EMPTY_ { + sendErr(400, "Bad Request - Priority Group missing") + return + } + + found := false + for _, group := range o.cfg.PriorityGroups { + if group == priorityGroup.Group { + found = true + break + } + } + if !found { + sendErr(400, "Bad Request - Invalid Priority Group") + return + } + + if o.currentPinId != _EMPTY_ { + if priorityGroup.Id == o.currentPinId { + o.setPinnedTimer(priorityGroup.Group) + } else if priorityGroup.Id != _EMPTY_ { + sendErr(423, "Nats-Pin-Id mismatch") + return + } + } + } + // If we have the max number of requests already pending try to expire. if o.waiting.isFull() { // Try to expire some of the requests. @@ -3469,7 +3889,7 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) { // Create a waiting request. wr := wrPool.Get().(*waitingRequest) - wr.acc, wr.interest, wr.reply, wr.n, wr.d, wr.noWait, wr.expires, wr.hb, wr.hbt = acc, interest, reply, batchSize, 0, noWait, expires, hb, hbt + wr.acc, wr.interest, wr.reply, wr.n, wr.d, wr.noWait, wr.expires, wr.hb, wr.hbt, wr.priorityGroup = acc, interest, reply, batchSize, 0, noWait, expires, hb, hbt, priorityGroup wr.b = maxBytes wr.received = time.Now() @@ -3980,6 +4400,8 @@ func (o *consumer) suppressDeletion() { } } +// loopAndGatherMsgs waits for messages for the consumer. qch is the quit channel, +// upch is the unpause channel which fires when the PauseUntil deadline is reached. func (o *consumer) loopAndGatherMsgs(qch chan struct{}) { // On startup check to see if we are in a reply situation where replay policy is not instant. var ( @@ -4046,6 +4468,13 @@ func (o *consumer) loopAndGatherMsgs(qch chan struct{}) { // Clear last error. err = nil + // If the consumer is paused then stop sending. + if o.cfg.PauseUntil != nil && !o.cfg.PauseUntil.IsZero() && time.Now().Before(*o.cfg.PauseUntil) { + // If the consumer is paused and we haven't reached the deadline yet then + // go back to waiting. + goto waitForMsgs + } + // If we are in push mode and not active or under flowcontrol let's stop sending. if o.isPushMode() { if !o.active || (o.maxpb > 0 && o.pbytes > o.maxpb) { @@ -4106,6 +4535,21 @@ func (o *consumer) loopAndGatherMsgs(qch chan struct{}) { } else if wr := o.nextWaiting(sz); wr != nil { wrn, wrb = wr.n, wr.b dsubj = wr.reply + if o.cfg.PriorityPolicy == PriorityPinnedClient { + // FIXME(jrm): Can we make this prettier? + if len(pmsg.hdr) == 0 { + pmsg.hdr = genHeader(pmsg.hdr, JSPullRequestNatsPinId, o.currentPinId) + pmsg.buf = append(pmsg.hdr, pmsg.msg...) + } else { + pmsg.hdr = genHeader(pmsg.hdr, JSPullRequestNatsPinId, o.currentPinId) + bufLen := len(pmsg.hdr) + len(pmsg.msg) + pmsg.buf = make([]byte, bufLen) + pmsg.buf = append(pmsg.hdr, pmsg.msg...) + } + + sz = len(pmsg.subj) + len(ackReply) + len(pmsg.hdr) + len(pmsg.msg) + + } if done := wr.recycleIfDone(); done && o.node != nil { o.removeClusterPendingRequest(dsubj) } else if !done && wr.hb > 0 { @@ -4388,9 +4832,6 @@ func (o *consumer) deliverMsg(dsubj, ackReply string, pmsg *jsPubMsg, dc uint64, // Update delivered first. o.updateDelivered(dseq, seq, dc, ts) - // Send message. - o.outq.send(pmsg) - if ap == AckExplicit || ap == AckAll { o.trackPending(seq, dseq) } else if ap == AckNone { @@ -4398,6 +4839,9 @@ func (o *consumer) deliverMsg(dsubj, ackReply string, pmsg *jsPubMsg, dc uint64, o.asflr = seq } + // Send message. + o.outq.send(pmsg) + // Flow control. if o.maxpb > 0 && o.needFlowControl(psz) { o.sendFlowControl() @@ -4858,13 +5302,13 @@ func (o *consumer) selectStartingSeqNo() { } else if o.cfg.DeliverPolicy == DeliverLast { if o.subjf == nil { o.sseq = state.LastSeq - return - } - // If we are partitioned here this will be properly set when we become leader. - for _, filter := range o.subjf { - ss := o.mset.store.FilteredState(1, filter.subject) - if ss.Last > o.sseq { - o.sseq = ss.Last + } else { + // If we are partitioned here this will be properly set when we become leader. + for _, filter := range o.subjf { + ss := o.mset.store.FilteredState(1, filter.subject) + if ss.Last > o.sseq { + o.sseq = ss.Last + } } } } else if o.cfg.DeliverPolicy == DeliverLastPerSubject { @@ -4942,12 +5386,18 @@ func (o *consumer) selectStartingSeqNo() { o.sseq = o.cfg.OptStartSeq } - if state.FirstSeq == 0 { + if state.FirstSeq == 0 && (o.cfg.Direct || o.cfg.OptStartSeq == 0) { + // If the stream is empty, deliver only new. + // But only if mirroring/sourcing, or start seq is unset, otherwise need to respect provided value. o.sseq = 1 - } else if o.sseq < state.FirstSeq { - o.sseq = state.FirstSeq - } else if o.sseq > state.LastSeq { + } else if o.sseq > state.LastSeq && (o.cfg.Direct || o.cfg.OptStartSeq == 0) { + // If selected sequence is in the future, clamp back down. + // But only if mirroring/sourcing, or start seq is unset, otherwise need to respect provided value. o.sseq = state.LastSeq + 1 + } else if o.sseq < state.FirstSeq { + // If the first sequence is further ahead than the starting sequence, + // there are no messages there anymore, so move the sequence up. + o.sseq = state.FirstSeq } } @@ -4958,7 +5408,8 @@ func (o *consumer) selectStartingSeqNo() { // Set ack store floor to store-1 o.asflr = o.sseq - 1 // Set our starting sequence state. - if o.store != nil && o.sseq > 0 { + // But only if we're not clustered, if clustered we propose upon becoming leader. + if o.store != nil && o.sseq > 0 && o.cfg.replicas(&o.mset.cfg) == 1 { o.store.SetStarting(o.sseq - 1) } } @@ -5182,7 +5633,7 @@ func (o *consumer) stopWithFlags(dflag, sdflag, doSignal, advisory bool) error { // Check if we are the leader and are being deleted (as a node). if dflag && o.isLeader() { // If we are clustered and node leader (probable from above), stepdown. - if node := o.node; node != nil && node.Leader() { + if node := o.node; node != nil { node.StepDown() } @@ -5335,8 +5786,10 @@ func (o *consumer) cleanupNoInterestMessages(mset *stream, ignoreInterest bool) var rmseqs []uint64 mset.mu.RLock() - // If over this amount of messages to check, defer to checkInterestState() which - // will do the right thing since we are now removed. + // If over this amount of messages to check, optimistically call to checkInterestState(). + // It will not always do the right thing in removing messages that lost interest, but ensures + // we don't degrade performance by doing a linear scan through the whole stream. + // Messages might need to expire based on limits to be cleaned up. // TODO(dlc) - Better way? const bailThresh = 100_000 @@ -5392,6 +5845,7 @@ func (o *consumer) switchToEphemeral() { interest := o.acc.sl.HasInterest(o.cfg.DeliverSubject) // Setup dthresh. o.updateInactiveThreshold(&o.cfg) + o.updatePauseState(&o.cfg) o.mu.Unlock() // Update interest @@ -5613,7 +6067,7 @@ func (o *consumer) checkStateForInterestStream(ss *StreamState) error { // Only ack though if no error and seq <= ack floor. if err == nil && seq <= asflr { didRemove := mset.ackMsg(o, seq) - // Removing the message could fail. For example if we're behind on stream applies. + // Removing the message could fail. For example if clustered since we need to propose it. // Overwrite retry floor (only the first time) to allow us to check next time if the removal was successful. if didRemove && retryAsflr == 0 { retryAsflr = seq diff --git a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail.go b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail.go index d3056cf81f..65e4ecb789 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build !windows && !openbsd && !netbsd && !wasm -// +build !windows,!openbsd,!netbsd,!wasm package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_netbsd.go b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_netbsd.go index d50af70764..1ce3920868 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_netbsd.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_netbsd.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build netbsd -// +build netbsd package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_openbsd.go b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_openbsd.go index 1dd4c0d201..6ed468fc38 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_openbsd.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_openbsd.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build openbsd -// +build openbsd package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_wasm.go b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_wasm.go index c668243439..47648834c6 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_wasm.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_wasm.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build wasm -// +build wasm package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_windows.go b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_windows.go index b7f95f313a..9c21243747 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_windows.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/disk_avail_windows.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build windows -// +build windows package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/errors.go b/vendor/github.com/nats-io/nats-server/v2/server/errors.go index ff718648a0..1bd4e8f777 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/errors.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/errors.go @@ -153,6 +153,9 @@ var ( // Gateway's name. ErrWrongGateway = errors.New("wrong gateway") + // ErrGatewayNameHasSpaces signals that the gateway name contains spaces, which is not allowed. + ErrGatewayNameHasSpaces = errors.New("gateway name cannot contain spaces") + // ErrNoSysAccount is returned when an attempt to publish or subscribe is made // when there is no internal system account defined. ErrNoSysAccount = errors.New("system account not setup") @@ -163,6 +166,9 @@ var ( // ErrServerNotRunning is used to signal an error that a server is not running. ErrServerNotRunning = errors.New("server is not running") + // ErrServerNameHasSpaces signals that the server name contains spaces, which is not allowed. + ErrServerNameHasSpaces = errors.New("server name cannot contain spaces") + // ErrBadMsgHeader signals the parser detected a bad message header ErrBadMsgHeader = errors.New("bad message header detected") @@ -181,7 +187,7 @@ var ( ErrClusterNameRemoteConflict = errors.New("cluster name from remote server conflicts") // ErrClusterNameHasSpaces signals that the cluster name contains spaces, which is not allowed. - ErrClusterNameHasSpaces = errors.New("cluster name cannot contain spaces or new lines") + ErrClusterNameHasSpaces = errors.New("cluster name cannot contain spaces") // ErrMalformedSubject is returned when a subscription is made with a subject that does not conform to subject rules. ErrMalformedSubject = errors.New("malformed subject") @@ -206,7 +212,7 @@ var ( ErrInvalidMappingDestination = errors.New("invalid mapping destination") // ErrInvalidMappingDestinationSubject is used to error on a bad transform destination mapping - ErrInvalidMappingDestinationSubject = fmt.Errorf("%w: invalid subject", ErrInvalidMappingDestination) + ErrInvalidMappingDestinationSubject = fmt.Errorf("%w: invalid transform", ErrInvalidMappingDestination) // ErrMappingDestinationNotUsingAllWildcards is used to error on a transform destination not using all of the token wildcards ErrMappingDestinationNotUsingAllWildcards = fmt.Errorf("%w: not using all of the token wildcard(s)", ErrInvalidMappingDestination) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/errors.json b/vendor/github.com/nats-io/nats-server/v2/server/errors.json index 79602466a1..7b90366a6e 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/errors.json +++ b/vendor/github.com/nats-io/nats-server/v2/server/errors.json @@ -203,7 +203,7 @@ "constant": "JSInvalidJSONErr", "code": 400, "error_code": 10025, - "description": "invalid JSON", + "description": "invalid JSON: {err}", "comment": "", "help": "", "url": "", @@ -833,7 +833,7 @@ "constant": "JSConsumerPullRequiresAckErr", "code": 400, "error_code": 10084, - "description": "consumer in pull mode requires ack policy", + "description": "consumer in pull mode requires explicit ack policy on workqueue stream", "comment": "", "help": "", "url": "", @@ -1433,7 +1433,7 @@ "constant": "JSSourceInvalidSubjectFilter", "code": 400, "error_code": 10145, - "description": "source subject filter is invalid", + "description": "source transform source: {err}", "comment": "", "help": "", "url": "", @@ -1443,7 +1443,7 @@ "constant": "JSSourceInvalidTransformDestination", "code": 400, "error_code": 10146, - "description": "source transform destination is invalid", + "description": "source transform: {err}", "comment": "", "help": "", "url": "", @@ -1493,7 +1493,7 @@ "constant": "JSMirrorInvalidSubjectFilter", "code": 400, "error_code": 10151, - "description": "mirror subject filter is invalid", + "description": "mirror transform source: {err}", "comment": "", "help": "", "url": "", @@ -1518,5 +1518,145 @@ "help": "", "url": "", "deprecates": "" + }, + { + "constant": "JSMirrorInvalidTransformDestination", + "code": 400, + "error_code": 10154, + "description": "mirror transform: {err}", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamTransformInvalidSource", + "code": 400, + "error_code": 10155, + "description": "stream transform source: {err}", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamTransformInvalidDestination", + "code": 400, + "error_code": 10156, + "description": "stream transform: {err}", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSPedanticErrF", + "code": 400, + "error_code": 10157, + "description": "pedantic mode: {err}", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamDuplicateMessageConflict", + "code": 409, + "error_code": 10158, + "description": "duplicate message id is in process", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSConsumerPriorityPolicyWithoutGroup", + "code": 400, + "error_code": 10159, + "description": "Setting PriorityPolicy requires at least one PriorityGroup to be set", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSConsumerInvalidPriorityGroupErr", + "code": 400, + "error_code": 10160, + "description": "Provided priority group does not exist for this consumer", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSConsumerEmptyGroupName", + "code": 400, + "error_code": 10161, + "description": "Group name cannot be an empty string", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSConsumerInvalidGroupNameErr", + "code": 400, + "error_code": 10162, + "description": "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamExpectedLastSeqPerSubjectNotReady", + "code": 503, + "error_code": 10163, + "description": "expected last sequence per subject temporarily unavailable", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamWrongLastSequenceConstantErr", + "code": 400, + "error_code": 10164, + "description": "wrong last sequence", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSMessageTTLInvalidErr", + "code": 400, + "error_code": 10165, + "description": "invalid per-message TTL", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSMessageTTLDisabledErr", + "code": 400, + "error_code": 10166, + "description": "per-message TTL is disabled", + "comment": "", + "help": "", + "url": "", + "deprecates": "" + }, + { + "constant": "JSStreamTooManyRequests", + "code": 429, + "error_code": 10167, + "description": "too many requests", + "comment": "", + "help": "", + "url": "", + "deprecates": "" } ] diff --git a/vendor/github.com/nats-io/nats-server/v2/server/events.go b/vendor/github.com/nats-io/nats-server/v2/server/events.go index c050e6525d..533ca9fa4d 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/events.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/events.go @@ -31,7 +31,6 @@ import ( "time" "github.com/klauspost/compress/s2" - "github.com/nats-io/jwt/v2" "github.com/nats-io/nats-server/v2/server/certidp" "github.com/nats-io/nats-server/v2/server/pse" @@ -215,6 +214,7 @@ type AccountNumConns struct { // AccountStat contains the data common between AccountNumConns and AccountStatz type AccountStat struct { Account string `json:"acc"` + Name string `json:"name"` Conns int `json:"conns"` LeafNodes int `json:"leafnodes"` TotalConns int `json:"total_conns"` @@ -264,6 +264,7 @@ type ServerInfo struct { const ( JetStreamEnabled ServerCapability = 1 << iota // Server had JetStream enabled. BinaryStreamSnapshot // New stream snapshot capability. + AccountNRG // Move NRG traffic out of system account. ) // Set JetStream capability. @@ -289,6 +290,17 @@ func (si *ServerInfo) BinaryStreamSnapshot() bool { return si.Flags&BinaryStreamSnapshot != 0 } +// Set account NRG capability. +func (si *ServerInfo) SetAccountNRG() { + si.Flags |= AccountNRG +} + +// AccountNRG indicates whether or not we support moving the NRG traffic out of the +// system account and into the asset account. +func (si *ServerInfo) AccountNRG() bool { + return si.Flags&AccountNRG != 0 +} + // ClientInfo is detailed information about the client forming a connection. type ClientInfo struct { Start *time.Time `json:"start,omitempty"` @@ -348,21 +360,22 @@ func (ci *ClientInfo) forAdvisory() *ClientInfo { // ServerStats hold various statistics that we will periodically send out. type ServerStats struct { - Start time.Time `json:"start"` - Mem int64 `json:"mem"` - Cores int `json:"cores"` - CPU float64 `json:"cpu"` - Connections int `json:"connections"` - TotalConnections uint64 `json:"total_connections"` - ActiveAccounts int `json:"active_accounts"` - NumSubs uint32 `json:"subscriptions"` - Sent DataStats `json:"sent"` - Received DataStats `json:"received"` - SlowConsumers int64 `json:"slow_consumers"` - Routes []*RouteStat `json:"routes,omitempty"` - Gateways []*GatewayStat `json:"gateways,omitempty"` - ActiveServers int `json:"active_servers,omitempty"` - JetStream *JetStreamVarz `json:"jetstream,omitempty"` + Start time.Time `json:"start"` + Mem int64 `json:"mem"` + Cores int `json:"cores"` + CPU float64 `json:"cpu"` + Connections int `json:"connections"` + TotalConnections uint64 `json:"total_connections"` + ActiveAccounts int `json:"active_accounts"` + NumSubs uint32 `json:"subscriptions"` + Sent DataStats `json:"sent"` + Received DataStats `json:"received"` + SlowConsumers int64 `json:"slow_consumers"` + SlowConsumersStats *SlowConsumersStats `json:"slow_consumer_stats,omitempty"` + Routes []*RouteStat `json:"routes,omitempty"` + Gateways []*GatewayStat `json:"gateways,omitempty"` + ActiveServers int `json:"active_servers,omitempty"` + JetStream *JetStreamVarz `json:"jetstream,omitempty"` } // RouteStat holds route statistics. @@ -506,10 +519,14 @@ RESET: si.Version = VERSION si.Time = time.Now().UTC() si.Tags = tags + si.Flags = 0 if js { // New capability based flags. si.SetJetStreamEnabled() si.SetBinaryStreamSnapshot() + if s.accountNRGAllowed.Load() { + si.SetAccountNRG() + } } } var b []byte @@ -684,7 +701,7 @@ func (s *Server) sendInternalAccountMsgWithReply(a *Account, subject, reply stri } // Send system style message to an account scope. -func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg interface{}) { +func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg any, ct compressionType) { s.mu.RLock() if s.sys == nil || s.sys.sendq == nil || a == nil { s.mu.RUnlock() @@ -697,7 +714,7 @@ func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerIn c := a.internalClient() a.mu.Unlock() - sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, noCompression, false, false)) + sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, ct, false, false)) } // This will queue up a message to be sent. @@ -895,6 +912,16 @@ func (s *Server) sendStatsz(subj string) { m.Stats.Sent.Msgs = atomic.LoadInt64(&s.outMsgs) m.Stats.Sent.Bytes = atomic.LoadInt64(&s.outBytes) m.Stats.SlowConsumers = atomic.LoadInt64(&s.slowConsumers) + // Evaluate the slow consumer stats, but set it only if one of the value is not 0. + scs := &SlowConsumersStats{ + Clients: s.NumSlowConsumersClients(), + Routes: s.NumSlowConsumersRoutes(), + Gateways: s.NumSlowConsumersGateways(), + Leafs: s.NumSlowConsumersLeafs(), + } + if scs.Clients != 0 || scs.Routes != 0 || scs.Gateways != 0 || scs.Leafs != 0 { + m.Stats.SlowConsumersStats = scs + } m.Stats.NumSubs = s.numSubscriptions() // Routes s.forEachRoute(func(r *client) { @@ -980,6 +1007,7 @@ func (s *Server) sendStatsz(subj string) { jStat.Meta.Pending = ipq.len() } } + jStat.Limits = &s.getOpts().JetStreamLimits m.Stats.JetStream = jStat s.mu.RLock() } @@ -1657,7 +1685,8 @@ func (s *Server) remoteServerUpdate(sub *subscription, c *client, _ *Account, su } node := getHash(si.Name) - s.nodeToInfo.Store(node, nodeInfo{ + accountNRG := si.AccountNRG() + oldInfo, _ := s.nodeToInfo.Swap(node, nodeInfo{ si.Name, si.Version, si.Cluster, @@ -1669,7 +1698,14 @@ func (s *Server) remoteServerUpdate(sub *subscription, c *client, _ *Account, su false, si.JetStreamEnabled(), si.BinaryStreamSnapshot(), + accountNRG, }) + if oldInfo == nil || accountNRG != oldInfo.(nodeInfo).accountNRG { + // One of the servers we received statsz from changed its mind about + // whether or not it supports in-account NRG, so update the groups + // with this information. + s.updateNRGAccountStatus() + } } // updateRemoteServer is called when we have an update from a remote server. @@ -1716,14 +1752,35 @@ func (s *Server) processNewServer(si *ServerInfo) { false, si.JetStreamEnabled(), si.BinaryStreamSnapshot(), + si.AccountNRG(), }) } } + go s.updateNRGAccountStatus() // Announce ourselves.. // Do this in a separate Go routine. go s.sendStatszUpdate() } +// Works out whether all nodes support moving the NRG traffic into +// the account and moves it appropriately. +// Server lock MUST NOT be held on entry. +func (s *Server) updateNRGAccountStatus() { + s.rnMu.RLock() + raftNodes := make([]RaftNode, 0, len(s.raftNodes)) + for _, n := range s.raftNodes { + raftNodes = append(raftNodes, n) + } + s.rnMu.RUnlock() + for _, n := range raftNodes { + // In the event that the node is happy that all nodes that + // it cares about haven't changed, this will be a no-op. + if err := n.RecreateInternalSubs(); err != nil { + n.Stop() + } + } +} + // If GW is enabled on this server and there are any leaf node connections, // this function will send a LeafNode connect system event to the super cluster // to ensure that the GWs are in interest-only mode for this account. @@ -2338,6 +2395,7 @@ func (a *Account) statz() *AccountStat { leafConns := a.numLocalLeafNodes() return &AccountStat{ Account: a.Name, + Name: a.getNameTagLocked(), Conns: localConns, LeafNodes: leafConns, TotalConns: localConns + leafConns, @@ -2408,7 +2466,7 @@ func (s *Server) accountConnectEvent(c *client) { Jwt: c.opts.JWT, IssuerKey: issuerForClient(c), Tags: c.tags, - NameTag: c.nameTag, + NameTag: c.acc.getNameTag(), Kind: c.kindString(), ClientType: c.clientTypeString(), MQTTClient: c.getMQTTClientID(), @@ -2460,7 +2518,7 @@ func (s *Server) accountDisconnectEvent(c *client, now time.Time, reason string) Jwt: c.opts.JWT, IssuerKey: issuerForClient(c), Tags: c.tags, - NameTag: c.nameTag, + NameTag: c.acc.getNameTag(), Kind: c.kindString(), ClientType: c.clientTypeString(), MQTTClient: c.getMQTTClientID(), @@ -2514,7 +2572,7 @@ func (s *Server) sendAuthErrorEvent(c *client) { Jwt: c.opts.JWT, IssuerKey: issuerForClient(c), Tags: c.tags, - NameTag: c.nameTag, + NameTag: c.acc.getNameTag(), Kind: c.kindString(), ClientType: c.clientTypeString(), MQTTClient: c.getMQTTClientID(), @@ -2572,7 +2630,7 @@ func (s *Server) sendAccountAuthErrorEvent(c *client, acc *Account, reason strin Jwt: c.opts.JWT, IssuerKey: issuerForClient(c), Tags: c.tags, - NameTag: c.nameTag, + NameTag: c.acc.getNameTag(), Kind: c.kindString(), ClientType: c.clientTypeString(), MQTTClient: c.getMQTTClientID(), @@ -2589,7 +2647,7 @@ func (s *Server) sendAccountAuthErrorEvent(c *client, acc *Account, reason strin } c.mu.Unlock() - s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m) + s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m, noCompression) } // Internal message callback. diff --git a/vendor/github.com/nats-io/nats-server/v2/server/filestore.go b/vendor/github.com/nats-io/nats-server/v2/server/filestore.go index 7168e76a45..528219fb7d 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/filestore.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/filestore.go @@ -45,6 +45,7 @@ import ( "github.com/minio/highwayhash" "github.com/nats-io/nats-server/v2/server/avl" "github.com/nats-io/nats-server/v2/server/stree" + "github.com/nats-io/nats-server/v2/server/thw" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20poly1305" ) @@ -173,6 +174,7 @@ type fileStore struct { tombs []uint64 ld *LostStreamData scb StorageUpdateHandler + sdmcb SubjectDeleteMarkerUpdateHandler ageChk *time.Timer syncTmr *time.Timer cfg FileStreamInfo @@ -198,6 +200,9 @@ type fileStore struct { fip bool receivedAny bool firstMoved bool + ttls *thw.HashWheel + ttlseq uint64 // How up-to-date is the `ttls` THW? + markers []string } // Represents a message store block and its data. @@ -218,6 +223,7 @@ type msgBlock struct { index uint32 bytes uint64 // User visible bytes count. rbytes uint64 // Total bytes (raw) including deleted. Used for rolling to new blk. + cbytes uint64 // Bytes count after last compaction. 0 if no compaction happened yet. msgs uint64 // User visible message count. fss *stree.SubjectTree[SimpleState] kfn string @@ -244,6 +250,7 @@ type msgBlock struct { syncAlways bool noCompact bool closed bool + ttls uint64 // How many msgs have TTLs? // Used to mock write failures. mockWriteErr bool @@ -282,6 +289,8 @@ const ( purgeDir = "__msgs__" // used to scan blk file names. blkScan = "%d.blk" + // suffix of a block file + blkSuffix = ".blk" // used for compacted blocks that are staged. newScan = "%d.new" // used to scan index file names. @@ -321,6 +330,9 @@ const ( // This is the full snapshotted state for the stream. streamStreamStateFile = "index.db" + // This is the encoded time hash wheel for TTLs. + ttlStreamStateFile = "thw.db" + // AEK key sizes minMetaKeySize = 64 minBlkKeySize = 64 @@ -408,6 +420,11 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim srv: fcfg.srv, } + // Only create a THW if we're going to allow TTLs. + if cfg.AllowMsgTTL { + fs.ttls = thw.NewHashWheel() + } + // Set flush in place to AsyncFlush which by default is false. fs.fip = !fcfg.AsyncFlush @@ -459,6 +476,10 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim // Check if our prior state remembers a last sequence past where we can see. if fs.ld != nil && prior.LastSeq > fs.state.LastSeq { fs.state.LastSeq, fs.state.LastTime = prior.LastSeq, prior.LastTime + if fs.state.Msgs == 0 { + fs.state.FirstSeq = fs.state.LastSeq + 1 + fs.state.FirstTime = time.Time{} + } if _, err := fs.newMsgBlockForWrite(); err == nil { if err = fs.writeTombstone(prior.LastSeq, prior.LastTime.UnixNano()); err != nil { return nil, err @@ -471,6 +492,13 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim fs.dirty++ } + // See if we can bring back our TTL timed hash wheel state from disk. + if cfg.AllowMsgTTL { + if err = fs.recoverTTLState(); err != nil && !os.IsNotExist(err) { + fs.warn("Recovering TTL state from index errored: %v", err) + } + } + // Also make sure we get rid of old idx and fss files on return. // Do this in separate go routine vs inline and at end of processing. defer func() { @@ -518,7 +546,7 @@ func newFileStoreWithCreated(fcfg FileStoreConfig, cfg StreamConfig, created tim // sequence. Need to do this locked as by now the age check timer // has started. if cfg.FirstSeq > 0 && firstSeq <= cfg.FirstSeq { - if _, err := fs.purge(cfg.FirstSeq); err != nil { + if _, err := fs.purge(cfg.FirstSeq, true); err != nil { return nil, err } } @@ -1383,6 +1411,10 @@ func (mb *msgBlock) rebuildStateLocked() (*LostStreamData, []uint64, error) { rl, slen := le.Uint32(hdr[0:]), int(le.Uint16(hdr[20:])) hasHeaders := rl&hbit != 0 + var ttl int64 + if mb.fs.ttls != nil && len(hdr) > 0 { + ttl, _ = getMessageTTL(hdr) + } // Clear any headers bit that could be set. rl &^= hbit dlen := int(rl) - msgHdrSize @@ -1457,6 +1489,11 @@ func (mb *msgBlock) rebuildStateLocked() (*LostStreamData, []uint64, error) { if !mb.dmap.Exists(seq) { mb.msgs++ mb.bytes += uint64(rl) + if mb.fs.ttls != nil && ttl > 0 { + expires := time.Duration(ts) + (time.Second * time.Duration(ttl)) + mb.fs.ttls.Add(seq, int64(expires)) + mb.ttls++ + } } // Check for any gaps from compaction, meaning no ebit entry. @@ -1595,7 +1632,8 @@ func (fs *fileStore) recoverFullState() (rerr error) { } } - if buf[0] != fullStateMagic || buf[1] != fullStateVersion { + version := buf[1] + if buf[0] != fullStateMagic || version < fullStateMinVersion || version > fullStateVersion { os.Remove(fn) fs.warn("Stream state magic and version mismatch") return errCorruptState @@ -1687,6 +1725,10 @@ func (fs *fileStore) recoverFullState() (rerr error) { fs.blks = make([]*msgBlock, 0, numBlocks) for i := 0; i < int(numBlocks); i++ { index, nbytes, fseq, fts, lseq, lts, numDeleted := uint32(readU64()), readU64(), readU64(), readI64(), readU64(), readI64(), readU64() + var ttls uint64 + if version >= 2 { + ttls = readU64() + } if bi < 0 { os.Remove(fn) return errCorruptState @@ -1696,6 +1738,7 @@ func (fs *fileStore) recoverFullState() (rerr error) { atomic.StoreUint64(&mb.last.seq, lseq) mb.msgs, mb.bytes = lseq-fseq+1, nbytes mb.first.ts, mb.last.ts = fts+baseTime, lts+baseTime + mb.ttls = ttls if numDeleted > 0 { dmap, n, err := avl.Decode(buf[bi:]) if err != nil { @@ -1777,6 +1820,10 @@ func (fs *fileStore) recoverFullState() (rerr error) { var index uint32 for _, fi := range dirs { + // Ensure it's actually a block file, otherwise fmt.Sscanf also matches %d.blk.tmp + if !strings.HasSuffix(fi.Name(), blkSuffix) { + continue + } if n, err := fmt.Sscanf(fi.Name(), blkScan, &index); err == nil && n == 1 { if index > blkIndex { fs.warn("Stream state outdated, found extra blocks, will rebuild") @@ -1796,6 +1843,81 @@ func (fs *fileStore) recoverFullState() (rerr error) { return nil } +func (fs *fileStore) recoverTTLState() error { + // See if we have a timed hash wheel for TTLs. + <-dios + fn := filepath.Join(fs.fcfg.StoreDir, msgDir, ttlStreamStateFile) + buf, err := os.ReadFile(fn) + dios <- struct{}{} + + if err != nil && !os.IsNotExist(err) { + return err + } + + fs.ttls = thw.NewHashWheel() + + if err == nil { + fs.ttlseq, err = fs.ttls.Decode(buf) + if err != nil { + fs.warn("Error decoding TTL state: %s", err) + os.Remove(fn) + } + } + + if fs.ttlseq < fs.state.FirstSeq { + fs.ttlseq = fs.state.FirstSeq + } + + defer fs.resetAgeChk(0) + if fs.state.Msgs > 0 && fs.ttlseq <= fs.state.LastSeq { + fs.warn("TTL state is outdated; attempting to recover using linear scan (seq %d to %d)", fs.ttlseq, fs.state.LastSeq) + var sm StoreMsg + mb := fs.selectMsgBlock(fs.ttlseq) + if mb == nil { + return nil + } + mblseq := atomic.LoadUint64(&mb.last.seq) + for seq := fs.ttlseq; seq <= fs.state.LastSeq; seq++ { + retry: + if mb.ttls == 0 { + // None of the messages in the block have message TTLs so don't + // bother doing anything further with this block, skip to the end. + seq = atomic.LoadUint64(&mb.last.seq) + 1 + } + if seq > mblseq { + // We've reached the end of the loaded block, see if we can continue + // by loading the next one. + mb.tryForceExpireCache() + if mb = fs.selectMsgBlock(seq); mb == nil { + // TODO(nat): Deal with gaps properly. Right now this will be + // probably expensive on CPU. + continue + } + mblseq = atomic.LoadUint64(&mb.last.seq) + // At this point we've loaded another block, so let's go back to the + // beginning and see if we need to skip this one too. + goto retry + } + msg, _, err := mb.fetchMsg(seq, &sm) + if err != nil { + fs.warn("Error loading msg seq %d for recovering TTL: %s", seq, err) + continue + } + if len(msg.hdr) == 0 { + continue + } + if ttl, _ := getMessageTTL(msg.hdr); ttl > 0 { + expires := time.Duration(msg.ts) + (time.Second * time.Duration(ttl)) + fs.ttls.Add(seq, int64(expires)) + if seq > fs.ttlseq { + fs.ttlseq = seq + } + } + } + } + return nil +} + // Grabs last checksum for the named block file. // Takes into account encryption etc. func (mb *msgBlock) lastChecksum() []byte { @@ -1890,6 +2012,10 @@ func (fs *fileStore) recoverMsgs() error { indices := make(sort.IntSlice, 0, len(dirs)) var index int for _, fi := range dirs { + // Ensure it's actually a block file, otherwise fmt.Sscanf also matches %d.blk.tmp + if !strings.HasSuffix(fi.Name(), blkSuffix) { + continue + } if n, err := fmt.Sscanf(fi.Name(), blkScan, &index); err == nil && n == 1 { indices = append(indices, index) } @@ -2007,7 +2133,7 @@ func (fs *fileStore) expireMsgsOnRecover() error { mb.fss.IterOrdered(func(bsubj []byte, ss *SimpleState) bool { subj := bytesToString(bsubj) for i := uint64(0); i < ss.Msgs; i++ { - fs.removePerSubject(subj) + fs.removePerSubject(subj, false) } return true }) @@ -2027,7 +2153,11 @@ func (fs *fileStore) expireMsgsOnRecover() error { break } // Can we remove whole block here? - if mb.last.ts <= minAge { + // TODO(nat): We can't do this with LimitsTTL as we have no way to know + // if we're throwing away real messages or other tombstones without + // loading them, so in this case we'll fall through to the "slow path". + // There might be a better way of handling this though. + if mb.fs.cfg.SubjectDeleteMarkerTTL <= 0 && mb.last.ts <= minAge { purged += mb.msgs bytes += mb.bytes err := deleteEmptyBlock(mb) @@ -2097,7 +2227,7 @@ func (fs *fileStore) expireMsgsOnRecover() error { // Update fss // Make sure we have fss loaded. mb.removeSeqPerSubject(sm.subj, seq) - fs.removePerSubject(sm.subj) + fs.removePerSubject(sm.subj, fs.cfg.SubjectDeleteMarkerTTL > 0 && len(getHeader(JSMarkerReason, sm.hdr)) == 0) } // Make sure we have a proper next first sequence. if needNextFirst { @@ -2154,6 +2284,9 @@ func (fs *fileStore) expireMsgsOnRecover() error { fs.psim, fs.tsl = fs.psim.Empty(), 0 } + // If we have pending markers, then create them. + fs.subjectDeleteMarkersAfterOperation(JSMarkerReasonMaxAge) + // If we purged anything, make sure we kick flush state loop. if purged > 0 { fs.dirty++ @@ -2866,6 +2999,119 @@ func (fs *fileStore) SubjectsState(subject string) map[string]SimpleState { return fss } +// MultiLastSeqs will return a sorted list of sequences that match all subjects presented in filters. +// We will not exceed the maxSeq, which if 0 becomes the store's last sequence. +func (fs *fileStore) MultiLastSeqs(filters []string, maxSeq uint64, maxAllowed int) ([]uint64, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + + if fs.state.Msgs == 0 || fs.noTrackSubjects() { + return nil, nil + } + + lastBlkIndex := len(fs.blks) - 1 + lastMB := fs.blks[lastBlkIndex] + + // Implied last sequence. + if maxSeq == 0 { + maxSeq = fs.state.LastSeq + } else { + // Udate last mb index if not last seq. + lastBlkIndex, lastMB = fs.selectMsgBlockWithIndex(maxSeq) + } + //Make sure non-nil + if lastMB == nil { + return nil, nil + } + + // Grab our last mb index (not same as blk index). + lastMB.mu.RLock() + lastMBIndex := lastMB.index + lastMB.mu.RUnlock() + + subs := make(map[string]*psi) + ltSeen := make(map[string]uint32) + for _, filter := range filters { + fs.psim.Match(stringToBytes(filter), func(subj []byte, psi *psi) { + s := string(subj) + subs[s] = psi + if psi.lblk < lastMBIndex { + ltSeen[s] = psi.lblk + } + }) + } + + // If all subjects have a lower last index, select the largest for our walk backwards. + if len(ltSeen) == len(subs) { + max := uint32(0) + for _, mbi := range ltSeen { + if mbi > max { + max = mbi + } + } + lastMB = fs.bim[max] + } + + // Collect all sequences needed. + seqs := make([]uint64, 0, len(subs)) + for i, lnf := lastBlkIndex, false; i >= 0; i-- { + if len(subs) == 0 { + break + } + mb := fs.blks[i] + if !lnf { + if mb != lastMB { + continue + } + lnf = true + } + // We can start properly looking here. + mb.mu.Lock() + mb.ensurePerSubjectInfoLoaded() + for subj, psi := range subs { + if ss, ok := mb.fss.Find(stringToBytes(subj)); ok && ss != nil { + if ss.Last <= maxSeq { + seqs = append(seqs, ss.Last) + delete(subs, subj) + } else { + // Need to search for it since last is > maxSeq. + if mb.cacheNotLoaded() { + mb.loadMsgsWithLock() + } + var smv StoreMsg + fseq := atomic.LoadUint64(&mb.first.seq) + for seq := maxSeq; seq >= fseq; seq-- { + sm, _ := mb.cacheLookup(seq, &smv) + if sm == nil || sm.subj != subj { + continue + } + seqs = append(seqs, sm.seq) + delete(subs, subj) + break + } + } + } else if mb.index <= psi.fblk { + // Track which subs are no longer applicable, meaning we will not find a valid msg at this point. + delete(subs, subj) + } + // TODO(dlc) we could track lblk like above in case some subs are very far apart. + // Not too bad if fss loaded since we will skip over quickly with it loaded, but might be worth it. + } + mb.mu.Unlock() + + // If maxAllowed was sepcified check that we will not exceed that. + if maxAllowed > 0 && len(seqs) > maxAllowed { + return nil, ErrTooManyResults + } + + } + if len(seqs) == 0 { + return nil, nil + } + slices.Sort(seqs) + return seqs, nil +} + // NumPending will return the number of pending messages matching the filter subject starting at sequence. // Optimized for stream num pending calculations for consumers. func (fs *fileStore) NumPending(sseq uint64, filter string, lastPerSubject bool) (total, validThrough uint64) { @@ -3501,7 +3747,7 @@ func (fs *fileStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject bo return total, validThrough } -// SubjectsTotal return message totals per subject. +// SubjectsTotals return message totals per subject. func (fs *fileStore) SubjectsTotals(filter string) map[string]uint64 { fs.mu.RLock() defer fs.mu.RUnlock() @@ -3533,6 +3779,13 @@ func (fs *fileStore) RegisterStorageUpdates(cb StorageUpdateHandler) { } } +// RegisterSubjectDeleteMarkerUpdates registers a callback for updates to new tombstones. +func (fs *fileStore) RegisterSubjectDeleteMarkerUpdates(cb SubjectDeleteMarkerUpdateHandler) { + fs.mu.Lock() + fs.sdmcb = cb + fs.mu.Unlock() +} + // Helper to get hash key for specific message block. // Lock should be held func (fs *fileStore) hashKeyForBlock(index uint32) []byte { @@ -3659,7 +3912,7 @@ func (fs *fileStore) genEncryptionKeysForBlock(mb *msgBlock) error { // Stores a raw message with expected sequence number and timestamp. // Lock should be held. -func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts int64) (err error) { +func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) (err error) { if fs.closed { return ErrStoreClosed } @@ -3711,6 +3964,8 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in } // Write msg record. + // Add expiry bit to sequence if needed. This is so that if we need to + // rebuild, we know which messages to look at more quickly. n, err := fs.writeMsgRecord(seq, ts, subj, hdr, msg) if err != nil { return err @@ -3778,8 +4033,21 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in fs.enforceMsgLimit() fs.enforceBytesLimit() + // Per-message TTL. + if fs.ttls != nil && ttl > 0 { + expires := time.Duration(ts) + (time.Second * time.Duration(ttl)) + fs.ttls.Add(seq, int64(expires)) + fs.lmb.ttls++ + if seq > fs.ttlseq { + fs.ttlseq = seq + } + } + // Check if we have and need the age expiration timer running. - if fs.ageChk == nil && fs.cfg.MaxAge != 0 { + switch { + case fs.ttls != nil && ttl > 0: + fs.resetAgeChk(0) + case fs.ageChk == nil && (fs.cfg.MaxAge > 0 || fs.ttls != nil): fs.startAgeChk() } @@ -3787,9 +4055,9 @@ func (fs *fileStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts in } // StoreRawMsg stores a raw message with expected sequence number and timestamp. -func (fs *fileStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts int64) error { +func (fs *fileStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error { fs.mu.Lock() - err := fs.storeRawMsg(subj, hdr, msg, seq, ts) + err := fs.storeRawMsg(subj, hdr, msg, seq, ts, ttl) cb := fs.scb // Check if first message timestamp requires expiry // sooner than initial replica expiry timer set to MaxAge when initializing. @@ -3809,10 +4077,10 @@ func (fs *fileStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts in } // Store stores a message. We hold the main filestore lock for any write operation. -func (fs *fileStore) StoreMsg(subj string, hdr, msg []byte) (uint64, int64, error) { +func (fs *fileStore) StoreMsg(subj string, hdr, msg []byte, ttl int64) (uint64, int64, error) { fs.mu.Lock() seq, ts := fs.state.LastSeq+1, time.Now().UnixNano() - err := fs.storeRawMsg(subj, hdr, msg, seq, ts) + err := fs.storeRawMsg(subj, hdr, msg, seq, ts, ttl) cb := fs.scb fs.mu.Unlock() @@ -4234,10 +4502,10 @@ func (fs *fileStore) EraseMsg(seq uint64) (bool, error) { } // Convenience function to remove per subject tracking at the filestore level. -// Lock should be held. -func (fs *fileStore) removePerSubject(subj string) { +// Lock should be held. Returns if we deleted the last message on the subject. +func (fs *fileStore) removePerSubject(subj string, marker bool) bool { if len(subj) == 0 || fs.psim == nil { - return + return false } // We do not update sense of fblk here but will do so when we resolve during lookup. bsubj := stringToBytes(subj) @@ -4248,9 +4516,14 @@ func (fs *fileStore) removePerSubject(subj string) { } else if info.total == 0 { if _, ok = fs.psim.Delete(bsubj); ok { fs.tsl -= len(subj) + if marker { + fs.markers = append(fs.markers, subj) + } + return true } } } + return false } // Remove a message, optionally rewriting the mb file. @@ -4362,7 +4635,7 @@ func (fs *fileStore) removeMsg(seq uint64, secure, viaLimits, needFSLock bool) ( // If we are tracking multiple subjects here make sure we update that accounting. mb.removeSeqPerSubject(sm.subj, seq) - fs.removePerSubject(sm.subj) + wasLast := fs.removePerSubject(sm.subj, false) if secure { // Grab record info. @@ -4418,6 +4691,15 @@ func (fs *fileStore) removeMsg(seq uint64, secure, viaLimits, needFSLock bool) ( } mb.mu.Unlock() + // If the deleted message was itself a delete marker then + // don't write out more of them or we'll churn endlessly. + var sdmcb func() + if wasLast && len(getHeader(JSMarkerReason, sm.hdr)) == 0 { // Not a marker. + if viaLimits { + sdmcb = fs.subjectDeleteMarkerIfNeeded(sm.subj, JSMarkerReasonMaxAge) + } + } + // If we emptied the current message block and the seq was state.FirstSeq // then we need to jump message blocks. We will also write the index so // we don't lose track of the first sequence. @@ -4434,16 +4716,21 @@ func (fs *fileStore) removeMsg(seq uint64, secure, viaLimits, needFSLock bool) ( fs.writeTombstone(sm.seq, sm.ts) } - if cb := fs.scb; cb != nil { + if cb := fs.scb; cb != nil || sdmcb != nil { // If we have a callback registered we need to release lock regardless since cb might need it to lookup msg, etc. fs.mu.Unlock() // Storage updates. - var subj string - if sm != nil { - subj = sm.subj + if cb != nil { + var subj string + if sm != nil { + subj = sm.subj + } + delta := int64(msz) + cb(-1, -delta, seq, subj) + } + if sdmcb != nil { + sdmcb() } - delta := int64(msz) - cb(-1, -delta, seq, subj) if !needFSLock { fs.mu.Lock() @@ -4458,9 +4745,10 @@ func (fs *fileStore) removeMsg(seq uint64, secure, viaLimits, needFSLock bool) ( // Tests whether we should try to compact this block while inline removing msgs. // We will want rbytes to be over the minimum and have a 2x potential savings. +// If we compacted before but rbytes didn't improve much, guard against constantly compacting. // Lock should be held. func (mb *msgBlock) shouldCompactInline() bool { - return mb.rbytes > compactMinimum && mb.bytes*2 < mb.rbytes + return mb.rbytes > compactMinimum && mb.bytes*2 < mb.rbytes && (mb.cbytes == 0 || mb.bytes*2 < mb.cbytes) } // Tests whether we should try to compact this block while running periodic sync. @@ -4592,6 +4880,7 @@ func (mb *msgBlock) compactWithFloor(floor uint64) { } else { mb.rbytes = rbytes } + mb.cbytes = mb.bytes // Remove any seqs from the beginning of the blk. for seq, nfseq := fseq, atomic.LoadUint64(&mb.first.seq); seq < nfseq; seq++ { @@ -4613,7 +4902,7 @@ func (mb *msgBlock) slotInfo(slot int) (uint32, uint32, bool, error) { } bi := mb.cache.idx[slot] - ri, hashChecked := (bi &^ hbit), (bi&hbit) != 0 + ri, hashChecked := (bi &^ cbit), (bi&cbit) != 0 // If this is a deleted slot return here. if bi == dbit { @@ -4628,7 +4917,7 @@ func (mb *msgBlock) slotInfo(slot int) (uint32, uint32, bool, error) { // Need to account for dbit markers in idx. // So we will walk until we find valid idx slot to calculate rl. for i := 1; slot+i < len(mb.cache.idx); i++ { - ni := mb.cache.idx[slot+i] &^ hbit + ni := mb.cache.idx[slot+i] &^ cbit if ni == dbit { continue } @@ -5141,25 +5430,52 @@ func (mb *msgBlock) expireCacheLocked() { } func (fs *fileStore) startAgeChk() { - if fs.ageChk == nil && fs.cfg.MaxAge != 0 { + if fs.ageChk != nil { + return + } + if fs.cfg.MaxAge != 0 || fs.ttls != nil { fs.ageChk = time.AfterFunc(fs.cfg.MaxAge, fs.expireMsgs) } } // Lock should be held. func (fs *fileStore) resetAgeChk(delta int64) { - if fs.cfg.MaxAge == 0 { + var next int64 = math.MaxInt64 + if fs.ttls != nil { + next = fs.ttls.GetNextExpiration(next) + } + + // If there's no MaxAge and there's nothing waiting to be expired then + // don't bother continuing. The next storeRawMsg() will wake us up if + // needs be. + if fs.cfg.MaxAge <= 0 && next == math.MaxInt64 { + clearTimer(&fs.ageChk) return } + // Check to see if we should be firing sooner than MaxAge for an expiring TTL. fireIn := fs.cfg.MaxAge - if delta > 0 && time.Duration(delta) < fireIn { - if fireIn = time.Duration(delta); fireIn < 250*time.Millisecond { - // Only fire at most once every 250ms. - // Excessive firing can effect ingest performance. - fireIn = time.Second + if next < math.MaxInt64 { + // Looks like there's a next expiration, use it either if there's no + // MaxAge set or if it looks to be sooner than MaxAge is. + if until := time.Until(time.Unix(0, next)); fireIn == 0 || until < fireIn { + fireIn = until } } + + // If not then look at the delta provided (usually gap to next age expiry). + if delta > 0 { + if fireIn == 0 || time.Duration(delta) < fireIn { + fireIn = time.Duration(delta) + } + } + + // Make sure we aren't firing too often either way, otherwise we can + // negatively impact stream ingest performance. + if fireIn < 250*time.Millisecond { + fireIn = 250 * time.Millisecond + } + if fs.ageChk != nil { fs.ageChk.Reset(fireIn) } else { @@ -5175,6 +5491,70 @@ func (fs *fileStore) cancelAgeChk() { } } +// Lock must be held so that nothing else can interleave and write a +// new message on this subject before we get the chance to write the +// delete marker. If the delete marker is written successfully then +// this function returns a callback func to call scb and sdmcb after +// the lock has been released. +func (fs *fileStore) subjectDeleteMarkerIfNeeded(subj string, reason string) func() { + if fs.cfg.SubjectDeleteMarkerTTL <= 0 { + return nil + } + if _, ok := fs.psim.Find(stringToBytes(subj)); ok { + // There are still messages left with this subject, + // therefore it wasn't the last message deleted. + return nil + } + // Build the subject delete marker. If no TTL is specified then + // we'll default to 15 minutes — by that time every possible condition + // should have cleared (i.e. ordered consumer timeout, client timeouts, + // route/gateway interruptions, even device/client restarts etc). + ttl := int64(fs.cfg.SubjectDeleteMarkerTTL.Seconds()) + if ttl <= 0 { + return nil + } + var _hdr [128]byte + hdr := fmt.Appendf( + _hdr[:0], + "NATS/1.0\r\n%s: %s\r\n%s: %s\r\n%s: %d\r\n%s: %s\r\n\r\n\r\n", + JSMarkerReason, reason, + JSMessageTTL, time.Duration(ttl)*time.Second, + JSExpectedLastSubjSeq, 0, + JSExpectedLastSubjSeqSubj, subj, + ) + msg := &inMsg{ + subj: subj, + hdr: hdr, + } + sdmcb := fs.sdmcb + return func() { + if sdmcb != nil { + sdmcb(msg) + } + } +} + +// Filestore lock must be held but message block locks must not be. +// The caller should call the callback, if non-nil, after releasing +// the filestore lock. +func (fs *fileStore) subjectDeleteMarkersAfterOperation(reason string) func() { + if fs.cfg.SubjectDeleteMarkerTTL <= 0 || len(fs.markers) == 0 { + return nil + } + cbs := make([]func(), 0, len(fs.markers)) + for _, subject := range fs.markers { + if cb := fs.subjectDeleteMarkerIfNeeded(subject, reason); cb != nil { + cbs = append(cbs, cb) + } + } + fs.markers = nil + return func() { + for _, cb := range cbs { + cb() + } + } +} + // Will expire msgs that are too old. func (fs *fileStore) expireMsgs() { // We need to delete one by one here and can not optimize for the time being. @@ -5186,19 +5566,46 @@ func (fs *fileStore) expireMsgs() { minAge := time.Now().UnixNano() - maxAge fs.mu.RUnlock() - for sm, _ = fs.msgForSeq(0, &smv); sm != nil && sm.ts <= minAge; sm, _ = fs.msgForSeq(0, &smv) { - fs.mu.Lock() - fs.removeMsgViaLimits(sm.seq) - fs.mu.Unlock() - // Recalculate in case we are expiring a bunch. - minAge = time.Now().UnixNano() - maxAge + if maxAge > 0 { + var seq uint64 + for sm, seq, _ = fs.LoadNextMsg(fwcs, true, 0, &smv); sm != nil && sm.ts <= minAge; sm, seq, _ = fs.LoadNextMsg(fwcs, true, seq+1, &smv) { + if len(sm.hdr) > 0 { + if ttl, err := getMessageTTL(sm.hdr); err == nil && ttl < 0 { + // The message has a negative TTL, therefore it must "never expire". + minAge = time.Now().UnixNano() - maxAge + continue + } + } + // Remove the message and then, if LimitsTTL is enabled, try and work out + // if it was the last message of that particular subject that we just deleted. + fs.mu.Lock() + fs.removeMsgViaLimits(sm.seq) + fs.mu.Unlock() + // Recalculate in case we are expiring a bunch. + minAge = time.Now().UnixNano() - maxAge + } } fs.mu.Lock() defer fs.mu.Unlock() - // Onky cancel if no message left, not on potential lookup error that would result in sm == nil. - if fs.state.Msgs == 0 { + // TODO: Not great that we're holding the lock here, but the timed hash wheel isn't thread-safe. + nextTTL := int64(math.MaxInt64) + if fs.ttls != nil { + fs.ttls.ExpireTasks(func(seq uint64, ts int64) { + fs.removeMsg(seq, false, false, false) + }) + if maxAge > 0 { + // Only check if we're expiring something in the next MaxAge interval, saves us a bit + // of work if MaxAge will beat us to the next expiry anyway. + nextTTL = fs.ttls.GetNextExpiration(time.Now().Add(time.Duration(maxAge)).UnixNano()) + } else { + nextTTL = fs.ttls.GetNextExpiration(math.MaxInt64) + } + } + + // Only cancel if no message left, not on potential lookup error that would result in sm == nil. + if fs.state.Msgs == 0 && nextTTL == math.MaxInt64 { fs.cancelAgeChk() } else { if sm == nil { @@ -5379,7 +5786,7 @@ func (mb *msgBlock) writeMsgRecord(rl, seq uint64, subj string, mhdr, msg []byte mb.cache.fseq = seq } // Write index - mb.cache.idx = append(mb.cache.idx, uint32(index)|hbit) + mb.cache.idx = append(mb.cache.idx, uint32(index)|cbit) } else { // Make sure to account for tombstones in rbytes. mb.rbytes += rl @@ -5965,9 +6372,11 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { } // Mark fss activity. mb.lsts = time.Now().UnixNano() + mb.ttls = 0 lbuf := uint32(len(buf)) - var seq uint64 + var seq, ttls uint64 + var sm StoreMsg // Used for finding TTL headers for index < lbuf { if index+msgHdrSize > lbuf { return errCorruptState @@ -5977,6 +6386,7 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { seq = le.Uint64(hdr[4:]) // Clear any headers bit that could be set. + hasHeaders := rl&hbit != 0 rl &^= hbit dlen := int(rl) - msgHdrSize @@ -6039,6 +6449,16 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { }) } } + + // Count how many TTLs we think are in this message block. + // TODO(nat): Not terribly optimal... + if hasHeaders { + if fsm, err := mb.msgFromBuf(buf, &sm, nil); err == nil && fsm != nil { + if _, err = getMessageTTL(fsm.hdr); err == nil && len(fsm.hdr) > 0 { + ttls++ + } + } + } } index += rl } @@ -6060,6 +6480,7 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { mb.cache.idx = idx mb.cache.fseq = fseq mb.cache.wp += int(lbuf) + mb.ttls = ttls return nil } @@ -6468,15 +6889,16 @@ var ( ) const ( - // Used for marking messages that have had their checksums checked. - // Used to signal a message record with headers. - hbit = 1 << 31 - // Used for marking erased messages sequences. - ebit = 1 << 63 - // Used for marking tombstone sequences. - tbit = 1 << 62 - // Used to mark an index as deleted and non-existent. + // "Checksum bit" is used in "mb.cache.idx" for marking messages that have had their checksums checked. + cbit = 1 << 31 + // "Delete bit" is used in "mb.cache.idx" to mark an index as deleted and non-existent. dbit = 1 << 30 + // "Header bit" is used in "rl" to signal a message record with headers. + hbit = 1 << 31 + // "Erase bit" is used in "seq" for marking erased messages sequences. + ebit = 1 << 63 + // "Tombstone bit" is used in "seq" for marking tombstone sequences. + tbit = 1 << 62 ) // Will do a lookup from cache. @@ -6559,7 +6981,7 @@ func (mb *msgBlock) cacheLookup(seq uint64, sm *StoreMsg) (*StoreMsg, error) { // Clear the check bit here after we know all is good. if !hashChecked { - mb.cache.idx[seq-mb.cache.fseq] = (bi | hbit) + mb.cache.idx[seq-mb.cache.fseq] = (bi | cbit) } return fsm, nil @@ -7092,7 +7514,7 @@ func (mb *msgBlock) sinceLastWriteActivity() time.Duration { } func checkNewHeader(hdr []byte) error { - if hdr == nil || len(hdr) < 2 || hdr[0] != magic || + if len(hdr) < 2 || hdr[0] != magic || (hdr[1] != version && hdr[1] != newVersion) { return errCorruptState } @@ -7259,13 +7681,17 @@ func compareFn(subject string) func(string, string) bool { // PurgeEx will remove messages based on subject filters, sequence and number of messages to keep. // Will return the number of purged messages. -func (fs *fileStore) PurgeEx(subject string, sequence, keep uint64) (purged uint64, err error) { +func (fs *fileStore) PurgeEx(subject string, sequence, keep uint64, _ /* noMarkers */ bool) (purged uint64, err error) { + // TODO: Don't write markers on purge until we have solved performance + // issues with them. + noMarkers := true + if subject == _EMPTY_ || subject == fwcs { if keep == 0 && sequence == 0 { - return fs.Purge() + return fs.purge(0, noMarkers) } if sequence > 1 { - return fs.Compact(sequence) + return fs.compact(sequence, noMarkers) } } @@ -7340,7 +7766,7 @@ func (fs *fileStore) PurgeEx(subject string, sequence, keep uint64) (purged uint } // PSIM and FSS updates. mb.removeSeqPerSubject(sm.subj, seq) - fs.removePerSubject(sm.subj) + fs.removePerSubject(sm.subj, !noMarkers && fs.cfg.SubjectDeleteMarkerTTL > 0) // Track tombstones we need to write. tombs = append(tombs, msgId{sm.seq, sm.ts}) @@ -7393,11 +7819,15 @@ func (fs *fileStore) PurgeEx(subject string, sequence, keep uint64) (purged uint os.Remove(filepath.Join(fs.fcfg.StoreDir, msgDir, streamStreamStateFile)) fs.dirty++ cb := fs.scb + sdmcb := fs.subjectDeleteMarkersAfterOperation(JSMarkerReasonPurge) fs.mu.Unlock() if cb != nil { cb(-int64(purged), -int64(bytes), 0, _EMPTY_) } + if sdmcb != nil { + sdmcb() + } return purged, nil } @@ -7405,10 +7835,14 @@ func (fs *fileStore) PurgeEx(subject string, sequence, keep uint64) (purged uint // Purge will remove all messages from this store. // Will return the number of purged messages. func (fs *fileStore) Purge() (uint64, error) { - return fs.purge(0) + return fs.purge(0, false) } -func (fs *fileStore) purge(fseq uint64) (uint64, error) { +func (fs *fileStore) purge(fseq uint64, _ /* noMarkers */ bool) (uint64, error) { + // TODO: Don't write markers on purge until we have solved performance + // issues with them. + noMarkers := true + fs.mu.Lock() if fs.closed { fs.mu.Unlock() @@ -7431,6 +7865,13 @@ func (fs *fileStore) purge(fseq uint64) (uint64, error) { fs.blks = nil fs.lmb = nil fs.bim = make(map[uint32]*msgBlock) + // Subject delete markers if needed. + if !noMarkers && fs.cfg.SubjectDeleteMarkerTTL > 0 { + fs.psim.IterOrdered(func(subject []byte, _ *psi) bool { + fs.markers = append(fs.markers, string(subject)) + return true + }) + } // Clear any per subject tracking. fs.psim, fs.tsl = fs.psim.Empty(), 0 // Mark dirty. @@ -7487,6 +7928,7 @@ func (fs *fileStore) purge(fseq uint64) (uint64, error) { } cb := fs.scb + sdmcb := fs.subjectDeleteMarkersAfterOperation(JSMarkerReasonPurge) fs.mu.Unlock() // Force a new index.db to be written. @@ -7497,16 +7939,29 @@ func (fs *fileStore) purge(fseq uint64) (uint64, error) { if cb != nil { cb(-int64(purged), -rbytes, 0, _EMPTY_) } + if sdmcb != nil { + sdmcb() + } return purged, nil } // Compact will remove all messages from this store up to // but not including the seq parameter. +// No subject delete markers will be left if they are enabled. If they are disabled, +// then this is functionally equivalent to a normal Compact() call. // Will return the number of purged messages. func (fs *fileStore) Compact(seq uint64) (uint64, error) { + return fs.compact(seq, false) +} + +func (fs *fileStore) compact(seq uint64, _ /* noMarkers */ bool) (uint64, error) { + // TODO: Don't write markers on compact until we have solved performance + // issues with them. + noMarkers := true + if seq == 0 { - return fs.purge(seq) + return fs.purge(seq, noMarkers) } var purged, bytes uint64 @@ -7515,7 +7970,7 @@ func (fs *fileStore) Compact(seq uint64) (uint64, error) { // Same as purge all. if lseq := fs.state.LastSeq; seq > lseq { fs.mu.Unlock() - return fs.purge(seq) + return fs.purge(seq, noMarkers) } // We have to delete interior messages. smb := fs.selectMsgBlock(seq) @@ -7538,7 +7993,7 @@ func (fs *fileStore) Compact(seq uint64) (uint64, error) { mb.fss.IterOrdered(func(bsubj []byte, ss *SimpleState) bool { subj := bytesToString(bsubj) for i := uint64(0); i < ss.Msgs; i++ { - fs.removePerSubject(subj) + fs.removePerSubject(subj, !noMarkers && fs.cfg.SubjectDeleteMarkerTTL > 0) } return true }) @@ -7584,7 +8039,7 @@ func (fs *fileStore) Compact(seq uint64) (uint64, error) { } // Update fss smb.removeSeqPerSubject(sm.subj, mseq) - fs.removePerSubject(sm.subj) + fs.removePerSubject(sm.subj, !noMarkers && fs.cfg.SubjectDeleteMarkerTTL > 0) } } @@ -7692,7 +8147,8 @@ SKIP: // after we release the lock. os.Remove(filepath.Join(fs.fcfg.StoreDir, msgDir, streamStreamStateFile)) fs.dirty++ - + // Subject delete markers if needed. + sdmcb := fs.subjectDeleteMarkersAfterOperation(JSMarkerReasonPurge) cb := fs.scb fs.mu.Unlock() @@ -7704,6 +8160,9 @@ SKIP: if cb != nil && purged > 0 { cb(-int64(purged), -int64(bytes), 0, _EMPTY_) } + if sdmcb != nil { + sdmcb() + } return purged, err } @@ -8037,6 +8496,20 @@ func (mb *msgBlock) removeSeqPerSubject(subj string, seq uint64) { ss.Msgs-- + // Only one left. + if ss.Msgs == 1 { + if !ss.lastNeedsUpdate && seq != ss.Last { + ss.First = ss.Last + ss.firstNeedsUpdate = false + return + } + if !ss.firstNeedsUpdate && seq != ss.First { + ss.Last = ss.First + ss.lastNeedsUpdate = false + return + } + } + // We can lazily calculate the first/last sequence when needed. ss.firstNeedsUpdate = seq == ss.First || ss.firstNeedsUpdate ss.lastNeedsUpdate = seq == ss.Last || ss.lastNeedsUpdate @@ -8081,7 +8554,7 @@ func (mb *msgBlock) recalculateForSubj(subj string, ss *SimpleState) { fseq = mbFseq } for slot := startSlot; slot < len(mb.cache.idx); slot++ { - bi := mb.cache.idx[slot] &^ hbit + bi := mb.cache.idx[slot] &^ cbit if bi == dbit { // delete marker so skip. continue @@ -8122,7 +8595,7 @@ func (mb *msgBlock) recalculateForSubj(subj string, ss *SimpleState) { lseq = mbLseq } for slot := endSlot; slot >= startSlot; slot-- { - bi := mb.cache.idx[slot] &^ hbit + bi := mb.cache.idx[slot] &^ cbit if bi == dbit { // delete marker so skip. continue @@ -8418,9 +8891,13 @@ func (fs *fileStore) cancelSyncTimer() { } } +// The full state file is versioned. +// - 0x1: original binary index.db format +// - 0x2: adds support for TTL count field after num deleted const ( - fullStateMagic = uint8(11) - fullStateVersion = uint8(1) + fullStateMagic = uint8(11) + fullStateMinVersion = uint8(1) // What is the minimum version we know how to parse? + fullStateVersion = uint8(2) // What is the current version written out to index.db? ) // This go routine periodically writes out our full stream state index. @@ -8520,7 +8997,7 @@ func (fs *fileStore) _writeFullState(force bool) error { binary.MaxVarintLen64 + fs.tsl + // NumSubjects + total subject length numSubjects*(binary.MaxVarintLen64*4) + // psi record binary.MaxVarintLen64 + // Num blocks. - len(fs.blks)*((binary.MaxVarintLen64*7)+avgDmapLen) + // msg blocks, avgDmapLen is est for dmaps + len(fs.blks)*((binary.MaxVarintLen64*8)+avgDmapLen) + // msg blocks, avgDmapLen is est for dmaps binary.MaxVarintLen64 + 8 + 8 // last index + record checksum + full state checksum // Do 4k on stack if possible. @@ -8582,6 +9059,7 @@ func (fs *fileStore) _writeFullState(force bool) error { numDeleted := mb.dmap.Size() buf = binary.AppendUvarint(buf, uint64(numDeleted)) + buf = binary.AppendUvarint(buf, mb.ttls) // Field is new in version 2 if numDeleted > 0 { dmap, _ := mb.dmap.Encode(scratch[:0]) dmapTotalLen += len(dmap) @@ -8667,7 +9145,24 @@ func (fs *fileStore) _writeFullState(force bool) error { fs.mu.Unlock() } - return nil + return fs.writeTTLState() +} + +func (fs *fileStore) writeTTLState() error { + if fs.ttls == nil { + return nil + } + + fs.mu.RLock() + fn := filepath.Join(fs.fcfg.StoreDir, msgDir, ttlStreamStateFile) + buf := fs.ttls.Encode(fs.state.LastSeq) + fs.mu.RUnlock() + + <-dios + err := os.WriteFile(fn, buf, defaultFilePerms) + dios <- struct{}{} + + return err } // Stop the current filestore. @@ -8744,7 +9239,8 @@ func (fs *fileStore) stop(delete, writeState bool) error { const errFile = "errors.txt" // Stream our snapshot through S2 compression and tar. -func (fs *fileStore) streamSnapshot(w io.WriteCloser, includeConsumers bool) { +func (fs *fileStore) streamSnapshot(w io.WriteCloser, includeConsumers bool, errCh chan string) { + defer close(errCh) defer w.Close() enc := s2.NewWriter(w) @@ -8782,6 +9278,7 @@ func (fs *fileStore) streamSnapshot(w io.WriteCloser, includeConsumers bool) { writeErr := func(err string) { writeFile(errFile, []byte(err)) + errCh <- err } fs.mu.Lock() @@ -8958,9 +9455,10 @@ func (fs *fileStore) Snapshot(deadline time.Duration, checkMsgs, includeConsumer fs.FastState(&state) // Stream in separate Go routine. - go fs.streamSnapshot(pw, includeConsumers) + errCh := make(chan string, 1) + go fs.streamSnapshot(pw, includeConsumers, errCh) - return &SnapshotResult{pr, state}, nil + return &SnapshotResult{pr, state, errCh}, nil } // Helper to return the config. @@ -9415,9 +9913,30 @@ func (o *consumerFileStore) SetStarting(sseq uint64) error { return o.writeState(buf) } +// UpdateStarting updates our starting stream sequence. +func (o *consumerFileStore) UpdateStarting(sseq uint64) { + o.mu.Lock() + defer o.mu.Unlock() + + if sseq > o.state.Delivered.Stream { + o.state.Delivered.Stream = sseq + // For AckNone just update delivered and ackfloor at the same time. + if o.cfg.AckPolicy == AckNone { + o.state.AckFloor.Stream = sseq + } + } + // Make sure we flush to disk. + o.kickFlusher() +} + // HasState returns if this store has a recorded state. func (o *consumerFileStore) HasState() bool { o.mu.Lock() + // We have a running state, or stored on disk but not yet initialized. + if o.state.Delivered.Consumer != 0 || o.state.Delivered.Stream != 0 { + o.mu.Unlock() + return true + } _, err := os.Stat(o.ifn) o.mu.Unlock() return err == nil @@ -9470,7 +9989,7 @@ func (o *consumerFileStore) UpdateDelivered(dseq, sseq, dc uint64, ts int64) err if o.state.Redelivered == nil { o.state.Redelivered = make(map[uint64]uint64) } - // Only update if greater then what we already have. + // Only update if greater than what we already have. if o.state.Redelivered[sseq] < dc-1 { o.state.Redelivered[sseq] = dc - 1 } @@ -9506,12 +10025,6 @@ func (o *consumerFileStore) UpdateAcks(dseq, sseq uint64) error { return nil } - // Match leader logic on checking if ack is ahead of delivered. - // This could happen on a cooperative takeover with high speed deliveries. - if sseq > o.state.Delivered.Stream { - o.state.Delivered.Stream = sseq + 1 - } - if len(o.state.Pending) == 0 || o.state.Pending[sseq] == nil { delete(o.state.Redelivered, sseq) return ErrStoreMsgNotFound @@ -9791,7 +10304,7 @@ func (cfs *consumerFileStore) writeConsumerMeta() error { // Consumer version. func checkConsumerHeader(hdr []byte) (uint8, error) { - if hdr == nil || len(hdr) < 2 || hdr[0] != magic { + if len(hdr) < 2 || hdr[0] != magic { return 0, errCorruptState } version := hdr[1] diff --git a/vendor/github.com/nats-io/nats-server/v2/server/fuzz.go b/vendor/github.com/nats-io/nats-server/v2/server/fuzz.go index 88f30350e9..361ab7c53e 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/fuzz.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/fuzz.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build gofuzz -// +build gofuzz package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/gateway.go b/vendor/github.com/nats-io/nats-server/v2/server/gateway.go index 22f0e417bd..6257e3f61d 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/gateway.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/gateway.go @@ -19,12 +19,14 @@ import ( "crypto/sha256" "crypto/tls" "encoding/json" + "errors" "fmt" "math/rand" "net" "net/url" "slices" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -217,6 +219,8 @@ type gateway struct { // interest-only mode "immediately", so the outbound should disregard // the optimistic mode when checking for interest. interestOnlyMode bool + // Name of the remote server + remoteName string } // Outbound subject interest entry. @@ -298,17 +302,20 @@ func (r *RemoteGatewayOpts) clone() *RemoteGatewayOpts { // Ensure that gateway is properly configured. func validateGatewayOptions(o *Options) error { - if o.Gateway.Name == "" && o.Gateway.Port == 0 { + if o.Gateway.Name == _EMPTY_ && o.Gateway.Port == 0 { return nil } - if o.Gateway.Name == "" { - return fmt.Errorf("gateway has no name") + if o.Gateway.Name == _EMPTY_ { + return errors.New("gateway has no name") + } + if strings.Contains(o.Gateway.Name, " ") { + return ErrGatewayNameHasSpaces } if o.Gateway.Port == 0 { return fmt.Errorf("gateway %q has no port specified (select -1 for random port)", o.Gateway.Name) } for i, g := range o.Gateway.Gateways { - if g.Name == "" { + if g.Name == _EMPTY_ { return fmt.Errorf("gateway in the list %d has no name", i) } if len(g.URLs) == 0 { @@ -528,6 +535,7 @@ func (s *Server) startGatewayAcceptLoop() { Gateway: opts.Gateway.Name, GatewayNRP: true, Headers: s.supportsHeaders(), + Proto: s.getServerProto(), } // Unless in some tests we want to keep the old behavior, we are now // (since v2.9.0) indicate that this server will switch all accounts @@ -1035,6 +1043,10 @@ func (c *client) processGatewayInfo(info *Info) { } if isFirstINFO { c.opts.Name = info.ID + // Get the protocol version from the INFO protocol. This will be checked + // to see if this connection supports message tracing for instance. + c.opts.Protocol = info.Proto + c.gw.remoteName = info.Name } c.mu.Unlock() @@ -2400,7 +2412,7 @@ func (s *Server) gatewayUpdateSubInterest(accName string, sub *subscription, cha if change < 0 { return } - entry = &sitally{n: 1, q: sub.queue != nil} + entry = &sitally{n: change, q: sub.queue != nil} st[string(key)] = entry first = true } else { @@ -2528,6 +2540,14 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr if len(gws) == 0 { return false } + + mt, _ := c.isMsgTraceEnabled() + if mt != nil { + pa := c.pa + msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg) + defer func() { c.pa = pa }() + } + var ( queuesa = [512]byte{} queues = queuesa[:0] @@ -2635,6 +2655,11 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr mreply = append(mreply, reply...) } } + + if mt != nil { + msg = mt.setHopHeader(c, msg) + } + // Setup the message header. // Make sure we are an 'R' proto by default c.msgb[0] = 'R' diff --git a/vendor/github.com/nats-io/nats-server/v2/server/ipqueue.go b/vendor/github.com/nats-io/nats-server/v2/server/ipqueue.go index b362631b56..094c522ee2 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/ipqueue.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/ipqueue.go @@ -14,6 +14,7 @@ package server import ( + "errors" "sync" "sync/atomic" ) @@ -28,36 +29,79 @@ type ipQueue[T any] struct { elts []T pos int pool *sync.Pool - mrs int + sz uint64 // Calculated size (only if calc != nil) name string m *sync.Map + ipQueueOpts[T] } -type ipQueueOpts struct { - maxRecycleSize int +type ipQueueOpts[T any] struct { + mrs int // Max recycle size + calc func(e T) uint64 // Calc function for tracking size + msz uint64 // Limit by total calculated size + mlen int // Limit by number of entries } -type ipQueueOpt func(*ipQueueOpts) +type ipQueueOpt[T any] func(*ipQueueOpts[T]) // This option allows to set the maximum recycle size when attempting // to put back a slice to the pool. -func ipQueue_MaxRecycleSize(max int) ipQueueOpt { - return func(o *ipQueueOpts) { - o.maxRecycleSize = max +func ipqMaxRecycleSize[T any](max int) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.mrs = max } } -func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] { - qo := ipQueueOpts{maxRecycleSize: ipQueueDefaultMaxRecycleSize} - for _, o := range opts { - o(&qo) +// This option enables total queue size counting by passing in a function +// that evaluates the size of each entry as it is pushed/popped. This option +// enables the size() function. +func ipqSizeCalculation[T any](calc func(e T) uint64) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.calc = calc } +} + +// This option allows setting the maximum queue size. Once the limit is +// reached, then push() will stop returning true and no more entries will +// be stored until some more are popped. The ipQueue_SizeCalculation must +// be provided for this to work. +func ipqLimitBySize[T any](max uint64) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.msz = max + } +} + +// This option allows setting the maximum queue length. Once the limit is +// reached, then push() will stop returning true and no more entries will +// be stored until some more are popped. +func ipqLimitByLen[T any](max int) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.mlen = max + } +} + +var errIPQLenLimitReached = errors.New("IPQ len limit reached") +var errIPQSizeLimitReached = errors.New("IPQ size limit reached") + +func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt[T]) *ipQueue[T] { q := &ipQueue[T]{ - ch: make(chan struct{}, 1), - mrs: qo.maxRecycleSize, - pool: &sync.Pool{}, + ch: make(chan struct{}, 1), + pool: &sync.Pool{ + New: func() any { + // Reason we use pointer to slice instead of slice is explained + // here: https://staticcheck.io/docs/checks#SA6002 + res := make([]T, 0, 32) + return &res + }, + }, name: name, m: &s.ipQueues, + ipQueueOpts: ipQueueOpts[T]{ + mrs: ipQueueDefaultMaxRecycleSize, + }, + } + for _, o := range opts { + o(&q.ipQueueOpts) } s.ipQueues.Store(name, q) return q @@ -66,32 +110,34 @@ func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] { // Add the element `e` to the queue, notifying the queue channel's `ch` if the // entry is the first to be added, and returns the length of the queue after // this element is added. -func (q *ipQueue[T]) push(e T) int { - var signal bool +func (q *ipQueue[T]) push(e T) (int, error) { q.Lock() l := len(q.elts) - q.pos - if l == 0 { - signal = true - eltsi := q.pool.Get() - if eltsi != nil { - // Reason we use pointer to slice instead of slice is explained - // here: https://staticcheck.io/docs/checks#SA6002 - q.elts = (*(eltsi.(*[]T)))[:0] - } - if cap(q.elts) == 0 { - q.elts = make([]T, 0, 32) + if q.mlen > 0 && l == q.mlen { + q.Unlock() + return l, errIPQLenLimitReached + } + if q.calc != nil { + sz := q.calc(e) + if q.msz > 0 && q.sz+sz > q.msz { + q.Unlock() + return l, errIPQSizeLimitReached } + q.sz += sz + } + if q.elts == nil { + // What comes out of the pool is already of size 0, so no need for [:0]. + q.elts = *(q.pool.Get().(*[]T)) } q.elts = append(q.elts, e) - l++ q.Unlock() - if signal { + if l == 0 { select { case q.ch <- struct{}{}: default: } } - return l + return l + 1, nil } // Returns the whole list of elements currently present in the queue, @@ -107,24 +153,23 @@ func (q *ipQueue[T]) pop() []T { if q == nil { return nil } - var elts []T q.Lock() + if len(q.elts)-q.pos == 0 { + q.Unlock() + return nil + } + var elts []T if q.pos == 0 { elts = q.elts } else { elts = q.elts[q.pos:] } - q.elts, q.pos = nil, 0 + q.elts, q.pos, q.sz = nil, 0, 0 atomic.AddInt64(&q.inprogress, int64(len(elts))) q.Unlock() return elts } -func (q *ipQueue[T]) resetAndReturnToPool(elts *[]T) { - (*elts) = (*elts)[:0] - q.pool.Put(elts) -} - // Returns the first element from the queue, if any. See comment above // regarding calling after being notified that there is something and // the use of drain(). In short, the caller should always check the @@ -133,24 +178,30 @@ func (q *ipQueue[T]) resetAndReturnToPool(elts *[]T) { func (q *ipQueue[T]) popOne() (T, bool) { q.Lock() l := len(q.elts) - q.pos - if l < 1 { + if l == 0 { q.Unlock() var empty T return empty, false } e := q.elts[q.pos] - q.pos++ - l-- - if l > 0 { + if l--; l > 0 { + q.pos++ + if q.calc != nil { + q.sz -= q.calc(e) + } // We need to re-signal select { case q.ch <- struct{}{}: default: } } else { - // We have just emptied the queue, so we can recycle now. - q.resetAndReturnToPool(&q.elts) - q.elts, q.pos = nil, 0 + // We have just emptied the queue, so we can reuse unless it is too big. + if cap(q.elts) <= q.mrs { + q.elts = q.elts[:0] + } else { + q.elts = nil + } + q.pos, q.sz = 0, 0 } q.Unlock() return e, true @@ -160,8 +211,7 @@ func (q *ipQueue[T]) popOne() (T, bool) { // a first element is added to the queue. // This will also decrement the "in progress" count with the length // of the slice. -// Reason we use pointer to slice instead of slice is explained -// here: https://staticcheck.io/docs/checks#SA6002 +// WARNING: The caller MUST never reuse `elts`. func (q *ipQueue[T]) recycle(elts *[]T) { // If invoked with a nil list, nothing to do. if elts == nil || *elts == nil { @@ -169,24 +219,30 @@ func (q *ipQueue[T]) recycle(elts *[]T) { } // Update the in progress count. if len(*elts) > 0 { - if atomic.AddInt64(&q.inprogress, int64(-(len(*elts)))) < 0 { - atomic.StoreInt64(&q.inprogress, 0) - } + atomic.AddInt64(&q.inprogress, int64(-(len(*elts)))) } // We also don't want to recycle huge slices, so check against the max. // q.mrs is normally immutable but can be changed, in a safe way, in some tests. if cap(*elts) > q.mrs { return } - q.resetAndReturnToPool(elts) + (*elts) = (*elts)[:0] + q.pool.Put(elts) } // Returns the current length of the queue. func (q *ipQueue[T]) len() int { q.Lock() - l := len(q.elts) - q.pos - q.Unlock() - return l + defer q.Unlock() + return len(q.elts) - q.pos +} + +// Returns the calculated size of the queue (if ipQueue_SizeCalculation has been +// passed in), otherwise returns zero. +func (q *ipQueue[T]) size() uint64 { + q.Lock() + defer q.Unlock() + return q.sz } // Empty the queue and consumes the notification signal if present. @@ -199,11 +255,8 @@ func (q *ipQueue[T]) drain() int { return 0 } q.Lock() - olen := len(q.elts) - if q.elts != nil { - q.resetAndReturnToPool(&q.elts) - q.elts, q.pos = nil, 0 - } + olen := len(q.elts) - q.pos + q.elts, q.pos, q.sz = nil, 0, 0 // Consume the signal if it was present to reduce the chance of a reader // routine to be think that there is something in the queue... select { diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream.go index ceb14663db..1dff7f2bd9 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jetstream.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream.go @@ -32,6 +32,7 @@ import ( "github.com/minio/highwayhash" "github.com/nats-io/nats-server/v2/server/sysmem" + "github.com/nats-io/nats-server/v2/server/tpm" "github.com/nats-io/nkeys" "github.com/nats-io/nuid" ) @@ -47,6 +48,7 @@ type JetStreamConfig struct { Domain string `json:"domain,omitempty"` CompressOK bool `json:"compress_ok,omitempty"` UniqueTag string `json:"unique_tag,omitempty"` + Strict bool `json:"strict,omitempty"` } // Statistics about JetStream for this server. @@ -90,6 +92,7 @@ type JetStreamAccountStats struct { } type JetStreamAPIStats struct { + Level int `json:"level"` Total uint64 `json:"total"` Errors uint64 `json:"errors"` Inflight uint64 `json:"inflight,omitempty"` @@ -173,6 +176,9 @@ type jsAccount struct { updatesSub *subscription lupdate time.Time utimer *time.Timer + + // Which account to send NRG traffic into. Empty string is system account. + nrgAccount string } // Track general usage for this account. @@ -370,6 +376,40 @@ func (s *Server) checkStoreDir(cfg *JetStreamConfig) error { return nil } +// This function sets/updates the jetstream encryption key and cipher based +// on options. If the TPM options have been specified, a key is generated +// and sealed by the TPM. +func (s *Server) initJetStreamEncryption() (err error) { + opts := s.getOpts() + + // The TPM settings and other encryption settings are mutually exclusive. + if opts.JetStreamKey != _EMPTY_ && opts.JetStreamTpm.KeysFile != _EMPTY_ { + return fmt.Errorf("JetStream encryption key may not be used with TPM options") + } + // if we are using the standard method to set the encryption key just return and carry on. + if opts.JetStreamKey != _EMPTY_ { + return nil + } + // if the tpm options are not used then no encryption has been configured and return. + if opts.JetStreamTpm.KeysFile == _EMPTY_ { + return nil + } + + if opts.JetStreamTpm.Pcr == 0 { + // Default PCR to use in the TPM. Values can be 0-23, and most platforms + // reserve values 0-12 for the OS, boot locker, disc encryption, etc. + // 16 used for debugging. In sticking to NATS tradition, we'll use 22 + // as the default with the option being configurable. + opts.JetStreamTpm.Pcr = 22 + } + + // Using the TPM to generate or get the encryption key and update the encryption options. + opts.JetStreamKey, err = tpm.LoadJetStreamEncryptionKeyFromTPM(opts.JetStreamTpm.SrkPassword, + opts.JetStreamTpm.KeysFile, opts.JetStreamTpm.KeyPassword, opts.JetStreamTpm.Pcr) + + return err +} + // enableJetStream will start up the JetStream subsystem. func (s *Server) enableJetStream(cfg JetStreamConfig) error { js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache()} @@ -402,6 +442,10 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error { os.Remove(tmpfile.Name()) } + if err := s.initJetStreamEncryption(); err != nil { + return err + } + // JetStream is an internal service so we need to make sure we have a system account. // This system account will export the JetStream service endpoints. if s.SystemAccount() == nil { @@ -419,6 +463,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error { s.Noticef("") } s.Noticef("---------------- JETSTREAM ----------------") + + if cfg.Strict { + s.Noticef(" Strict: %t", cfg.Strict) + } + s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory)) s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore)) s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir) @@ -429,6 +478,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error { if ek := opts.JetStreamKey; ek != _EMPTY_ { s.Noticef(" Encryption: %s", opts.JetStreamCipher) } + if opts.JetStreamTpm.KeysFile != _EMPTY_ { + s.Noticef(" TPM File: %q, Pcr: %d", opts.JetStreamTpm.KeysFile, + opts.JetStreamTpm.Pcr) + } + s.Noticef(" API Level: %d", JSApiLevel) s.Noticef("-------------------------------------------") // Setup our internal subscriptions. @@ -508,6 +562,7 @@ func (s *Server) restartJetStream() error { MaxMemory: opts.JetStreamMaxMemory, MaxStore: opts.JetStreamMaxStore, Domain: opts.JetStreamDomain, + Strict: opts.JetStreamStrict, } s.Noticef("Restarting JetStream") err := s.EnableJetStream(&cfg) @@ -1408,7 +1463,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits) erro // the consumer can reconnect. We will create it as a durable and switch it. cfg.ConsumerConfig.Durable = ofi.Name() } - obs, err := e.mset.addConsumerWithAssignment(&cfg.ConsumerConfig, _EMPTY_, nil, true, ActionCreateOrUpdate) + obs, err := e.mset.addConsumerWithAssignment(&cfg.ConsumerConfig, _EMPTY_, nil, true, ActionCreateOrUpdate, false) if err != nil { s.Warnf(" Error adding consumer %q: %v", cfg.Name, err) continue @@ -1652,6 +1707,7 @@ func (a *Account) JetStreamUsage() JetStreamAccountStats { stats.Memory, stats.Store = jsa.storageTotals() stats.Domain = js.config.Domain stats.API = JetStreamAPIStats{ + Level: JSApiLevel, Total: jsa.apiTotal, Errors: jsa.apiErrors, } @@ -2338,6 +2394,7 @@ func (js *jetStream) usageStats() *JetStreamStats { stats.ReservedStore = uint64(js.storeReserved) s := js.srv js.mu.RUnlock() + stats.API.Level = JSApiLevel stats.API.Total = uint64(atomic.LoadInt64(&js.apiTotal)) stats.API.Errors = uint64(atomic.LoadInt64(&js.apiErrors)) stats.API.Inflight = uint64(atomic.LoadInt64(&js.apiInflight)) @@ -2480,6 +2537,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je opts := s.getOpts() + // Strict mode. + jsc.Strict = opts.JetStreamStrict + // Sync options. jsc.SyncInterval = opts.SyncInterval jsc.SyncAlways = opts.SyncAlways @@ -2569,7 +2629,7 @@ func (a *Account) addStreamTemplate(tc *StreamTemplateConfig) (*streamTemplate, // FIXME(dlc) - Hacky tcopy := tc.deepCopy() tcopy.Config.Name = "_" - cfg, apiErr := s.checkStreamCfg(tcopy.Config, a) + cfg, apiErr := s.checkStreamCfg(tcopy.Config, a, false) if apiErr != nil { return nil, apiErr } @@ -2871,11 +2931,11 @@ func (s *Server) resourcesExceededError() { } s.rerrMu.Unlock() - // If we are meta leader we should relinguish that here. + // If we are meta leader we should relinquish that here. if didAlert { if js := s.getJetStream(); js != nil { js.mu.RLock() - if cc := js.cluster; cc != nil && cc.isLeader() { + if cc := js.cluster; cc != nil && cc.meta != nil { cc.meta.StepDown() } js.mu.RUnlock() diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_api.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_api.go index 4edc99bbd3..3d9882ad82 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_api.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_api.go @@ -19,7 +19,7 @@ import ( "encoding/json" "errors" "fmt" - "math/rand" + "io" "os" "path/filepath" "runtime" @@ -166,9 +166,18 @@ const ( JSApiConsumerDelete = "$JS.API.CONSUMER.DELETE.*.*" JSApiConsumerDeleteT = "$JS.API.CONSUMER.DELETE.%s.%s" + // JSApiConsumerPause is the endpoint to pause or unpause consumers. + // Will return JSON response. + JSApiConsumerPause = "$JS.API.CONSUMER.PAUSE.*.*" + JSApiConsumerPauseT = "$JS.API.CONSUMER.PAUSE.%s.%s" + // JSApiRequestNextT is the prefix for the request next message(s) for a consumer in worker/pull mode. JSApiRequestNextT = "$JS.API.CONSUMER.MSG.NEXT.%s.%s" + // JSApiConsumerUnpinT is the prefix for unpinning subscription for a given consumer. + JSApiConsumerUnpin = "$JS.API.CONSUMER.UNPIN.*.*" + JSApiConsumerUnpinT = "$JS.API.CONSUMER.UNPIN.%s.%s" + // jsRequestNextPre jsRequestNextPre = "$JS.API.CONSUMER.MSG.NEXT." @@ -260,12 +269,21 @@ const ( // JSAdvisoryStreamUpdatedPre notification that a stream was updated. JSAdvisoryStreamUpdatedPre = "$JS.EVENT.ADVISORY.STREAM.UPDATED" - // JSAdvisoryConsumerCreatedPre notification that a template created. + // JSAdvisoryConsumerCreatedPre notification that a consumer was created. JSAdvisoryConsumerCreatedPre = "$JS.EVENT.ADVISORY.CONSUMER.CREATED" - // JSAdvisoryConsumerDeletedPre notification that a template deleted. + // JSAdvisoryConsumerDeletedPre notification that a consumer was deleted. JSAdvisoryConsumerDeletedPre = "$JS.EVENT.ADVISORY.CONSUMER.DELETED" + // JSAdvisoryConsumerPausePre notification that a consumer paused/unpaused. + JSAdvisoryConsumerPausePre = "$JS.EVENT.ADVISORY.CONSUMER.PAUSE" + + // JSAdvisoryConsumerPinnedPre notification that a consumer was pinned. + JSAdvisoryConsumerPinnedPre = "$JS.EVENT.ADVISORY.CONSUMER.PINNED" + + // JSAdvisoryConsumerUnpinnedPre notification that a consumer was unpinned. + JSAdvisoryConsumerUnpinnedPre = "$JS.EVENT.ADVISORY.CONSUMER.UNPINNED" + // JSAdvisoryStreamSnapshotCreatePre notification that a snapshot was created. JSAdvisoryStreamSnapshotCreatePre = "$JS.EVENT.ADVISORY.STREAM.SNAPSHOT_CREATE" @@ -495,6 +513,16 @@ type JSApiStreamPurgeResponse struct { const JSApiStreamPurgeResponseType = "io.nats.jetstream.api.v1.stream_purge_response" +type JSApiConsumerUnpinRequest struct { + Group string `json:"group"` +} + +type JSApiConsumerUnpinResponse struct { + ApiResponse +} + +const JSApiConsumerUnpinResponseType = "io.nats.jetstream.api.v1.consumer_unpin_response" + // JSApiStreamUpdateResponse for updating a stream. type JSApiStreamUpdateResponse struct { ApiResponse @@ -642,6 +670,22 @@ type JSApiMsgGetRequest struct { Seq uint64 `json:"seq,omitempty"` LastFor string `json:"last_by_subj,omitempty"` NextFor string `json:"next_by_subj,omitempty"` + + // Batch support. Used to request more then one msg at a time. + // Can be used with simple starting seq, but also NextFor with wildcards. + Batch int `json:"batch,omitempty"` + // This will make sure we limit how much data we blast out. If not set we will + // inherit the slow consumer default max setting of the server. Default is MAX_PENDING_SIZE. + MaxBytes int `json:"max_bytes,omitempty"` + // Return messages as of this start time. + StartTime *time.Time `json:"start_time,omitempty"` + + // Multiple response support. Will get the last msgs matching the subjects. These can include wildcards. + MultiLastFor []string `json:"multi_last,omitempty"` + // Only return messages up to this sequence. If not set, will be last sequence for the stream. + UpToSeq uint64 `json:"up_to_seq,omitempty"` + // Only return messages up to this time. + UpToTime *time.Time `json:"up_to_time,omitempty"` } type JSApiMsgGetResponse struct { @@ -668,6 +712,19 @@ type JSApiConsumerDeleteResponse struct { const JSApiConsumerDeleteResponseType = "io.nats.jetstream.api.v1.consumer_delete_response" +type JSApiConsumerPauseRequest struct { + PauseUntil time.Time `json:"pause_until,omitempty"` +} + +const JSApiConsumerPauseResponseType = "io.nats.jetstream.api.v1.consumer_pause_response" + +type JSApiConsumerPauseResponse struct { + ApiResponse + Paused bool `json:"paused"` + PauseUntil time.Time `json:"pause_until"` + PauseRemaining time.Duration `json:"pause_remaining,omitempty"` +} + type JSApiConsumerInfoResponse struct { ApiResponse *ConsumerInfo @@ -703,6 +760,7 @@ type JSApiConsumerGetNextRequest struct { MaxBytes int `json:"max_bytes,omitempty"` NoWait bool `json:"no_wait,omitempty"` Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` + PriorityGroup } // JSApiStreamTemplateCreateResponse for creating templates. @@ -832,7 +890,7 @@ func (js *jetStream) apiDispatch(sub *subscription, c *client, acc *Account, sub // Copy the state. Note the JSAPI only uses the hdr index to piece apart the // header from the msg body. No other references are needed. // Check pending and warn if getting backed up. - pending := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa}) + pending, _ := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa}) limit := atomic.LoadInt64(&js.queueLimit) if pending >= int(limit) { s.rateLimitFormatWarnf("JetStream API queue limit reached, dropping %d requests", pending) @@ -943,6 +1001,8 @@ func (s *Server) setJetStreamExportSubs() error { {JSApiConsumerList, s.jsConsumerListRequest}, {JSApiConsumerInfo, s.jsConsumerInfoRequest}, {JSApiConsumerDelete, s.jsConsumerDeleteRequest}, + {JSApiConsumerPause, s.jsConsumerPauseRequest}, + {JSApiConsumerUnpin, s.jsConsumerUnpinRequest}, } js.mu.Lock() @@ -976,30 +1036,128 @@ func (s *Server) sendAPIErrResponse(ci *ClientInfo, acc *Account, subject, reply const errRespDelay = 500 * time.Millisecond -func (s *Server) sendDelayedAPIErrResponse(ci *ClientInfo, acc *Account, subject, reply, request, response string, rg *raftGroup) { - js := s.getJetStream() - if js == nil { +type delayedAPIResponse struct { + ci *ClientInfo + acc *Account + subject string + reply string + request string + response string + rg *raftGroup + deadline time.Time + next *delayedAPIResponse +} + +// Add `r` in the list that is maintained ordered by the `delayedAPIResponse.deadline` time. +func addDelayedResponse(head, tail **delayedAPIResponse, r *delayedAPIResponse) { + // Check if list empty. + if *head == nil { + *head, *tail = r, r return } - var quitCh <-chan struct{} - js.mu.RLock() - if rg != nil && rg.node != nil { - quitCh = rg.node.QuitC() + // Check if it should be added at the end, which is if after or equal to the tail. + if r.deadline.After((*tail).deadline) || r.deadline.Equal((*tail).deadline) { + (*tail).next, *tail = r, r + return } - js.mu.RUnlock() - - s.startGoRoutine(func() { - defer s.grWG.Done() - select { - case <-quitCh: - case <-s.quitCh: - case <-time.After(errRespDelay): - acc.trackAPIErr() - if reply != _EMPTY_ { - s.sendInternalAccountMsg(nil, reply, response) + // Find its spot in the list. + var prev *delayedAPIResponse + for c := *head; c != nil; c = c.next { + // We insert only if we are stricly before the current `c`. + if r.deadline.Before(c.deadline) { + r.next = c + if prev != nil { + prev.next = r + } else { + *head = r } - s.sendJetStreamAPIAuditAdvisory(ci, acc, subject, request, response) + return } + prev = c + } +} + +func (s *Server) delayedAPIResponder() { + defer s.grWG.Done() + var ( + head, tail *delayedAPIResponse // Linked list. + r *delayedAPIResponse // Updated by calling next(). + rqch <-chan struct{} // Quit channel of the Raft group (if present). + tm = time.NewTimer(time.Hour) + ) + next := func() { + r, rqch = nil, nil + // Check that JetStream is still on. Do not exit the go routine + // since JS can be enabled/disabled. The go routine will exit + // only if server is shutdown. + js := s.getJetStream() + if js == nil { + // Reset head and tail here. Also drain the ipQueue. + head, tail = nil, nil + s.delayedAPIResponses.drain() + // Fall back into next "if" that resets timer. + } + // If there are no delayed messages then delay the timer for + // a while. + if head == nil { + tm.Reset(time.Hour) + return + } + // Get the first expected message and then reset the timer. + r = head + js.mu.RLock() + if r.rg != nil && r.rg.node != nil { + // If there's an attached Raft group to the delayed response + // then pull out the quit channel, so that we don't bother + // sending responses for entities which are now no longer + // running. + rqch = r.rg.node.QuitC() + } + js.mu.RUnlock() + tm.Reset(time.Until(r.deadline)) + } + pop := func() { + if head == nil { + return + } + head = head.next + if head == nil { + tail = nil + } + } + for { + select { + case <-s.delayedAPIResponses.ch: + v, ok := s.delayedAPIResponses.popOne() + if !ok { + continue + } + // Add it to the list, and if ends up being the head, set things up. + addDelayedResponse(&head, &tail, v) + if v == head { + next() + } + case <-s.quitCh: + return + case <-rqch: + // If we were the head, drop and setup things for next. + if r != nil && r == head { + pop() + } + next() + case <-tm.C: + if r != nil { + s.sendAPIErrResponse(r.ci, r.acc, r.subject, r.reply, r.request, r.response) + pop() + } + next() + } + } +} + +func (s *Server) sendDelayedAPIErrResponse(ci *ClientInfo, acc *Account, subject, reply, request, response string, rg *raftGroup, duration time.Duration) { + s.delayedAPIResponses.push(&delayedAPIResponse{ + ci, acc, subject, reply, request, response, rg, time.Now().Add(duration), nil, }) } @@ -1030,6 +1188,32 @@ func (s *Server) getRequestInfo(c *client, raw []byte) (pci *ClientInfo, acc *Ac return &ci, acc, hdr, msg, nil } +func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg []byte, v any) error { + decoder := json.NewDecoder(bytes.NewReader(msg)) + decoder.DisallowUnknownFields() + + for { + if err := decoder.Decode(v); err != nil { + if err == io.EOF { + return nil + } + + var syntaxErr *json.SyntaxError + if errors.As(err, &syntaxErr) { + err = fmt.Errorf("%w at offset %d", err, syntaxErr.Offset) + } + + c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err) + + if s.JetStreamConfig().Strict { + return err + } + + return json.Unmarshal(msg, v) + } + } +} + func (a *Account) trackAPI() { a.mu.RLock() jsa := a.js @@ -1159,8 +1343,8 @@ func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Accoun } var cfg StreamTemplateConfig - if err := json.Unmarshal(msg, &cfg); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -1214,10 +1398,10 @@ func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account } var offset int - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiStreamTemplatesRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -1398,13 +1582,16 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account, return } - var cfg StreamConfig - if err := json.Unmarshal(msg, &cfg); err != nil { - resp.Error = NewJSInvalidJSONError() + var cfg StreamConfigRequest + if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } + // Initialize asset version metadata. + setStaticStreamMetadata(&cfg.StreamConfig) + streamName := streamNameFromSubject(subject) if streamName != cfg.Name { resp.Error = NewJSStreamMismatchError() @@ -1439,13 +1626,13 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account, return } - if err := acc.jsNonClusteredStreamLimitsCheck(&cfg); err != nil { + if err := acc.jsNonClusteredStreamLimitsCheck(&cfg.StreamConfig); err != nil { resp.Error = err s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - mset, err := acc.addStream(&cfg) + mset, err := acc.addStreamPedantic(&cfg.StreamConfig, cfg.Pedantic) if err != nil { if IsNatsErr(err, JSStreamStoreFailedF) { s.Warnf("Stream create failed for '%s > %s': %v", acc, streamName, err) @@ -1455,10 +1642,11 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account, s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), TimeStamp: time.Now().UTC(), Mirror: mset.mirrorInfo(), Sources: mset.sourcesInfo(), @@ -1505,14 +1693,14 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account, } return } - var ncfg StreamConfig - if err := json.Unmarshal(msg, &ncfg); err != nil { - resp.Error = NewJSInvalidJSONError() + var ncfg StreamConfigRequest + if err := s.unmarshalRequest(c, acc, subject, msg, &ncfg); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - cfg, apiErr := s.checkStreamCfg(&ncfg, acc) + cfg, apiErr := s.checkStreamCfg(&ncfg.StreamConfig, acc, ncfg.Pedantic) if apiErr != nil { resp.Error = apiErr s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) @@ -1529,7 +1717,7 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account, // Handle clustered version here. if s.JetStreamIsClustered() { // Always do in separate Go routine. - go s.jsClusteredStreamUpdateRequest(ci, acc, subject, reply, copyBytes(rmsg), &cfg, nil) + go s.jsClusteredStreamUpdateRequest(ci, acc, subject, reply, copyBytes(rmsg), &cfg, nil, ncfg.Pedantic) return } @@ -1540,16 +1728,20 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account, return } - if err := mset.update(&cfg); err != nil { + // Update asset version metadata. + setStaticStreamMetadata(&cfg) + + if err := mset.updatePedantic(&cfg, ncfg.Pedantic); err != nil { resp.Error = NewJSStreamUpdateError(err, Unless(err)) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), Domain: s.getOpts().JetStreamDomain, Mirror: mset.mirrorInfo(), Sources: mset.sourcesInfo(), @@ -1599,10 +1791,10 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account, var offset int var filter string - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiStreamNamesRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -1729,10 +1921,10 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s var offset int var filter string - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiStreamListRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -1843,12 +2035,12 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s if js.isLeaderless() { resp.Error = NewJSClusterNotAvailError() // Delaying an error response gives the leader a chance to respond before us - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil, errRespDelay) } return } else if isLeader && offline { resp.Error = NewJSStreamOfflineError() - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil, errRespDelay) return } @@ -1860,7 +2052,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s if js.isLeaderless() { resp.Error = NewJSClusterNotAvailError() // Delaying an error response gives the leader a chance to respond before us - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), sa.Group) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), sa.Group, errRespDelay) return } @@ -1900,10 +2092,10 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s var details bool var subjects string var offset int - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiStreamInfoRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -1929,11 +2121,10 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s } config := mset.config() - resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.stateWithDetail(details), - Config: config, + Config: *setDynamicStreamMetadata(&config), Domain: s.getOpts().JetStreamDomain, Cluster: js.clusterInfo(mset.raftGroup()), Mirror: mset.mirrorInfo(), @@ -2047,11 +2238,6 @@ func (s *Server) jsStreamLeaderStepDownRequest(sub *subscription, c *client, _ * } return } - if !isEmptyRequest(msg) { - resp.Error = NewJSBadRequestError() - s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) - return - } // Check to see if we are a member of the group and if the group has no leader. if js.isGroupLeaderless(sa.Group) { @@ -2078,18 +2264,35 @@ func (s *Server) jsStreamLeaderStepDownRequest(sub *subscription, c *client, _ * return } - // Call actual stepdown. Do this in a Go routine. - go func() { - if node := mset.raftNode(); node != nil { - mset.setLeader(false) - // TODO (mh) eventually make sure all go routines exited and all channels are cleared - time.Sleep(250 * time.Millisecond) - node.StepDown() - } - + node := mset.raftNode() + if node == nil { resp.Success = true s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) - }() + return + } + + var preferredLeader string + if isJSONObjectOrArray(msg) { + var req JSApiLeaderStepdownRequest + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + if preferredLeader, resp.Error = s.getStepDownPreferredPlacement(node, req.Placement); resp.Error != nil { + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + } + + // Call actual stepdown. + err = node.StepDown(preferredLeader) + if err != nil { + resp.Error = NewJSRaftGeneralError(err, Unless(err)) + } else { + resp.Success = true + } + s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) } // Request to have a consumer leader stepdown. @@ -2165,11 +2368,6 @@ func (s *Server) jsConsumerLeaderStepDownRequest(sub *subscription, c *client, _ } return } - if !isEmptyRequest(msg) { - resp.Error = NewJSBadRequestError() - s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) - return - } mset, err := acc.lookupStream(stream) if err != nil { @@ -2191,16 +2389,28 @@ func (s *Server) jsConsumerLeaderStepDownRequest(sub *subscription, c *client, _ return } - // Call actual stepdown. Do this in a Go routine. - go func() { - o.setLeader(false) - // TODO (mh) eventually make sure all go routines exited and all channels are cleared - time.Sleep(250 * time.Millisecond) - n.StepDown() + var preferredLeader string + if isJSONObjectOrArray(msg) { + var req JSApiLeaderStepdownRequest + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + if preferredLeader, resp.Error = s.getStepDownPreferredPlacement(n, req.Placement); resp.Error != nil { + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + } + // Call actual stepdown. + err = n.StepDown(preferredLeader) + if err != nil { + resp.Error = NewJSRaftGeneralError(err, Unless(err)) + } else { resp.Success = true - s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) - }() + } + s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) } // Request to remove a peer from a clustered stream. @@ -2260,8 +2470,8 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco } var req JSApiStreamRemovePeerRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -2340,8 +2550,8 @@ func (s *Server) jsLeaderServerRemoveRequest(sub *subscription, c *client, _ *Ac } var req JSApiMetaServerRemoveRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -2443,8 +2653,8 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _ var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}} var req JSApiMetaServerStreamMoveRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -2470,7 +2680,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _ if ok { sa, ok := streams[streamName] if ok { - cfg = *sa.Config + cfg = *sa.Config.clone() streamFound = true currPeers = sa.Group.Peers currCluster = sa.Group.Cluster @@ -2562,7 +2772,8 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _ accName, streamName, cfg.Replicas, s.peerSetToNames(currPeers), s.peerSetToNames(peers)) // We will always have peers and therefore never do a callout, therefore it is safe to call inline - s.jsClusteredStreamUpdateRequest(&ciNew, targetAcc.(*Account), subject, reply, rmsg, &cfg, peers) + // We should be fine ignoring pedantic mode here. as we do not touch configuration. + s.jsClusteredStreamUpdateRequest(&ciNew, targetAcc.(*Account), subject, reply, rmsg, &cfg, peers, false) } // Request to have the metaleader move a stream on a peer to another @@ -2611,7 +2822,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli if ok { sa, ok := streams[streamName] if ok { - cfg = *sa.Config + cfg = *sa.Config.clone() streamFound = true currPeers = sa.Group.Peers } @@ -2668,7 +2879,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli cfg.Replicas, accName, streamName, s.peerSetToNames(currPeers), s.peerSetToNames(peers)) // We will always have peers and therefore never do a callout, therefore it is safe to call inline - s.jsClusteredStreamUpdateRequest(&ciNew, targetAcc.(*Account), subject, reply, rmsg, &cfg, peers) + s.jsClusteredStreamUpdateRequest(&ciNew, targetAcc.(*Account), subject, reply, rmsg, &cfg, peers, false) } // Request to have an account purged @@ -2790,41 +3001,16 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun var preferredLeader string var resp = JSApiLeaderStepDownResponse{ApiResponse: ApiResponse{Type: JSApiLeaderStepDownResponseType}} - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiLeaderStepdownRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - if req.Placement != nil { - if len(req.Placement.Tags) > 0 { - // Tags currently not supported. - resp.Error = NewJSClusterTagsError() - s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) - return - } - cn := req.Placement.Cluster - var peers []string - ourID := cc.meta.ID() - for _, p := range cc.meta.Peers() { - if si, ok := s.nodeToInfo.Load(p.ID); ok && si != nil { - if ni := si.(nodeInfo); ni.offline || ni.cluster != cn || p.ID == ourID { - continue - } - peers = append(peers, p.ID) - } - } - if len(peers) == 0 { - resp.Error = NewJSClusterNoPeersError(fmt.Errorf("no replacement peer connected")) - s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) - return - } - // Randomize and select. - if len(peers) > 1 { - rand.Shuffle(len(peers), func(i, j int) { peers[i], peers[j] = peers[j], peers[i] }) - } - preferredLeader = peers[0] + if preferredLeader, resp.Error = s.getStepDownPreferredPlacement(cc.meta, req.Placement); resp.Error != nil { + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return } } @@ -2838,6 +3024,25 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) } +// Check if given []bytes is a JSON Object or Array. +// Technically, valid JSON can also be a plain string or number, but for our use case, +// we care only for JSON objects or arrays which starts with `[` or `{`. +// This function does not have to ensure valid JSON in its entirety. It is used merely +// to hint the codepath if it should attempt to parse the request as JSON or not. +func isJSONObjectOrArray(req []byte) bool { + // Skip leading JSON whitespace (space, tab, newline, carriage return) + i := 0 + for i < len(req) && (req[i] == ' ' || req[i] == '\t' || req[i] == '\n' || req[i] == '\r') { + i++ + } + // Check for empty input after trimming + if i >= len(req) { + return false + } + // Check if the first non-whitespace character is '{' or '[' + return req[i] == '{' || req[i] == '[' +} + func isEmptyRequest(req []byte) bool { if len(req) == 0 { return true @@ -2857,6 +3062,84 @@ func isEmptyRequest(req []byte) bool { return len(vm) == 0 } +// getStepDownPreferredPlacement attempts to work out what the best placement is +// for a stepdown request. The preferred server name always takes precedence, but +// if not specified, the placement will be used to filter by cluster. The caller +// should check for return API errors and return those to the requestor if needed. +func (s *Server) getStepDownPreferredPlacement(group RaftNode, placement *Placement) (string, *ApiError) { + if placement == nil { + return _EMPTY_, nil + } + var preferredLeader string + if placement.Preferred != _EMPTY_ { + for _, p := range group.Peers() { + si, ok := s.nodeToInfo.Load(p.ID) + if !ok || si == nil { + continue + } + if si.(nodeInfo).name == placement.Preferred { + preferredLeader = p.ID + break + } + } + if preferredLeader == group.ID() { + return _EMPTY_, NewJSClusterNoPeersError(fmt.Errorf("preferred server %q is already leader", placement.Preferred)) + } + if preferredLeader == _EMPTY_ { + return _EMPTY_, NewJSClusterNoPeersError(fmt.Errorf("preferred server %q not known", placement.Preferred)) + } + } else { + possiblePeers := make(map[*Peer]nodeInfo, len(group.Peers())) + ourID := group.ID() + for _, p := range group.Peers() { + if p == nil { + continue // ... shouldn't happen. + } + si, ok := s.nodeToInfo.Load(p.ID) + if !ok || si == nil { + continue + } + ni := si.(nodeInfo) + if ni.offline || p.ID == ourID { + continue + } + possiblePeers[p] = ni + } + // If cluster is specified, filter out anything not matching the cluster name. + if placement.Cluster != _EMPTY_ { + for p, si := range possiblePeers { + if si.cluster != placement.Cluster { + delete(possiblePeers, p) + } + } + } + // If tags are specified, filter out anything not matching all supplied tags. + if len(placement.Tags) > 0 { + for p, si := range possiblePeers { + matchesAll := true + for _, tag := range placement.Tags { + if matchesAll = matchesAll && si.tags.Contains(tag); !matchesAll { + break + } + } + if !matchesAll { + delete(possiblePeers, p) + } + } + } + // If there are no possible peers, return an error. + if len(possiblePeers) == 0 { + return _EMPTY_, NewJSClusterNoPeersError(fmt.Errorf("no replacement peer connected")) + } + // Take advantage of random map iteration order to select the preferred. + for p := range possiblePeers { + preferredLeader = p.ID + break + } + } + return preferredLeader, nil +} + // Request to delete a stream. func (s *Server) jsStreamDeleteRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { if c == nil || !s.JetStreamEnabled() { @@ -3000,8 +3283,8 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su return } var req JSApiMsgDeleteRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -3119,20 +3402,26 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje return } var req JSApiMsgGetRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - // Check that we do not have both options set. - if req.Seq > 0 && req.LastFor != _EMPTY_ || req.Seq == 0 && req.LastFor == _EMPTY_ && req.NextFor == _EMPTY_ { + // This version does not support batch. + if req.Batch > 0 || req.MaxBytes > 0 { resp.Error = NewJSBadRequestError() s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - // Check that both last and next not both set. - if req.LastFor != _EMPTY_ && req.NextFor != _EMPTY_ { + + // Validate non-conflicting options. Seq, LastFor, and AsOfTime are mutually exclusive. + // NextFor can be paired with Seq or AsOfTime indicating a filter subject. + if (req.Seq > 0 && req.LastFor != _EMPTY_) || + (req.Seq == 0 && req.LastFor == _EMPTY_ && req.NextFor == _EMPTY_ && req.StartTime == nil) || + (req.Seq > 0 && req.StartTime != nil) || + (req.StartTime != nil && req.LastFor != _EMPTY_) || + (req.LastFor != _EMPTY_ && req.NextFor != _EMPTY_) { resp.Error = NewJSBadRequestError() s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return @@ -3148,10 +3437,18 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje var svp StoreMsg var sm *StoreMsg - if req.Seq > 0 && req.NextFor == _EMPTY_ { - sm, err = mset.store.LoadMsg(req.Seq, &svp) + // If AsOfTime is set, perform this first to get the sequence. + var seq uint64 + if req.StartTime != nil { + seq = mset.store.GetSeqFromTime(*req.StartTime) + } else { + seq = req.Seq + } + + if seq > 0 && req.NextFor == _EMPTY_ { + sm, err = mset.store.LoadMsg(seq, &svp) } else if req.NextFor != _EMPTY_ { - sm, _, err = mset.store.LoadNextMsg(req.NextFor, subjectHasWildcard(req.NextFor), req.Seq, &svp) + sm, _, err = mset.store.LoadNextMsg(req.NextFor, subjectHasWildcard(req.NextFor), seq, &svp) } else { sm, err = mset.store.LoadLastMsg(req.LastFor, &svp) } @@ -3172,6 +3469,122 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje s.sendInternalAccountMsg(nil, reply, s.jsonResponse(resp)) } +func (s *Server) jsConsumerUnpinRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { + if c == nil || !s.JetStreamEnabled() { + return + } + + ci, acc, _, msg, err := s.getRequestInfo(c, rmsg) + if err != nil { + s.Warnf(badAPIRequestT, msg) + return + } + + stream := streamNameFromSubject(subject) + consumer := consumerNameFromSubject(subject) + + var req JSApiConsumerUnpinRequest + var resp = JSApiConsumerUnpinResponse{ApiResponse: ApiResponse{Type: JSApiConsumerUnpinResponseType}} + + if err := json.Unmarshal(msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + if req.Group == _EMPTY_ { + resp.Error = NewJSInvalidJSONError(errors.New("consumer group not specified")) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + if !validGroupName.MatchString(req.Group) { + resp.Error = NewJSConsumerInvalidGroupNameError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + if s.JetStreamIsClustered() { + // Check to make sure the stream is assigned. + js, cc := s.getJetStreamCluster() + if js == nil || cc == nil { + return + } + + // First check if the stream and consumer is there. + js.mu.RLock() + sa := js.streamAssignment(acc.Name, stream) + if sa == nil { + js.mu.RUnlock() + resp.Error = NewJSStreamNotFoundError(Unless(err)) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + ca, ok := sa.consumers[consumer] + if !ok || ca == nil { + js.mu.RUnlock() + resp.Error = NewJSConsumerNotFoundError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + js.mu.RUnlock() + + // Then check if we are the leader. + mset, err := acc.lookupStream(stream) + if err != nil { + return + } + + o := mset.lookupConsumer(consumer) + if o == nil { + return + } + if !o.isLeader() { + return + } + } + + if hasJS, doErr := acc.checkJetStream(); !hasJS { + if doErr { + resp.Error = NewJSNotEnabledForAccountError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + } + return + } + + mset, err := acc.lookupStream(stream) + if err != nil { + resp.Error = NewJSStreamNotFoundError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + o := mset.lookupConsumer(consumer) + if o == nil { + resp.Error = NewJSConsumerNotFoundError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + var foundPriority bool + for _, group := range o.config().PriorityGroups { + if group == req.Group { + foundPriority = true + break + } + } + if !foundPriority { + resp.Error = NewJSConsumerInvalidPriorityGroupError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + o.mu.Lock() + o.currentPinId = _EMPTY_ + o.sendUnpinnedAdvisoryLocked(req.Group, "admin") + o.mu.Unlock() + s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) +} + // Request to purge a stream. func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { if c == nil || !s.JetStreamEnabled() { @@ -3246,10 +3659,10 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account, } var purgeRequest *JSApiStreamPurgeRequest - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiStreamPurgeRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -3338,8 +3751,8 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account } var req JSApiStreamRestoreRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -3351,7 +3764,7 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account } // check stream config at the start of the restore process, not at the end - cfg, apiErr := s.checkStreamCfg(&req.Config, acc) + cfg, apiErr := s.checkStreamCfg(&req.Config, acc, false) if apiErr != nil { resp.Error = apiErr s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) @@ -3562,10 +3975,11 @@ func (s *Server) processStreamRestore(ci *ClientInfo, acc *Account, cfg *StreamC s.Warnf("Restore failed for %s for stream '%s > %s' in %v", friendlyBytes(int64(total)), acc.Name, streamName, end.Sub(start)) } else { + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), TimeStamp: time.Now().UTC(), } s.Noticef("Completed restore of %s for stream '%s > %s' in %v", @@ -3640,8 +4054,8 @@ func (s *Server) jsStreamSnapshotRequest(sub *subscription, c *client, _ *Accoun } var req JSApiStreamSnapshotRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, smsg, s.jsonResponse(&resp)) return } @@ -3806,6 +4220,11 @@ func (s *Server) streamSnapshot(acc *Account, mset *stream, sr *SnapshotResult, mset.outq.send(newJSPubMsg(reply, _EMPTY_, ackReply, nil, chunk, nil, 0)) atomic.AddInt32(&out, int32(len(chunk))) } + + if err := <-sr.errCh; err != _EMPTY_ { + hdr = []byte(fmt.Sprintf("NATS/1.0 500 %s\r\n\r\n", err)) + } + done: // Send last EOF // TODO(dlc) - place hash in header @@ -3838,8 +4257,8 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}} var req CreateConsumerRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -3984,9 +4403,9 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun // during this call, so place in Go routine to not block client. // Router and Gateway API calls already in separate context. if c.kind != ROUTER && c.kind != GATEWAY { - go s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action) + go s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action, req.Pedantic) } else { - s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action) + s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action, req.Pedantic) } return } @@ -4005,7 +4424,16 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun return } - o, err := stream.addConsumerWithAction(&req.Config, req.Action) + if o := stream.lookupConsumer(consumerName); o != nil { + // If the consumer already exists then don't allow updating the PauseUntil, just set + // it back to whatever the current configured value is. + req.Config.PauseUntil = o.cfg.PauseUntil + } + + // Initialize/update asset version metadata. + setStaticConsumerMetadata(&req.Config) + + o, err := stream.addConsumerWithAction(&req.Config, req.Action, req.Pedantic) if err != nil { if IsNatsErr(err, JSConsumerStoreFailedErrF) { @@ -4017,8 +4445,12 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } - resp.ConsumerInfo = o.initialInfo() + resp.ConsumerInfo = setDynamicConsumerInfoMetadata(o.initialInfo()) s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) + + if o.cfg.PauseUntil != nil && !o.cfg.PauseUntil.IsZero() && time.Now().Before(*o.cfg.PauseUntil) { + o.sendPauseAdvisoryLocked(&o.cfg) + } } // Request for the list of all consumer names. @@ -4063,10 +4495,10 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account } var offset int - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiConsumersRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -4185,10 +4617,10 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account, } var offset int - if !isEmptyRequest(msg) { + if isJSONObjectOrArray(msg) { var req JSApiConsumersRequest - if err := json.Unmarshal(msg, &req); err != nil { - resp.Error = NewJSInvalidJSONError() + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) return } @@ -4316,12 +4748,12 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, if isLeaderLess { resp.Error = NewJSClusterNotAvailError() // Delaying an error response gives the leader a chance to respond before us - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil, errRespDelay) } return } else if isLeader && offline { resp.Error = NewJSConsumerOfflineError() - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil, errRespDelay) return } @@ -4337,7 +4769,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, if isLeaderLess { resp.Error = NewJSClusterNotAvailError() // Delaying an error response gives the leader a chance to respond before us - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), ca.Group) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), ca.Group, errRespDelay) return } @@ -4347,7 +4779,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, // We have a consumer assignment. if isMember { js.mu.RLock() - if rg.node != nil { + if rg != nil && rg.node != nil { node = rg.node if gl := node.GroupLeader(); gl != _EMPTY_ && !rg.isMember(gl) { leaderNotPartOfGroup = true @@ -4367,7 +4799,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, Stream: ca.Stream, Name: ca.Name, Created: ca.Created, - Config: ca.Config, + Config: setDynamicConsumerMetadata(ca.Config), TimeStamp: time.Now().UTC(), } b := s.jsonResponse(resp) @@ -4377,14 +4809,14 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, return } // If we are a member and we have a group leader or we had a previous leader consider bailing out. - if !node.Leaderless() || node.HadPreviousLeader() { + if !node.Leaderless() || node.HadPreviousLeader() || (rg != nil && rg.Preferred != _EMPTY_ && rg.Preferred != ourID) { if leaderNotPartOfGroup { resp.Error = NewJSConsumerOfflineError() - s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil) + s.sendDelayedAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp), nil, errRespDelay) } return } - // If we are here we are a member and this is just a new consumer that does not have a leader yet. + // If we are here we are a member and this is just a new consumer that does not have a (preferred) leader yet. // Will fall through and return what we have. All consumers can respond but this should be very rare // but makes more sense to clients when they try to create, get a consumer exists, and then do consumer info. } @@ -4410,7 +4842,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account, return } - if resp.ConsumerInfo = obs.info(); resp.ConsumerInfo == nil { + if resp.ConsumerInfo = setDynamicConsumerInfoMetadata(obs.info()); resp.ConsumerInfo == nil { // This consumer returned nil which means it's closed. Respond with not found. resp.Error = NewJSConsumerNotFoundError() s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) @@ -4491,7 +4923,142 @@ func (s *Server) jsConsumerDeleteRequest(sub *subscription, c *client, _ *Accoun s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) } -// sendJetStreamAPIAuditAdvisor will send the audit event for a given event. +// Request to pause or unpause a Consumer. +func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { + if c == nil || !s.JetStreamEnabled() { + return + } + ci, acc, _, msg, err := s.getRequestInfo(c, rmsg) + if err != nil { + s.Warnf(badAPIRequestT, msg) + return + } + + var req JSApiConsumerPauseRequest + var resp = JSApiConsumerPauseResponse{ApiResponse: ApiResponse{Type: JSApiConsumerPauseResponseType}} + + if isJSONObjectOrArray(msg) { + if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil { + resp.Error = NewJSInvalidJSONError(err) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + } + + // Determine if we should proceed here when we are in clustered mode. + isClustered := s.JetStreamIsClustered() + js, cc := s.getJetStreamCluster() + if isClustered { + if js == nil || cc == nil { + return + } + if js.isLeaderless() { + resp.Error = NewJSClusterNotAvailError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + // Make sure we are meta leader. + if !s.JetStreamIsLeader() { + return + } + } + + if hasJS, doErr := acc.checkJetStream(); !hasJS { + if doErr { + resp.Error = NewJSNotEnabledForAccountError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + } + return + } + + stream := streamNameFromSubject(subject) + consumer := consumerNameFromSubject(subject) + + if isClustered { + js.mu.RLock() + sa := js.streamAssignment(acc.Name, stream) + if sa == nil { + js.mu.RUnlock() + resp.Error = NewJSStreamNotFoundError(Unless(err)) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + ca, ok := sa.consumers[consumer] + if !ok || ca == nil { + js.mu.RUnlock() + resp.Error = NewJSConsumerNotFoundError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + nca := *ca + ncfg := *ca.Config + nca.Config = &ncfg + js.mu.RUnlock() + pauseUTC := req.PauseUntil.UTC() + if !pauseUTC.IsZero() { + nca.Config.PauseUntil = &pauseUTC + } else { + nca.Config.PauseUntil = nil + } + + // Update asset version metadata due to updating pause/resume. + // Only PauseUntil is updated above, so reuse config for both. + setStaticConsumerMetadata(nca.Config) + + eca := encodeAddConsumerAssignment(&nca) + cc.meta.Propose(eca) + + resp.PauseUntil = pauseUTC + if resp.Paused = time.Now().Before(pauseUTC); resp.Paused { + resp.PauseRemaining = time.Until(pauseUTC) + } + s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) + return + } + + mset, err := acc.lookupStream(stream) + if err != nil { + resp.Error = NewJSStreamNotFoundError(Unless(err)) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + obs := mset.lookupConsumer(consumer) + if obs == nil { + resp.Error = NewJSConsumerNotFoundError() + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + ncfg := obs.cfg + pauseUTC := req.PauseUntil.UTC() + if !pauseUTC.IsZero() { + ncfg.PauseUntil = &pauseUTC + } else { + ncfg.PauseUntil = nil + } + + // Update asset version metadata due to updating pause/resume. + setStaticConsumerMetadata(&ncfg) + + if err := obs.updateConfig(&ncfg); err != nil { + // The only type of error that should be returned here is from o.store, + // so use a store failed error type. + resp.Error = NewJSConsumerStoreFailedError(err) + s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp)) + return + } + + resp.PauseUntil = pauseUTC + if resp.Paused = time.Now().Before(pauseUTC); resp.Paused { + resp.PauseRemaining = time.Until(pauseUTC) + } + s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp)) +} + +// sendJetStreamAPIAuditAdvisory will send the audit event for a given event. func (s *Server) sendJetStreamAPIAuditAdvisory(ci *ClientInfo, acc *Account, subject, request, response string) { s.publishAdvisory(acc, JSAuditAdvisory, JSAPIAudit{ TypedEvent: TypedEvent{ diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_cluster.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_cluster.go index c1d7d255d9..83100e2379 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_cluster.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_cluster.go @@ -79,8 +79,9 @@ type inflightInfo struct { // Used to guide placement of streams and meta controllers in clustered JetStream. type Placement struct { - Cluster string `json:"cluster,omitempty"` - Tags []string `json:"tags,omitempty"` + Cluster string `json:"cluster,omitempty"` + Tags []string `json:"tags,omitempty"` + Preferred string `json:"preferred,omitempty"` } // Define types of the entry. @@ -295,7 +296,7 @@ func (s *Server) JetStreamStepdownStream(account, stream string) error { return err } - if node := mset.raftNode(); node != nil && node.Leader() { + if node := mset.raftNode(); node != nil { node.StepDown() } @@ -326,7 +327,7 @@ func (s *Server) JetStreamStepdownConsumer(account, stream, consumer string) err return NewJSConsumerNotFoundError() } - if node := o.raftNode(); node != nil && node.Leader() { + if node := o.raftNode(); node != nil { node.StepDown() } @@ -1137,65 +1138,6 @@ func (js *jetStream) checkForOrphans() { } } -// Check and delete any orphans we may come across. -func (s *Server) checkForNRGOrphans() { - js, cc := s.getJetStreamCluster() - if js == nil || cc == nil || js.isMetaRecovering() { - // No cluster means no NRGs. Also return if still recovering. - return - } - - // Track which assets R>1 should be on this server. - nrgMap := make(map[string]struct{}) - trackGroup := func(rg *raftGroup) { - // If R>1 track this as a legit NRG. - if rg.node != nil { - nrgMap[rg.Name] = struct{}{} - } - } - // Register our meta. - js.mu.RLock() - meta := cc.meta - if meta == nil { - js.mu.RUnlock() - // Bail with no meta node. - return - } - - ourID := meta.ID() - nrgMap[meta.Group()] = struct{}{} - - // Collect all valid groups from our assignments. - for _, asa := range cc.streams { - for _, sa := range asa { - if sa.Group.isMember(ourID) && sa.Restore == nil { - trackGroup(sa.Group) - for _, ca := range sa.consumers { - if ca.Group.isMember(ourID) { - trackGroup(ca.Group) - } - } - } - } - } - js.mu.RUnlock() - - // Check NRGs that are running. - var needDelete []RaftNode - s.rnMu.RLock() - for name, n := range s.raftNodes { - if _, ok := nrgMap[name]; !ok { - needDelete = append(needDelete, n) - } - } - s.rnMu.RUnlock() - - for _, n := range needDelete { - s.Warnf("Detected orphaned NRG %q, will cleanup", n.Group()) - n.Delete() - } -} - func (js *jetStream) monitorCluster() { s, n := js.server(), js.getMetaGroup() qch, rqch, lch, aq := js.clusterQuitC(), n.QuitC(), n.LeadChangeC(), n.ApplyQ() @@ -1228,8 +1170,6 @@ func (js *jetStream) monitorCluster() { if hs := s.healthz(nil); hs.Error != _EMPTY_ { s.Warnf("%v", hs.Error) } - // Also check for orphaned NRGs. - s.checkForNRGOrphans() } var ( @@ -2204,15 +2144,6 @@ func (mset *stream) removeNode() { } } -func (mset *stream) clearRaftNode() { - if mset == nil { - return - } - mset.mu.Lock() - defer mset.mu.Unlock() - mset.node = nil -} - // Helper function to generate peer info. // lists and sets for old and new. func genPeerInfo(peers []string, split int) (newPeers, oldPeers []string, newPeerSet, oldPeerSet map[string]bool) { @@ -2330,9 +2261,7 @@ func (js *jetStream) monitorStream(mset *stream, sa *streamAssignment, sendSnaps if n.State() == Closed { return } - if n.Leader() { - n.StepDown() - } + n.StepDown() // Drain the commit queue... aq.drain() }() @@ -2443,8 +2372,7 @@ func (js *jetStream) monitorStream(mset *stream, sa *streamAssignment, sendSnaps var cist *time.Ticker var cistc <-chan time.Time - // 2 minutes plus up to 30s jitter. - checkInterestInterval := 2*time.Minute + time.Duration(rand.Intn(30))*time.Second + checkInterestInterval := checkInterestStateT + time.Duration(rand.Intn(checkInterestStateJ))*time.Second if mset != nil && mset.isInterestRetention() { // Wait on our consumers to be assigned and running before proceeding. @@ -2867,11 +2795,11 @@ func (mset *stream) isMigrating() bool { func (mset *stream) resetClusteredState(err error) bool { mset.mu.RLock() s, js, jsa, sa, acc, node := mset.srv, mset.js, mset.jsa, mset.sa, mset.acc, mset.node - stype, isLeader, tierName, replicas := mset.cfg.Storage, mset.isLeader(), mset.tier, mset.cfg.Replicas + stype, tierName, replicas := mset.cfg.Storage, mset.tier, mset.cfg.Replicas mset.mu.RUnlock() // Stepdown regardless if we are the leader here. - if isLeader && node != nil { + if node != nil { node.StepDown() } @@ -2976,7 +2904,7 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco } } - subject, reply, hdr, msg, lseq, ts, err := decodeStreamMsg(mbuf) + subject, reply, hdr, msg, lseq, ts, sourced, err := decodeStreamMsg(mbuf) if err != nil { if node := mset.raftNode(); node != nil { s.Errorf("JetStream cluster could not decode stream msg for '%s > %s' [%s]", @@ -3025,8 +2953,15 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco continue } + var mt *msgTrace + // If not recovering, see if we find a message trace object for this + // sequence. Only the leader that has proposed this entry will have + // stored the trace info. + if !isRecovering { + mt = mset.getAndDeleteMsgTrace(lseq) + } // Process the actual message here. - err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts) + err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts, mt, sourced) // If we have inflight make sure to clear after processing. // TODO(dlc) - technically check on inflight != nil could cause datarace. @@ -3037,6 +2972,16 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco mset.clMu.Unlock() } + // Clear expected per subject state after processing. + if mset.expectedPerSubjectSequence != nil { + mset.clMu.Lock() + if subj, found := mset.expectedPerSubjectSequence[lseq]; found { + delete(mset.expectedPerSubjectSequence, lseq) + delete(mset.expectedPerSubjectInProcess, subj) + } + mset.clMu.Unlock() + } + if err != nil { if err == errLastSeqMismatch { @@ -3049,7 +2994,7 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco if state.Msgs == 0 { mset.store.Compact(lseq + 1) // Retry - err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts) + err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts, mt, sourced) } // FIXME(dlc) - We could just run a catchup with a request defining the span between what we expected // and what we got. @@ -3203,7 +3148,7 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco } if isRecovering || !mset.IsLeader() { - if err := mset.processSnapshot(ss); err != nil { + if err := mset.processSnapshot(ss, ce.Index); err != nil { return err } } @@ -3257,9 +3202,34 @@ func (js *jetStream) processStreamLeaderChange(mset *stream, isLeader bool) { return } - // Clear inflight if we have it. + // Clear inflight dedupe IDs, where seq=0. + mset.mu.Lock() + var removed int + for i := len(mset.ddarr) - 1; i >= mset.ddindex; i-- { + dde := mset.ddarr[i] + if dde.seq != 0 { + break + } + removed++ + delete(mset.ddmap, dde.id) + } + if removed > 0 { + if len(mset.ddmap) > 0 { + mset.ddarr = mset.ddarr[:len(mset.ddarr)-removed] + } else { + mset.ddmap = nil + mset.ddarr = nil + mset.ddindex = 0 + } + } + mset.mu.Unlock() + mset.clMu.Lock() + // Clear inflight if we have it. mset.inflight = nil + // Clear expected per subject state. + mset.expectedPerSubjectSequence = nil + mset.expectedPerSubjectInProcess = nil mset.clMu.Unlock() js.mu.Lock() @@ -3267,7 +3237,6 @@ func (js *jetStream) processStreamLeaderChange(mset *stream, isLeader bool) { client, subject, reply := sa.Client, sa.Subject, sa.Reply hasResponded := sa.responded sa.responded = true - peers := copyStrings(sa.Group.Peers) js.mu.Unlock() streamName := mset.name() @@ -3275,7 +3244,6 @@ func (js *jetStream) processStreamLeaderChange(mset *stream, isLeader bool) { if isLeader { s.Noticef("JetStream cluster new stream leader for '%s > %s'", account, streamName) s.sendStreamLeaderElectAdvisory(mset) - mset.checkAllowMsgCompress(peers) } else { // We are stepping down. // Make sure if we are doing so because we have lost quorum that we send the appropriate advisories. @@ -3310,10 +3278,11 @@ func (js *jetStream) processStreamLeaderChange(mset *stream, isLeader bool) { resp.Error = NewJSStreamCreateError(err, Unless(err)) s.sendAPIErrResponse(client, acc, subject, reply, _EMPTY_, s.jsonResponse(&resp)) } else { + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), Cluster: js.clusterInfo(mset.raftGroup()), Sources: mset.sourcesInfo(), Mirror: mset.mirrorInfo(), @@ -3591,9 +3560,7 @@ func (s *Server) removeStream(mset *stream, nsa *streamAssignment) { // Make sure to use the new stream assignment, not our own. s.Debugf("JetStream removing stream '%s > %s' from this server", nsa.Client.serviceAccount(), nsa.Config.Name) if node := mset.raftNode(); node != nil { - if node.Leader() { - node.StepDown(nsa.Group.Preferred) - } + node.StepDown(nsa.Group.Preferred) // shutdown monitor by shutting down raft. node.Delete() } @@ -3689,7 +3656,7 @@ func (js *jetStream) processClusterUpdateStream(acc *Account, osa, sa *streamAss mset.setStreamAssignment(sa) // Call update. - if err = mset.updateWithAdvisory(cfg, !recovering); err != nil { + if err = mset.updateWithAdvisory(cfg, !recovering, false); err != nil { s.Warnf("JetStream cluster error updating stream %q for account %q: %v", cfg.Name, acc.Name, err) } } @@ -3739,10 +3706,11 @@ func (js *jetStream) processClusterUpdateStream(acc *Account, osa, sa *streamAss // Send our response. var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}} + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), Cluster: js.clusterInfo(mset.raftGroup()), Mirror: mset.mirrorInfo(), Sources: mset.sourcesInfo(), @@ -3806,10 +3774,11 @@ func (js *jetStream) processClusterCreateStream(acc *Account, sa *streamAssignme if !recovering { var resp = JSApiStreamCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamCreateResponseType}} + msetCfg := mset.config() resp.StreamInfo = &StreamInfo{ Created: mset.createdTime(), State: mset.state(), - Config: mset.config(), + Config: *setDynamicStreamMetadata(&msetCfg), Cluster: js.clusterInfo(mset.raftGroup()), Sources: mset.sourcesInfo(), Mirror: mset.mirrorInfo(), @@ -3835,7 +3804,7 @@ func (js *jetStream) processClusterCreateStream(acc *Account, sa *streamAssignme // Check if our config has really been updated. cfg := mset.config() if !reflect.DeepEqual(&cfg, sa.Config) { - if err = mset.updateWithAdvisory(sa.Config, false); err != nil { + if err = mset.updateWithAdvisory(sa.Config, false, false); err != nil { s.Warnf("JetStream cluster error updating stream %q for account %q: %v", sa.Config.Name, acc.Name, err) if osa != nil { // Process the raft group and make sure it's running if needed. @@ -3854,7 +3823,7 @@ func (js *jetStream) processClusterCreateStream(acc *Account, sa *streamAssignme } } else if err == NewJSStreamNotFoundError() { // Add in the stream here. - mset, err = acc.addStreamWithAssignment(sa.Config, nil, sa) + mset, err = acc.addStreamWithAssignment(sa.Config, nil, sa, false) } if mset != nil { mset.setCreatedTime(sa.Created) @@ -4393,7 +4362,7 @@ func (js *jetStream) processClusterCreateConsumer(ca *consumerAssignment, state var didCreate, isConfigUpdate, needsLocalResponse bool if o == nil { // Add in the consumer if needed. - if o, err = mset.addConsumerWithAssignment(ca.Config, ca.Name, ca, js.isMetaRecovering(), ActionCreateOrUpdate); err == nil { + if o, err = mset.addConsumerWithAssignment(ca.Config, ca.Name, ca, js.isMetaRecovering(), ActionCreateOrUpdate, false); err == nil { didCreate = true } } else { @@ -4403,7 +4372,11 @@ func (js *jetStream) processClusterCreateConsumer(ca *consumerAssignment, state if isConfigUpdate = !reflect.DeepEqual(&cfg, ca.Config); isConfigUpdate { // Call into update, ignore consumer exists error here since this means an old deliver subject is bound // which can happen on restart etc. - if err := o.updateConfig(ca.Config); err != nil && err != NewJSConsumerNameExistError() { + // JS lock needed as this can mutate the consumer assignments and race with updateInactivityThreshold. + js.mu.Lock() + err := o.updateConfig(ca.Config) + js.mu.Unlock() + if err != nil && err != NewJSConsumerNameExistError() { // This is essentially an update that has failed. Respond back to metaleader if we are not recovering. js.mu.RLock() if !js.metaRecovering { @@ -4537,7 +4510,7 @@ func (js *jetStream) processClusterCreateConsumer(ca *consumerAssignment, state client, subject, reply := ca.Client, ca.Subject, ca.Reply js.mu.Unlock() var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}} - resp.ConsumerInfo = o.info() + resp.ConsumerInfo = setDynamicConsumerInfoMetadata(o.info()) s.sendAPIResponse(client, acc, subject, reply, _EMPTY_, s.jsonResponse(&resp)) return } @@ -4575,7 +4548,7 @@ func (js *jetStream) processClusterCreateConsumer(ca *consumerAssignment, state js.mu.RUnlock() if !recovering { var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}} - resp.ConsumerInfo = o.info() + resp.ConsumerInfo = setDynamicConsumerInfoMetadata(o.info()) s.sendAPIResponse(client, acc, subject, reply, _EMPTY_, s.jsonResponse(&resp)) } } @@ -4886,6 +4859,7 @@ func (js *jetStream) monitorConsumer(o *consumer, ca *consumerAssignment) { } } aq.recycle(&ces) + case isLeader = <-lch: if recovering && !isLeader { js.setConsumerAssignmentRecovering(ca) @@ -5082,11 +5056,13 @@ func (js *jetStream) applyConsumerEntries(o *consumer, ce *CommittedEntry, isLea } case updateSkipOp: o.mu.Lock() - if !o.isLeader() { - var le = binary.LittleEndian - if sseq := le.Uint64(buf[1:]); sseq > o.sseq { - o.sseq = sseq - } + var le = binary.LittleEndian + sseq := le.Uint64(buf[1:]) + if !o.isLeader() && sseq > o.sseq { + o.sseq = sseq + } + if o.store != nil { + o.store.UpdateStarting(sseq - 1) } o.mu.Unlock() case addPendingRequest: @@ -5121,6 +5097,16 @@ func (o *consumer) processReplicatedAck(dseq, sseq uint64) error { // Update activity. o.lat = time.Now() + var sagap uint64 + if o.cfg.AckPolicy == AckAll { + // Always use the store state, as o.asflr is skipped ahead already. + // Capture before updating store. + state, err := o.store.BorrowState() + if err == nil { + sagap = sseq - state.AckFloor.Stream + } + } + // Do actual ack update to store. // Always do this to have it recorded. o.store.UpdateAcks(dseq, sseq) @@ -5145,21 +5131,6 @@ func (o *consumer) processReplicatedAck(dseq, sseq uint64) error { o.mu.Unlock() return nil } - - var sagap uint64 - if o.cfg.AckPolicy == AckAll { - if o.isLeader() { - sagap = sseq - o.asflr - } else { - // We are a follower so only have the store state, so read that in. - state, err := o.store.State() - if err != nil { - o.mu.Unlock() - return err - } - sagap = sseq - state.AckFloor.Stream - } - } o.mu.Unlock() if sagap > 1 { @@ -5262,11 +5233,18 @@ func (js *jetStream) processConsumerLeaderChange(o *consumer, isLeader bool) err resp.Error = NewJSConsumerCreateError(err, Unless(err)) s.sendAPIErrResponse(client, acc, subject, reply, _EMPTY_, s.jsonResponse(&resp)) } else { - resp.ConsumerInfo = o.initialInfo() + resp.ConsumerInfo = setDynamicConsumerInfoMetadata(o.initialInfo()) s.sendAPIResponse(client, acc, subject, reply, _EMPTY_, s.jsonResponse(&resp)) o.sendCreateAdvisory() } + // Only send a pause advisory on consumer create if we're + // actually paused. The timer would have been kicked by now + // by the call to o.setLeader() above. + if isLeader && o.cfg.PauseUntil != nil && !o.cfg.PauseUntil.IsZero() && time.Now().Before(*o.cfg.PauseUntil) { + o.sendPauseAdvisoryLocked(&o.cfg) + } + return nil } @@ -6113,7 +6091,7 @@ func (js *jetStream) jsClusteredStreamLimitsCheck(acc *Account, cfg *StreamConfi return nil } -func (s *Server) jsClusteredStreamRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, config *StreamConfig) { +func (s *Server) jsClusteredStreamRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, config *StreamConfigRequest) { js, cc := s.getJetStreamCluster() if js == nil || cc == nil { return @@ -6121,7 +6099,7 @@ func (s *Server) jsClusteredStreamRequest(ci *ClientInfo, acc *Account, subject, var resp = JSApiStreamCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamCreateResponseType}} - ccfg, apiErr := s.checkStreamCfg(config, acc) + ccfg, apiErr := s.checkStreamCfg(&config.StreamConfig, acc, config.Pedantic) if apiErr != nil { resp.Error = apiErr s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) @@ -6139,6 +6117,7 @@ func (s *Server) jsClusteredStreamRequest(ci *ClientInfo, acc *Account, subject, // Capture if we have existing assignment first. if osa := js.streamAssignment(acc.Name, cfg.Name); osa != nil { + copyStreamMetadata(cfg, osa.Config) if !reflect.DeepEqual(osa.Config, cfg) { resp.Error = NewJSStreamNameExistError() s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) @@ -6269,7 +6248,7 @@ func sysRequest[T any](s *Server, subjFormat string, args ...any) (*T, error) { } } -func (s *Server) jsClusteredStreamUpdateRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, cfg *StreamConfig, peerSet []string) { +func (s *Server) jsClusteredStreamUpdateRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, cfg *StreamConfig, peerSet []string, pedantic bool) { js, cc := s.getJetStreamCluster() if js == nil || cc == nil { return @@ -6286,16 +6265,19 @@ func (s *Server) jsClusteredStreamUpdateRequest(ci *ClientInfo, acc *Account, su var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}} osa := js.streamAssignment(acc.Name, cfg.Name) - if osa == nil { resp.Error = NewJSStreamNotFoundError() s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) return } + + // Update asset version metadata. + setStaticStreamMetadata(cfg) + var newCfg *StreamConfig if jsa := js.accounts[acc.Name]; jsa != nil { js.mu.Unlock() - ncfg, err := jsa.configUpdateCheck(osa.Config, cfg, s) + ncfg, err := jsa.configUpdateCheck(osa.Config, cfg, s, pedantic) js.mu.Lock() if err != nil { resp.Error = NewJSStreamUpdateError(err, Unless(err)) @@ -7278,7 +7260,7 @@ func (cc *jetStreamCluster) createGroupForConsumer(cfg *ConsumerConfig, sa *stre } // jsClusteredConsumerRequest is first point of entry to create a consumer in clustered mode. -func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, stream string, cfg *ConsumerConfig, action ConsumerAction) { +func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subject, reply string, rmsg []byte, stream string, cfg *ConsumerConfig, action ConsumerAction, pedantic bool) { js, cc := s.getJetStreamCluster() if js == nil || cc == nil { return @@ -7300,7 +7282,11 @@ func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subjec } srvLim := &s.getOpts().JetStreamLimits // Make sure we have sane defaults - setConsumerConfigDefaults(cfg, &streamCfg, srvLim, selectedLimits) + if err := setConsumerConfigDefaults(cfg, &streamCfg, srvLim, selectedLimits, pedantic); err != nil { + resp.Error = err + s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) + return + } if err := checkConsumerCfg(cfg, srvLim, &streamCfg, acc, selectedLimits, false); err != nil { resp.Error = err @@ -7385,12 +7371,19 @@ func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subjec cfg.MaxAckPending = JsDefaultMaxAckPending } + if cfg.PriorityPolicy == PriorityPinnedClient && cfg.PinnedTTL == 0 { + cfg.PinnedTTL = JsDefaultPinnedTTL + } + var ca *consumerAssignment // See if we have an existing one already under same durable name or // if name was set by the user. if oname != _EMPTY_ { if ca = sa.consumers[oname]; ca != nil && !ca.deleted { + // Provided config might miss metadata, copy from existing config. + copyConsumerMetadata(cfg, ca.Config) + if action == ActionCreate && !reflect.DeepEqual(cfg, ca.Config) { resp.Error = NewJSConsumerAlreadyExistsError() s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) @@ -7402,9 +7395,20 @@ func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subjec s.sendAPIErrResponse(ci, acc, subject, reply, string(rmsg), s.jsonResponse(&resp)) return } + } else { + // Initialize/update asset version metadata. + // First time creating this consumer, or updating. + setStaticConsumerMetadata(cfg) } } + // Initialize/update asset version metadata. + // But only if we're not creating, should only update it the first time + // to be idempotent with versions where there's no versioning metadata. + if action != ActionCreate { + setStaticConsumerMetadata(cfg) + } + // If this is new consumer. if ca == nil { if action == ActionUpdate { @@ -7509,6 +7513,10 @@ func (s *Server) jsClusteredConsumerRequest(ci *ClientInfo, acc *Account, subjec Created: time.Now().UTC(), } } else { + // If the consumer already exists then don't allow updating the PauseUntil, just set + // it back to whatever the current configured value is. + cfg.PauseUntil = ca.Config.PauseUntil + nca := ca.copyGroup() rBefore := nca.Config.replicas(sa.Config) @@ -7636,10 +7644,10 @@ func decodeConsumerAssignmentCompressed(buf []byte) (*consumerAssignment, error) var errBadStreamMsg = errors.New("jetstream cluster bad replicated stream msg") -func decodeStreamMsg(buf []byte) (subject, reply string, hdr, msg []byte, lseq uint64, ts int64, err error) { +func decodeStreamMsg(buf []byte) (subject, reply string, hdr, msg []byte, lseq uint64, ts int64, sourced bool, err error) { var le = binary.LittleEndian if len(buf) < 26 { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } lseq = le.Uint64(buf) buf = buf[8:] @@ -7648,55 +7656,58 @@ func decodeStreamMsg(buf []byte) (subject, reply string, hdr, msg []byte, lseq u sl := int(le.Uint16(buf)) buf = buf[2:] if len(buf) < sl { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } subject = string(buf[:sl]) buf = buf[sl:] if len(buf) < 2 { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } rl := int(le.Uint16(buf)) buf = buf[2:] if len(buf) < rl { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } reply = string(buf[:rl]) buf = buf[rl:] if len(buf) < 2 { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } hl := int(le.Uint16(buf)) buf = buf[2:] if len(buf) < hl { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } if hdr = buf[:hl]; len(hdr) == 0 { hdr = nil } buf = buf[hl:] if len(buf) < 4 { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } ml := int(le.Uint32(buf)) buf = buf[4:] if len(buf) < ml { - return _EMPTY_, _EMPTY_, nil, nil, 0, 0, errBadStreamMsg + return _EMPTY_, _EMPTY_, nil, nil, 0, 0, false, errBadStreamMsg } if msg = buf[:ml]; len(msg) == 0 { msg = nil } - return subject, reply, hdr, msg, lseq, ts, nil + buf = buf[ml:] + if len(buf) > 0 { + flags, _ := binary.Uvarint(buf) + sourced = flags&msgFlagFromSourceOrMirror != 0 + } + return subject, reply, hdr, msg, lseq, ts, sourced, nil } -// Helper to return if compression allowed. -func (mset *stream) compressAllowed() bool { - mset.clMu.Lock() - defer mset.clMu.Unlock() - return mset.compressOK -} +// Flags for encodeStreamMsg/decodeStreamMsg. +const ( + msgFlagFromSourceOrMirror uint64 = 1 << iota +) -func encodeStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64) []byte { - return encodeStreamMsgAllowCompress(subject, reply, hdr, msg, lseq, ts, false) +func encodeStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, sourced bool) []byte { + return encodeStreamMsgAllowCompress(subject, reply, hdr, msg, lseq, ts, sourced) } // Threshold for compression. @@ -7704,7 +7715,7 @@ func encodeStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int const compressThreshold = 8192 // 8k // If allowed and contents over the threshold we will compress. -func encodeStreamMsgAllowCompress(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, compressOK bool) []byte { +func encodeStreamMsgAllowCompress(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, sourced bool) []byte { // Clip the subject, reply, header and msgs down. Operate on // uint64 lengths to avoid overflowing. slen := min(uint64(len(subject)), math.MaxUint16) @@ -7713,9 +7724,14 @@ func encodeStreamMsgAllowCompress(subject, reply string, hdr, msg []byte, lseq u mlen := min(uint64(len(msg)), math.MaxUint32) total := slen + rlen + hlen + mlen - shouldCompress := compressOK && total > compressThreshold + shouldCompress := total > compressThreshold elen := int(1 + 8 + 8 + total) - elen += (2 + 2 + 2 + 4) // Encoded lengths, 4bytes + elen += (2 + 2 + 2 + 4 + 8) // Encoded lengths, 4bytes, flags are up to 8 bytes + + var flags uint64 + if sourced { + flags |= msgFlagFromSourceOrMirror + } buf := make([]byte, 1, elen) buf[0] = byte(streamMsgOp) @@ -7731,6 +7747,7 @@ func encodeStreamMsgAllowCompress(subject, reply string, hdr, msg []byte, lseq u buf = append(buf, hdr[:hlen]...) buf = le.AppendUint32(buf, uint32(mlen)) buf = append(buf, msg[:mlen]...) + buf = binary.AppendUvarint(buf, flags) // Check if we should compress. if shouldCompress { @@ -7820,31 +7837,11 @@ func (mset *stream) stateSnapshotLocked() []byte { return b } -// Will check if we can do message compression in RAFT and catchup logic. -func (mset *stream) checkAllowMsgCompress(peers []string) { - allowed := true - for _, id := range peers { - sir, ok := mset.srv.nodeToInfo.Load(id) - if !ok || sir == nil { - allowed = false - break - } - // Check for capability. - if si := sir.(nodeInfo); si.cfg == nil || !si.cfg.CompressOK { - allowed = false - break - } - } - mset.mu.Lock() - mset.compressOK = allowed - mset.mu.Unlock() -} - // To warn when we are getting too far behind from what has been proposed vs what has been committed. const streamLagWarnThreshold = 10_000 -// processClusteredMsg will propose the inbound message to the underlying raft group. -func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg []byte) error { +// processClusteredInboundMsg will propose the inbound message to the underlying raft group. +func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg []byte, mt *msgTrace, sourced bool) (retErr error) { // For possible error response. var response []byte @@ -7854,12 +7851,27 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ s, js, jsa, st, r, tierName, outq, node := mset.srv, mset.js, mset.jsa, mset.cfg.Storage, mset.cfg.Replicas, mset.tier, mset.outq, mset.node maxMsgSize, lseq := int(mset.cfg.MaxMsgSize), mset.lseq interestPolicy, discard, maxMsgs, maxBytes := mset.cfg.Retention != LimitsPolicy, mset.cfg.Discard, mset.cfg.MaxMsgs, mset.cfg.MaxBytes - isLeader, isSealed, compressOK := mset.isLeader(), mset.cfg.Sealed, mset.compressOK + isLeader, isSealed, allowTTL := mset.isLeader(), mset.cfg.Sealed, mset.cfg.AllowMsgTTL mset.mu.RUnlock() // This should not happen but possible now that we allow scale up, and scale down where this could trigger. - if node == nil { - return mset.processJetStreamMsg(subject, reply, hdr, msg, 0, 0) + // + // We also invoke this in clustering mode for message tracing when not + // performing message delivery. + if node == nil || mt.traceOnly() { + return mset.processJetStreamMsg(subject, reply, hdr, msg, 0, 0, mt, sourced) + } + + // If message tracing (with message delivery), we will need to send the + // event on exit in case there was an error (if message was not proposed). + // Otherwise, the event will be sent from processJetStreamMsg when + // invoked by the leader (from applyStreamEntries). + if mt != nil { + defer func() { + if retErr != nil { + mt.sendEventFromJetStream(retErr) + } + }() } // Check that we are the leader. This can be false if we have scaled up from an R1 that had inbound queued messages. @@ -7933,29 +7945,6 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ } return err } - // Expected last sequence per subject. - // We can check for last sequence per subject but only if the expected seq <= lseq. - if seq, exists := getExpectedLastSeqPerSubject(hdr); exists && store != nil && seq <= lseq { - var smv StoreMsg - var fseq uint64 - sm, err := store.LoadLastMsg(subject, &smv) - if sm != nil { - fseq = sm.seq - } - if err == ErrStoreMsgNotFound && seq == 0 { - fseq, err = 0, nil - } - if err != nil || fseq != seq { - if canRespond { - var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} - resp.PubAck = &PubAck{Stream: name} - resp.Error = NewJSStreamWrongLastSequenceError(fseq) - b, _ := json.Marshal(resp) - outq.sendMsg(reply, b) - } - return fmt.Errorf("last sequence by subject mismatch: %d vs %d", seq, fseq) - } - } // Expected stream name can also be pre-checked. if sname := getExpectedStream(hdr); sname != _EMPTY_ && sname != name { if canRespond { @@ -7976,10 +7965,18 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ pubAck := append(buf[:0], mset.pubAck...) seq := dde.seq mset.mu.Unlock() + // Should not return an invalid sequence, in that case error. if canRespond { - response := append(pubAck, strconv.FormatUint(seq, 10)...) - response = append(response, ",\"duplicate\": true}"...) - outq.sendMsg(reply, response) + if seq > 0 { + response := append(pubAck, strconv.FormatUint(seq, 10)...) + response = append(response, ",\"duplicate\": true}"...) + outq.sendMsg(reply, response) + } else { + var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} + resp.Error = ApiErrors[JSStreamDuplicateMessageConflict] + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } } return errMsgIdDuplicate } @@ -7988,6 +7985,17 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.storeMsgIdLocked(&ddentry{msgId, 0, time.Now().UnixNano()}) mset.mu.Unlock() } + + // TTL'd messages are rejected entirely if TTLs are not enabled on the stream. + if ttl, _ := getMessageTTL(hdr); !sourced && ttl != 0 && !allowTTL { + if canRespond { + var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} + resp.Error = NewJSMessageTTLDisabledError() + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return errMsgTTLDisabled + } } // Proceed with proposing this message. @@ -8045,7 +8053,75 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ } } - esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), compressOK) + if len(hdr) > 0 { + // Expected last sequence per subject. + if seq, exists := getExpectedLastSeqPerSubject(hdr); exists && store != nil { + // Allow override of the subject used for the check. + seqSubj := subject + if optSubj := getExpectedLastSeqPerSubjectForSubject(hdr); optSubj != _EMPTY_ { + seqSubj = optSubj + } + + // If subject is already in process, block as otherwise we could have multiple messages inflight with same subject. + if _, found := mset.expectedPerSubjectInProcess[seqSubj]; found { + // Could have set inflight above, cleanup here. + delete(mset.inflight, mset.clseq) + mset.clMu.Unlock() + if canRespond { + var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSStreamWrongLastSequenceConstantError() + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return fmt.Errorf("last sequence by subject mismatch") + } + + var smv StoreMsg + var fseq uint64 + sm, err := store.LoadLastMsg(seqSubj, &smv) + if sm != nil { + fseq = sm.seq + } + if err == ErrStoreMsgNotFound && seq == 0 { + fseq, err = 0, nil + } + if err != nil || fseq != seq { + // Could have set inflight above, cleanup here. + delete(mset.inflight, mset.clseq) + mset.clMu.Unlock() + if canRespond { + var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSStreamWrongLastSequenceError(fseq) + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return fmt.Errorf("last sequence by subject mismatch: %d vs %d", seq, fseq) + } + + // Track sequence and subject. + if mset.expectedPerSubjectSequence == nil { + mset.expectedPerSubjectSequence = make(map[uint64]string) + } + if mset.expectedPerSubjectInProcess == nil { + mset.expectedPerSubjectInProcess = make(map[string]struct{}) + } + mset.expectedPerSubjectSequence[mset.clseq] = seqSubj + mset.expectedPerSubjectInProcess[seqSubj] = struct{}{} + } + } + + esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), sourced) + var mtKey uint64 + if mt != nil { + mtKey = mset.clseq + if mset.mt == nil { + mset.mt = make(map[uint64]*msgTrace) + } + mset.mt[mtKey] = mt + } + // Do proposal. err := node.Propose(esm) if err == nil { @@ -8067,6 +8143,9 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.clMu.Unlock() if err != nil { + if mt != nil { + mset.getAndDeleteMsgTrace(mtKey) + } if canRespond { var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: mset.cfg.Name}} resp.Error = &ApiError{Code: 503, Description: err.Error()} @@ -8082,6 +8161,19 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ return err } +func (mset *stream) getAndDeleteMsgTrace(lseq uint64) *msgTrace { + if mset == nil { + return nil + } + mset.clMu.Lock() + mt, ok := mset.mt[lseq] + if ok { + delete(mset.mt, lseq) + } + mset.clMu.Unlock() + return mt +} + // For requesting messages post raft snapshot to catch up streams post server restart. // Any deleted msgs etc will be handled inline on catchup. type streamSyncRequest struct { @@ -8089,11 +8181,12 @@ type streamSyncRequest struct { FirstSeq uint64 `json:"first_seq"` LastSeq uint64 `json:"last_seq"` DeleteRangesOk bool `json:"delete_ranges"` + MinApplied uint64 `json:"min_applied"` } // Given a stream state that represents a snapshot, calculate the sync request based on our current state. // Stream lock must be held. -func (mset *stream) calculateSyncRequest(state *StreamState, snap *StreamReplicatedState) *streamSyncRequest { +func (mset *stream) calculateSyncRequest(state *StreamState, snap *StreamReplicatedState, index uint64) *streamSyncRequest { // Shouldn't happen, but consequences are pretty bad if we have the lock held and // our caller tries to take the lock again on panic defer, as in processSnapshot. if state == nil || snap == nil || mset.node == nil { @@ -8103,7 +8196,7 @@ func (mset *stream) calculateSyncRequest(state *StreamState, snap *StreamReplica if state.LastSeq >= snap.LastSeq { return nil } - return &streamSyncRequest{FirstSeq: state.LastSeq + 1, LastSeq: snap.LastSeq, Peer: mset.node.ID(), DeleteRangesOk: true} + return &streamSyncRequest{FirstSeq: state.LastSeq + 1, LastSeq: snap.LastSeq, Peer: mset.node.ID(), DeleteRangesOk: true, MinApplied: index} } // processSnapshotDeletes will update our current store based on the snapshot @@ -8239,7 +8332,7 @@ var ( ) // Process a stream snapshot. -func (mset *stream) processSnapshot(snap *StreamReplicatedState) (e error) { +func (mset *stream) processSnapshot(snap *StreamReplicatedState, index uint64) (e error) { // Update any deletes, etc. mset.processSnapshotDeletes(snap) mset.setCLFS(snap.Failed) @@ -8247,7 +8340,7 @@ func (mset *stream) processSnapshot(snap *StreamReplicatedState) (e error) { mset.mu.Lock() var state StreamState mset.store.FastState(&state) - sreq := mset.calculateSyncRequest(&state, snap) + sreq := mset.calculateSyncRequest(&state, snap, index) s, js, subject, n, st := mset.srv, mset.js, mset.sa.Sync, mset.node, mset.cfg.Storage qname := fmt.Sprintf("[ACC:%s] stream '%s' snapshot", mset.acc.Name, mset.cfg.Name) @@ -8385,7 +8478,7 @@ RETRY: mset.mu.RLock() var state StreamState mset.store.FastState(&state) - sreq = mset.calculateSyncRequest(&state, snap) + sreq = mset.calculateSyncRequest(&state, snap, index) mset.mu.RUnlock() if sreq == nil { return nil @@ -8450,6 +8543,7 @@ RETRY: } } else if isOutOfSpaceErr(err) { notifyLeaderStopCatchup(mrec, err) + msgsQ.recycle(&mrecs) return err } else if err == NewJSInsufficientResourcesError() { notifyLeaderStopCatchup(mrec, err) @@ -8545,7 +8639,7 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { } } - subj, _, hdr, msg, seq, ts, err := decodeStreamMsg(mbuf) + subj, _, hdr, msg, seq, ts, _, err := decodeStreamMsg(mbuf) if err != nil { return 0, errCatchupBadMsg } @@ -8553,9 +8647,6 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { mset.mu.Lock() st := mset.cfg.Storage ddloaded := mset.ddloaded - tierName := mset.tier - replicas := mset.cfg.Replicas - if mset.hasAllPreAcks(seq, subj) { mset.clearAllPreAcks(seq) // Mark this to be skipped @@ -8563,14 +8654,16 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { } mset.mu.Unlock() + // Since we're clustered we do not want to check limits based on tier here and possibly introduce skew. if mset.js.limitsExceeded(st) { return 0, NewJSInsufficientResourcesError() - } else if exceeded, apiErr := mset.jsa.limitsExceeded(st, tierName, replicas); apiErr != nil { - return 0, apiErr - } else if exceeded { - return 0, NewJSInsufficientResourcesError() } + // Find the message TTL if any. + // TODO(nat): If the TTL isn't valid by this stage then there isn't really a + // lot we can do about it, as we'd break the catchup if we reject the message. + ttl, _ := getMessageTTL(hdr) + // Put into our store // Messages to be skipped have no subject or timestamp. // TODO(dlc) - formalize with skipMsgOp @@ -8578,7 +8671,7 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { if lseq := mset.store.SkipMsg(); lseq != seq { return 0, errCatchupWrongSeqForSkip } - } else if err := mset.store.StoreRawMsg(subj, hdr, msg, seq, ts); err != nil { + } else if err := mset.store.StoreRawMsg(subj, hdr, msg, seq, ts, ttl); err != nil { return 0, err } @@ -8932,10 +9025,23 @@ func (mset *stream) runCatchup(sendSubject string, sreq *streamSyncRequest) { // Setup sequences to walk through. seq, last := sreq.FirstSeq, sreq.LastSeq - mset.setCatchupPeer(sreq.Peer, last-seq) - // Check if we can compress during this. - compressOk := mset.compressAllowed() + // The follower received a snapshot from another leader, and we've become leader since. + // We have an up-to-date log but could be behind on applies. We must wait until we've reached the minimum required. + // The follower will automatically retry after a timeout, so we can safely return here. + if node := mset.raftNode(); node != nil { + index, _, applied := node.Progress() + // Only skip if our log has enough entries, and they could be applied in the future. + if index >= sreq.MinApplied && applied < sreq.MinApplied { + return + } + // We know here we've either applied enough entries, or our log doesn't have enough entries. + // In the latter case the request expects us to have more. Just continue and value availability here. + // This should only be possible if the logs have already desynced, and we shouldn't have become leader + // in the first place. Not much we can do here in this (hypothetical) scenario. + } + + mset.setCatchupPeer(sreq.Peer, last-seq) var spb int const minWait = 5 * time.Second @@ -8999,7 +9105,7 @@ func (mset *stream) runCatchup(sendSubject string, sreq *streamSyncRequest) { sendDR := func() { if dr.Num == 1 { // Send like a normal skip msg. - sendEM(encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, dr.First, 0)) + sendEM(encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, dr.First, 0, false)) } else { // We have a run, send a gap record. We send these without reply or tracking. s.sendInternalMsgLocked(sendSubject, _EMPTY_, nil, encodeDeleteRange(&dr)) @@ -9079,7 +9185,7 @@ func (mset *stream) runCatchup(sendSubject string, sreq *streamSyncRequest) { sendDR() } // Send the normal message now. - sendEM(encodeStreamMsgAllowCompress(sm.subj, _EMPTY_, sm.hdr, sm.msg, sm.seq, sm.ts, compressOk)) + sendEM(encodeStreamMsgAllowCompress(sm.subj, _EMPTY_, sm.hdr, sm.msg, sm.seq, sm.ts, false)) } else { if drOk { if dr.First == 0 { @@ -9089,7 +9195,7 @@ func (mset *stream) runCatchup(sendSubject string, sreq *streamSyncRequest) { } } else { // Skip record for deleted msg. - sendEM(encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, seq, 0)) + sendEM(encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, seq, 0, false)) } } diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_errors_generated.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_errors_generated.go index 1543eeeb4b..30ed884e0f 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_errors_generated.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_errors_generated.go @@ -95,6 +95,9 @@ const ( // JSConsumerEmptyFilter consumer filter in FilterSubjects cannot be empty JSConsumerEmptyFilter ErrorIdentifier = 10139 + // JSConsumerEmptyGroupName Group name cannot be an empty string + JSConsumerEmptyGroupName ErrorIdentifier = 10161 + // JSConsumerEphemeralWithDurableInSubjectErr consumer expected to be ephemeral but detected a durable name set in subject JSConsumerEphemeralWithDurableInSubjectErr ErrorIdentifier = 10019 @@ -119,9 +122,15 @@ const ( // JSConsumerInvalidDeliverSubject invalid push consumer deliver subject JSConsumerInvalidDeliverSubject ErrorIdentifier = 10112 + // JSConsumerInvalidGroupNameErr Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters + JSConsumerInvalidGroupNameErr ErrorIdentifier = 10162 + // JSConsumerInvalidPolicyErrF Generic delivery policy error ({err}) JSConsumerInvalidPolicyErrF ErrorIdentifier = 10094 + // JSConsumerInvalidPriorityGroupErr Provided priority group does not exist for this consumer + JSConsumerInvalidPriorityGroupErr ErrorIdentifier = 10160 + // JSConsumerInvalidSamplingErrF failed to parse consumer sampling configuration: {err} JSConsumerInvalidSamplingErrF ErrorIdentifier = 10095 @@ -173,10 +182,13 @@ const ( // JSConsumerOverlappingSubjectFilters consumer subject filters cannot overlap JSConsumerOverlappingSubjectFilters ErrorIdentifier = 10138 + // JSConsumerPriorityPolicyWithoutGroup Setting PriorityPolicy requires at least one PriorityGroup to be set + JSConsumerPriorityPolicyWithoutGroup ErrorIdentifier = 10159 + // JSConsumerPullNotDurableErr consumer in pull mode requires a durable name JSConsumerPullNotDurableErr ErrorIdentifier = 10085 - // JSConsumerPullRequiresAckErr consumer in pull mode requires ack policy + // JSConsumerPullRequiresAckErr consumer in pull mode requires explicit ack policy on workqueue stream JSConsumerPullRequiresAckErr ErrorIdentifier = 10084 // JSConsumerPullWithRateLimitErr consumer in pull mode can not have rate limit set @@ -218,7 +230,7 @@ const ( // JSInsufficientResourcesErr insufficient resources JSInsufficientResourcesErr ErrorIdentifier = 10023 - // JSInvalidJSONErr invalid JSON + // JSInvalidJSONErr invalid JSON: {err} JSInvalidJSONErr ErrorIdentifier = 10025 // JSMaximumConsumersLimitErr maximum consumers limit reached @@ -230,15 +242,24 @@ const ( // JSMemoryResourcesExceededErr insufficient memory resources available JSMemoryResourcesExceededErr ErrorIdentifier = 10028 + // JSMessageTTLDisabledErr per-message TTL is disabled + JSMessageTTLDisabledErr ErrorIdentifier = 10166 + + // JSMessageTTLInvalidErr invalid per-message TTL + JSMessageTTLInvalidErr ErrorIdentifier = 10165 + // JSMirrorConsumerSetupFailedErrF generic mirror consumer setup failure string ({err}) JSMirrorConsumerSetupFailedErrF ErrorIdentifier = 10029 // JSMirrorInvalidStreamName mirrored stream name is invalid JSMirrorInvalidStreamName ErrorIdentifier = 10142 - // JSMirrorInvalidSubjectFilter mirror subject filter is invalid + // JSMirrorInvalidSubjectFilter mirror transform source: {err} JSMirrorInvalidSubjectFilter ErrorIdentifier = 10151 + // JSMirrorInvalidTransformDestination mirror transform: {err} + JSMirrorInvalidTransformDestination ErrorIdentifier = 10154 + // JSMirrorMaxMessageSizeTooBigErr stream mirror must have max message size >= source JSMirrorMaxMessageSizeTooBigErr ErrorIdentifier = 10030 @@ -281,6 +302,9 @@ const ( // JSNotEnabledForAccountErr JetStream not enabled for account JSNotEnabledForAccountErr ErrorIdentifier = 10039 + // JSPedanticErrF pedantic mode: {err} + JSPedanticErrF ErrorIdentifier = 10157 + // JSPeerRemapErr peer remap failed JSPeerRemapErr ErrorIdentifier = 10075 @@ -308,10 +332,10 @@ const ( // JSSourceInvalidStreamName sourced stream name is invalid JSSourceInvalidStreamName ErrorIdentifier = 10141 - // JSSourceInvalidSubjectFilter source subject filter is invalid + // JSSourceInvalidSubjectFilter source transform source: {err} JSSourceInvalidSubjectFilter ErrorIdentifier = 10145 - // JSSourceInvalidTransformDestination source transform destination is invalid + // JSSourceInvalidTransformDestination source transform: {err} JSSourceInvalidTransformDestination ErrorIdentifier = 10146 // JSSourceMaxMessageSizeTooBigErr stream source must have max message size >= target @@ -335,6 +359,12 @@ const ( // JSStreamDeleteErrF General stream deletion error string ({err}) JSStreamDeleteErrF ErrorIdentifier = 10050 + // JSStreamDuplicateMessageConflict duplicate message id is in process + JSStreamDuplicateMessageConflict ErrorIdentifier = 10158 + + // JSStreamExpectedLastSeqPerSubjectNotReady expected last sequence per subject temporarily unavailable + JSStreamExpectedLastSeqPerSubjectNotReady ErrorIdentifier = 10163 + // JSStreamExternalApiOverlapErrF stream external api prefix {prefix} must not overlap with {subject} JSStreamExternalApiOverlapErrF ErrorIdentifier = 10021 @@ -446,12 +476,24 @@ const ( // JSStreamTemplateNotFoundErr template not found JSStreamTemplateNotFoundErr ErrorIdentifier = 10068 + // JSStreamTooManyRequests too many requests + JSStreamTooManyRequests ErrorIdentifier = 10167 + + // JSStreamTransformInvalidDestination stream transform: {err} + JSStreamTransformInvalidDestination ErrorIdentifier = 10156 + + // JSStreamTransformInvalidSource stream transform source: {err} + JSStreamTransformInvalidSource ErrorIdentifier = 10155 + // JSStreamUpdateErrF Generic stream update error string ({err}) JSStreamUpdateErrF ErrorIdentifier = 10069 // JSStreamWrongLastMsgIDErrF wrong last msg ID: {id} JSStreamWrongLastMsgIDErrF ErrorIdentifier = 10070 + // JSStreamWrongLastSequenceConstantErr wrong last sequence + JSStreamWrongLastSequenceConstantErr ErrorIdentifier = 10164 + // JSStreamWrongLastSequenceErrF wrong last sequence: {seq} JSStreamWrongLastSequenceErrF ErrorIdentifier = 10071 @@ -494,6 +536,7 @@ var ( JSConsumerDurableNameNotMatchSubjectErr: {Code: 400, ErrCode: 10017, Description: "consumer name in subject does not match durable name in request"}, JSConsumerDurableNameNotSetErr: {Code: 400, ErrCode: 10018, Description: "consumer expected to be durable but a durable name was not set"}, JSConsumerEmptyFilter: {Code: 400, ErrCode: 10139, Description: "consumer filter in FilterSubjects cannot be empty"}, + JSConsumerEmptyGroupName: {Code: 400, ErrCode: 10161, Description: "Group name cannot be an empty string"}, JSConsumerEphemeralWithDurableInSubjectErr: {Code: 400, ErrCode: 10019, Description: "consumer expected to be ephemeral but detected a durable name set in subject"}, JSConsumerEphemeralWithDurableNameErr: {Code: 400, ErrCode: 10020, Description: "consumer expected to be ephemeral but a durable name was set in request"}, JSConsumerExistingActiveErr: {Code: 400, ErrCode: 10105, Description: "consumer already exists and is still active"}, @@ -502,7 +545,9 @@ var ( JSConsumerHBRequiresPushErr: {Code: 400, ErrCode: 10088, Description: "consumer idle heartbeat requires a push based consumer"}, JSConsumerInactiveThresholdExcess: {Code: 400, ErrCode: 10153, Description: "consumer inactive threshold exceeds system limit of {limit}"}, JSConsumerInvalidDeliverSubject: {Code: 400, ErrCode: 10112, Description: "invalid push consumer deliver subject"}, + JSConsumerInvalidGroupNameErr: {Code: 400, ErrCode: 10162, Description: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters"}, JSConsumerInvalidPolicyErrF: {Code: 400, ErrCode: 10094, Description: "{err}"}, + JSConsumerInvalidPriorityGroupErr: {Code: 400, ErrCode: 10160, Description: "Provided priority group does not exist for this consumer"}, JSConsumerInvalidSamplingErrF: {Code: 400, ErrCode: 10095, Description: "failed to parse consumer sampling configuration: {err}"}, JSConsumerMaxDeliverBackoffErr: {Code: 400, ErrCode: 10116, Description: "max deliver is required to be > length of backoff values"}, JSConsumerMaxPendingAckExcessErrF: {Code: 400, ErrCode: 10121, Description: "consumer max ack pending exceeds system limit of {limit}"}, @@ -520,8 +565,9 @@ var ( JSConsumerOfflineErr: {Code: 500, ErrCode: 10119, Description: "consumer is offline"}, JSConsumerOnMappedErr: {Code: 400, ErrCode: 10092, Description: "consumer direct on a mapped consumer"}, JSConsumerOverlappingSubjectFilters: {Code: 400, ErrCode: 10138, Description: "consumer subject filters cannot overlap"}, + JSConsumerPriorityPolicyWithoutGroup: {Code: 400, ErrCode: 10159, Description: "Setting PriorityPolicy requires at least one PriorityGroup to be set"}, JSConsumerPullNotDurableErr: {Code: 400, ErrCode: 10085, Description: "consumer in pull mode requires a durable name"}, - JSConsumerPullRequiresAckErr: {Code: 400, ErrCode: 10084, Description: "consumer in pull mode requires ack policy"}, + JSConsumerPullRequiresAckErr: {Code: 400, ErrCode: 10084, Description: "consumer in pull mode requires explicit ack policy on workqueue stream"}, JSConsumerPullWithRateLimitErr: {Code: 400, ErrCode: 10086, Description: "consumer in pull mode can not have rate limit set"}, JSConsumerPushMaxWaitingErr: {Code: 400, ErrCode: 10080, Description: "consumer in push mode can not set max waiting"}, JSConsumerReplacementWithDifferentNameErr: {Code: 400, ErrCode: 10106, Description: "consumer replacement durable config not the same"}, @@ -535,13 +581,16 @@ var ( JSConsumerWQRequiresExplicitAckErr: {Code: 400, ErrCode: 10098, Description: "workqueue stream requires explicit ack"}, JSConsumerWithFlowControlNeedsHeartbeats: {Code: 400, ErrCode: 10108, Description: "consumer with flow control also needs heartbeats"}, JSInsufficientResourcesErr: {Code: 503, ErrCode: 10023, Description: "insufficient resources"}, - JSInvalidJSONErr: {Code: 400, ErrCode: 10025, Description: "invalid JSON"}, + JSInvalidJSONErr: {Code: 400, ErrCode: 10025, Description: "invalid JSON: {err}"}, JSMaximumConsumersLimitErr: {Code: 400, ErrCode: 10026, Description: "maximum consumers limit reached"}, JSMaximumStreamsLimitErr: {Code: 400, ErrCode: 10027, Description: "maximum number of streams reached"}, JSMemoryResourcesExceededErr: {Code: 500, ErrCode: 10028, Description: "insufficient memory resources available"}, + JSMessageTTLDisabledErr: {Code: 400, ErrCode: 10166, Description: "per-message TTL is disabled"}, + JSMessageTTLInvalidErr: {Code: 400, ErrCode: 10165, Description: "invalid per-message TTL"}, JSMirrorConsumerSetupFailedErrF: {Code: 500, ErrCode: 10029, Description: "{err}"}, JSMirrorInvalidStreamName: {Code: 400, ErrCode: 10142, Description: "mirrored stream name is invalid"}, - JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror subject filter is invalid"}, + JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror transform source: {err}"}, + JSMirrorInvalidTransformDestination: {Code: 400, ErrCode: 10154, Description: "mirror transform: {err}"}, JSMirrorMaxMessageSizeTooBigErr: {Code: 400, ErrCode: 10030, Description: "stream mirror must have max message size >= source"}, JSMirrorMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10150, Description: "mirror with multiple subject transforms cannot also have a single subject filter"}, JSMirrorOverlappingSubjectFilters: {Code: 400, ErrCode: 10152, Description: "mirror subject filters can not overlap"}, @@ -556,6 +605,7 @@ var ( JSNotEmptyRequestErr: {Code: 400, ErrCode: 10038, Description: "expected an empty request payload"}, JSNotEnabledErr: {Code: 503, ErrCode: 10076, Description: "JetStream not enabled"}, JSNotEnabledForAccountErr: {Code: 503, ErrCode: 10039, Description: "JetStream not enabled for account"}, + JSPedanticErrF: {Code: 400, ErrCode: 10157, Description: "pedantic mode: {err}"}, JSPeerRemapErr: {Code: 503, ErrCode: 10075, Description: "peer remap failed"}, JSRaftGeneralErrF: {Code: 500, ErrCode: 10041, Description: "{err}"}, JSReplicasCountCannotBeNegative: {Code: 400, ErrCode: 10133, Description: "replicas count cannot be negative"}, @@ -565,8 +615,8 @@ var ( JSSourceConsumerSetupFailedErrF: {Code: 500, ErrCode: 10045, Description: "{err}"}, JSSourceDuplicateDetected: {Code: 400, ErrCode: 10140, Description: "duplicate source configuration detected"}, JSSourceInvalidStreamName: {Code: 400, ErrCode: 10141, Description: "sourced stream name is invalid"}, - JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source subject filter is invalid"}, - JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform destination is invalid"}, + JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source transform source: {err}"}, + JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform: {err}"}, JSSourceMaxMessageSizeTooBigErr: {Code: 400, ErrCode: 10046, Description: "stream source must have max message size >= target"}, JSSourceMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10144, Description: "source with multiple subject transforms cannot also have a single subject filter"}, JSSourceOverlappingSubjectFilters: {Code: 400, ErrCode: 10147, Description: "source filters can not overlap"}, @@ -574,6 +624,8 @@ var ( JSStreamAssignmentErrF: {Code: 500, ErrCode: 10048, Description: "{err}"}, JSStreamCreateErrF: {Code: 500, ErrCode: 10049, Description: "{err}"}, JSStreamDeleteErrF: {Code: 500, ErrCode: 10050, Description: "{err}"}, + JSStreamDuplicateMessageConflict: {Code: 409, ErrCode: 10158, Description: "duplicate message id is in process"}, + JSStreamExpectedLastSeqPerSubjectNotReady: {Code: 503, ErrCode: 10163, Description: "expected last sequence per subject temporarily unavailable"}, JSStreamExternalApiOverlapErrF: {Code: 400, ErrCode: 10021, Description: "stream external api prefix {prefix} must not overlap with {subject}"}, JSStreamExternalDelPrefixOverlapsErrF: {Code: 400, ErrCode: 10022, Description: "stream external delivery prefix {prefix} overlaps with stream subject {subject}"}, JSStreamGeneralErrorF: {Code: 500, ErrCode: 10051, Description: "{err}"}, @@ -611,8 +663,12 @@ var ( JSStreamTemplateCreateErrF: {Code: 500, ErrCode: 10066, Description: "{err}"}, JSStreamTemplateDeleteErrF: {Code: 500, ErrCode: 10067, Description: "{err}"}, JSStreamTemplateNotFoundErr: {Code: 404, ErrCode: 10068, Description: "template not found"}, + JSStreamTooManyRequests: {Code: 429, ErrCode: 10167, Description: "too many requests"}, + JSStreamTransformInvalidDestination: {Code: 400, ErrCode: 10156, Description: "stream transform: {err}"}, + JSStreamTransformInvalidSource: {Code: 400, ErrCode: 10155, Description: "stream transform source: {err}"}, JSStreamUpdateErrF: {Code: 500, ErrCode: 10069, Description: "{err}"}, JSStreamWrongLastMsgIDErrF: {Code: 400, ErrCode: 10070, Description: "wrong last msg ID: {id}"}, + JSStreamWrongLastSequenceConstantErr: {Code: 400, ErrCode: 10164, Description: "wrong last sequence"}, JSStreamWrongLastSequenceErrF: {Code: 400, ErrCode: 10071, Description: "wrong last sequence: {seq}"}, JSTempStorageFailedErr: {Code: 500, ErrCode: 10072, Description: "JetStream unable to open temp storage for restore"}, JSTemplateNameNotMatchSubjectErr: {Code: 400, ErrCode: 10073, Description: "template name in subject does not match request"}, @@ -959,6 +1015,16 @@ func NewJSConsumerEmptyFilterError(opts ...ErrorOption) *ApiError { return ApiErrors[JSConsumerEmptyFilter] } +// NewJSConsumerEmptyGroupNameError creates a new JSConsumerEmptyGroupName error: "Group name cannot be an empty string" +func NewJSConsumerEmptyGroupNameError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSConsumerEmptyGroupName] +} + // NewJSConsumerEphemeralWithDurableInSubjectError creates a new JSConsumerEphemeralWithDurableInSubjectErr error: "consumer expected to be ephemeral but detected a durable name set in subject" func NewJSConsumerEphemeralWithDurableInSubjectError(opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1045,6 +1111,16 @@ func NewJSConsumerInvalidDeliverSubjectError(opts ...ErrorOption) *ApiError { return ApiErrors[JSConsumerInvalidDeliverSubject] } +// NewJSConsumerInvalidGroupNameError creates a new JSConsumerInvalidGroupNameErr error: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters" +func NewJSConsumerInvalidGroupNameError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSConsumerInvalidGroupNameErr] +} + // NewJSConsumerInvalidPolicyError creates a new JSConsumerInvalidPolicyErrF error: "{err}" func NewJSConsumerInvalidPolicyError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1061,6 +1137,16 @@ func NewJSConsumerInvalidPolicyError(err error, opts ...ErrorOption) *ApiError { } } +// NewJSConsumerInvalidPriorityGroupError creates a new JSConsumerInvalidPriorityGroupErr error: "Provided priority group does not exist for this consumer" +func NewJSConsumerInvalidPriorityGroupError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSConsumerInvalidPriorityGroupErr] +} + // NewJSConsumerInvalidSamplingError creates a new JSConsumerInvalidSamplingErrF error: "failed to parse consumer sampling configuration: {err}" func NewJSConsumerInvalidSamplingError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1261,6 +1347,16 @@ func NewJSConsumerOverlappingSubjectFiltersError(opts ...ErrorOption) *ApiError return ApiErrors[JSConsumerOverlappingSubjectFilters] } +// NewJSConsumerPriorityPolicyWithoutGroupError creates a new JSConsumerPriorityPolicyWithoutGroup error: "Setting PriorityPolicy requires at least one PriorityGroup to be set" +func NewJSConsumerPriorityPolicyWithoutGroupError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSConsumerPriorityPolicyWithoutGroup] +} + // NewJSConsumerPullNotDurableError creates a new JSConsumerPullNotDurableErr error: "consumer in pull mode requires a durable name" func NewJSConsumerPullNotDurableError(opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1271,7 +1367,7 @@ func NewJSConsumerPullNotDurableError(opts ...ErrorOption) *ApiError { return ApiErrors[JSConsumerPullNotDurableErr] } -// NewJSConsumerPullRequiresAckError creates a new JSConsumerPullRequiresAckErr error: "consumer in pull mode requires ack policy" +// NewJSConsumerPullRequiresAckError creates a new JSConsumerPullRequiresAckErr error: "consumer in pull mode requires explicit ack policy on workqueue stream" func NewJSConsumerPullRequiresAckError(opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) if ae, ok := eopts.err.(*ApiError); ok { @@ -1417,14 +1513,20 @@ func NewJSInsufficientResourcesError(opts ...ErrorOption) *ApiError { return ApiErrors[JSInsufficientResourcesErr] } -// NewJSInvalidJSONError creates a new JSInvalidJSONErr error: "invalid JSON" -func NewJSInvalidJSONError(opts ...ErrorOption) *ApiError { +// NewJSInvalidJSONError creates a new JSInvalidJSONErr error: "invalid JSON: {err}" +func NewJSInvalidJSONError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) if ae, ok := eopts.err.(*ApiError); ok { return ae } - return ApiErrors[JSInvalidJSONErr] + e := ApiErrors[JSInvalidJSONErr] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } } // NewJSMaximumConsumersLimitError creates a new JSMaximumConsumersLimitErr error: "maximum consumers limit reached" @@ -1457,6 +1559,26 @@ func NewJSMemoryResourcesExceededError(opts ...ErrorOption) *ApiError { return ApiErrors[JSMemoryResourcesExceededErr] } +// NewJSMessageTTLDisabledError creates a new JSMessageTTLDisabledErr error: "per-message TTL is disabled" +func NewJSMessageTTLDisabledError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSMessageTTLDisabledErr] +} + +// NewJSMessageTTLInvalidError creates a new JSMessageTTLInvalidErr error: "invalid per-message TTL" +func NewJSMessageTTLInvalidError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSMessageTTLInvalidErr] +} + // NewJSMirrorConsumerSetupFailedError creates a new JSMirrorConsumerSetupFailedErrF error: "{err}" func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1483,14 +1605,36 @@ func NewJSMirrorInvalidStreamNameError(opts ...ErrorOption) *ApiError { return ApiErrors[JSMirrorInvalidStreamName] } -// NewJSMirrorInvalidSubjectFilterError creates a new JSMirrorInvalidSubjectFilter error: "mirror subject filter is invalid" -func NewJSMirrorInvalidSubjectFilterError(opts ...ErrorOption) *ApiError { +// NewJSMirrorInvalidSubjectFilterError creates a new JSMirrorInvalidSubjectFilter error: "mirror transform source: {err}" +func NewJSMirrorInvalidSubjectFilterError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) if ae, ok := eopts.err.(*ApiError); ok { return ae } - return ApiErrors[JSMirrorInvalidSubjectFilter] + e := ApiErrors[JSMirrorInvalidSubjectFilter] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } +} + +// NewJSMirrorInvalidTransformDestinationError creates a new JSMirrorInvalidTransformDestination error: "mirror transform: {err}" +func NewJSMirrorInvalidTransformDestinationError(err error, opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + e := ApiErrors[JSMirrorInvalidTransformDestination] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } } // NewJSMirrorMaxMessageSizeTooBigError creates a new JSMirrorMaxMessageSizeTooBigErr error: "stream mirror must have max message size >= source" @@ -1633,6 +1777,22 @@ func NewJSNotEnabledForAccountError(opts ...ErrorOption) *ApiError { return ApiErrors[JSNotEnabledForAccountErr] } +// NewJSPedanticError creates a new JSPedanticErrF error: "pedantic mode: {err}" +func NewJSPedanticError(err error, opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + e := ApiErrors[JSPedanticErrF] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } +} + // NewJSPeerRemapError creates a new JSPeerRemapErr error: "peer remap failed" func NewJSPeerRemapError(opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -1747,24 +1907,36 @@ func NewJSSourceInvalidStreamNameError(opts ...ErrorOption) *ApiError { return ApiErrors[JSSourceInvalidStreamName] } -// NewJSSourceInvalidSubjectFilterError creates a new JSSourceInvalidSubjectFilter error: "source subject filter is invalid" -func NewJSSourceInvalidSubjectFilterError(opts ...ErrorOption) *ApiError { +// NewJSSourceInvalidSubjectFilterError creates a new JSSourceInvalidSubjectFilter error: "source transform source: {err}" +func NewJSSourceInvalidSubjectFilterError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) if ae, ok := eopts.err.(*ApiError); ok { return ae } - return ApiErrors[JSSourceInvalidSubjectFilter] + e := ApiErrors[JSSourceInvalidSubjectFilter] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } } -// NewJSSourceInvalidTransformDestinationError creates a new JSSourceInvalidTransformDestination error: "source transform destination is invalid" -func NewJSSourceInvalidTransformDestinationError(opts ...ErrorOption) *ApiError { +// NewJSSourceInvalidTransformDestinationError creates a new JSSourceInvalidTransformDestination error: "source transform: {err}" +func NewJSSourceInvalidTransformDestinationError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) if ae, ok := eopts.err.(*ApiError); ok { return ae } - return ApiErrors[JSSourceInvalidTransformDestination] + e := ApiErrors[JSSourceInvalidTransformDestination] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } } // NewJSSourceMaxMessageSizeTooBigError creates a new JSSourceMaxMessageSizeTooBigErr error: "stream source must have max message size >= target" @@ -1855,6 +2027,26 @@ func NewJSStreamDeleteError(err error, opts ...ErrorOption) *ApiError { } } +// NewJSStreamDuplicateMessageConflictError creates a new JSStreamDuplicateMessageConflict error: "duplicate message id is in process" +func NewJSStreamDuplicateMessageConflictError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSStreamDuplicateMessageConflict] +} + +// NewJSStreamExpectedLastSeqPerSubjectNotReadyError creates a new JSStreamExpectedLastSeqPerSubjectNotReady error: "expected last sequence per subject temporarily unavailable" +func NewJSStreamExpectedLastSeqPerSubjectNotReadyError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSStreamExpectedLastSeqPerSubjectNotReady] +} + // NewJSStreamExternalApiOverlapError creates a new JSStreamExternalApiOverlapErrF error: "stream external api prefix {prefix} must not overlap with {subject}" func NewJSStreamExternalApiOverlapError(prefix interface{}, subject interface{}, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -2315,6 +2507,48 @@ func NewJSStreamTemplateNotFoundError(opts ...ErrorOption) *ApiError { return ApiErrors[JSStreamTemplateNotFoundErr] } +// NewJSStreamTooManyRequestsError creates a new JSStreamTooManyRequests error: "too many requests" +func NewJSStreamTooManyRequestsError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSStreamTooManyRequests] +} + +// NewJSStreamTransformInvalidDestinationError creates a new JSStreamTransformInvalidDestination error: "stream transform: {err}" +func NewJSStreamTransformInvalidDestinationError(err error, opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + e := ApiErrors[JSStreamTransformInvalidDestination] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } +} + +// NewJSStreamTransformInvalidSourceError creates a new JSStreamTransformInvalidSource error: "stream transform source: {err}" +func NewJSStreamTransformInvalidSourceError(err error, opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + e := ApiErrors[JSStreamTransformInvalidSource] + args := e.toReplacerArgs([]interface{}{"{err}", err}) + return &ApiError{ + Code: e.Code, + ErrCode: e.ErrCode, + Description: strings.NewReplacer(args...).Replace(e.Description), + } +} + // NewJSStreamUpdateError creates a new JSStreamUpdateErrF error: "{err}" func NewJSStreamUpdateError(err error, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) @@ -2347,6 +2581,16 @@ func NewJSStreamWrongLastMsgIDError(id interface{}, opts ...ErrorOption) *ApiErr } } +// NewJSStreamWrongLastSequenceConstantError creates a new JSStreamWrongLastSequenceConstantErr error: "wrong last sequence" +func NewJSStreamWrongLastSequenceConstantError(opts ...ErrorOption) *ApiError { + eopts := parseOpts(opts) + if ae, ok := eopts.err.(*ApiError); ok { + return ae + } + + return ApiErrors[JSStreamWrongLastSequenceConstantErr] +} + // NewJSStreamWrongLastSequenceError creates a new JSStreamWrongLastSequenceErrF error: "wrong last sequence: {seq}" func NewJSStreamWrongLastSequenceError(seq uint64, opts ...ErrorOption) *ApiError { eopts := parseOpts(opts) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_events.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_events.go index f813efc483..6b4ac9de8c 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_events.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_events.go @@ -90,6 +90,18 @@ type JSConsumerActionAdvisory struct { const JSConsumerActionAdvisoryType = "io.nats.jetstream.advisory.v1.consumer_action" +// JSConsumerPauseAdvisory indicates that a consumer was paused or unpaused +type JSConsumerPauseAdvisory struct { + TypedEvent + Stream string `json:"stream"` + Consumer string `json:"consumer"` + Paused bool `json:"paused"` + PauseUntil time.Time `json:"pause_until,omitempty"` + Domain string `json:"domain,omitempty"` +} + +const JSConsumerPauseAdvisoryType = "io.nats.jetstream.advisory.v1.consumer_pause" + // JSConsumerAckMetric is a metric published when a user acknowledges a message, the // number of these that will be published is dependent on SampleFrequency type JSConsumerAckMetric struct { @@ -269,6 +281,33 @@ type JSConsumerQuorumLostAdvisory struct { Domain string `json:"domain,omitempty"` } +const JSConsumerGroupPinnedAdvisoryType = "io.nats.jetstream.advisory.v1.consumer_group_pinned" + +// JSConsumerGroupPinnedAdvisory that a group switched to a new pinned client +type JSConsumerGroupPinnedAdvisory struct { + TypedEvent + Account string `json:"account,omitempty"` + Stream string `json:"stream"` + Consumer string `json:"consumer"` + Domain string `json:"domain,omitempty"` + Group string `json:"group"` + PinnedClientId string `json:"pinned_id"` +} + +const JSConsumerGroupUnpinnedAdvisoryType = "io.nats.jetstream.advisory.v1.consumer_group_unpinned" + +// JSConsumerGroupUnpinnedAdvisory indicates that a pin was lost +type JSConsumerGroupUnpinnedAdvisory struct { + TypedEvent + Account string `json:"account,omitempty"` + Stream string `json:"stream"` + Consumer string `json:"consumer"` + Domain string `json:"domain,omitempty"` + Group string `json:"group"` + // one of "admin" or "timeout", could be an enum up to the implementor to decide + Reason string `json:"reason"` +} + // JSServerOutOfStorageAdvisoryType is sent when the server is out of storage space. const JSServerOutOfStorageAdvisoryType = "io.nats.jetstream.advisory.v1.server_out_of_space" diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jetstream_versioning.go b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_versioning.go new file mode 100644 index 0000000000..1192968b88 --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/jetstream_versioning.go @@ -0,0 +1,179 @@ +// Copyright 2024-2025 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import "strconv" + +const ( + // JSApiLevel is the maximum supported JetStream API level for this server. + JSApiLevel int = 1 + + JSRequiredLevelMetadataKey = "_nats.req.level" + JSServerVersionMetadataKey = "_nats.ver" + JSServerLevelMetadataKey = "_nats.level" +) + +// setStaticStreamMetadata sets JetStream stream metadata, like the server version and API level. +// Any dynamic metadata is removed, it must not be stored and only be added for responses. +func setStaticStreamMetadata(cfg *StreamConfig) { + if cfg.Metadata == nil { + cfg.Metadata = make(map[string]string) + } else { + deleteDynamicMetadata(cfg.Metadata) + } + + var requiredApiLevel int + requires := func(level int) { + if level > requiredApiLevel { + requiredApiLevel = level + } + } + + // TTLs were added in v2.11 and require API level 1. + if cfg.AllowMsgTTL || cfg.SubjectDeleteMarkerTTL > 0 { + requires(1) + } + + cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel) +} + +// setDynamicStreamMetadata adds dynamic fields into the (copied) metadata. +func setDynamicStreamMetadata(cfg *StreamConfig) *StreamConfig { + newCfg := *cfg + newCfg.Metadata = make(map[string]string) + for key, value := range cfg.Metadata { + newCfg.Metadata[key] = value + } + newCfg.Metadata[JSServerVersionMetadataKey] = VERSION + newCfg.Metadata[JSServerLevelMetadataKey] = strconv.Itoa(JSApiLevel) + return &newCfg +} + +// copyConsumerMetadata copies versioning fields from metadata of prevCfg into cfg. +// Removes versioning fields if no previous metadata, updates if set, and removes fields if it doesn't exist in prevCfg. +// Any dynamic metadata is removed, it must not be stored and only be added for responses. +// +// Note: useful when doing equality checks on cfg and prevCfg, but ignoring any versioning metadata differences. +func copyStreamMetadata(cfg *StreamConfig, prevCfg *StreamConfig) { + if cfg.Metadata != nil { + deleteDynamicMetadata(cfg.Metadata) + } + setOrDeleteInStreamMetadata(cfg, prevCfg, JSRequiredLevelMetadataKey) +} + +// setOrDeleteInConsumerMetadata sets field with key/value in metadata of cfg if set, deletes otherwise. +func setOrDeleteInStreamMetadata(cfg *StreamConfig, prevCfg *StreamConfig, key string) { + if prevCfg != nil && prevCfg.Metadata != nil { + if value, ok := prevCfg.Metadata[key]; ok { + if cfg.Metadata == nil { + cfg.Metadata = make(map[string]string) + } + cfg.Metadata[key] = value + return + } + } + delete(cfg.Metadata, key) + if len(cfg.Metadata) == 0 { + cfg.Metadata = nil + } +} + +// setStaticConsumerMetadata sets JetStream consumer metadata, like the server version and API level. +// Any dynamic metadata is removed, it must not be stored and only be added for responses. +func setStaticConsumerMetadata(cfg *ConsumerConfig) { + if cfg.Metadata == nil { + cfg.Metadata = make(map[string]string) + } else { + deleteDynamicMetadata(cfg.Metadata) + } + + var requiredApiLevel int + requires := func(level int) { + if level > requiredApiLevel { + requiredApiLevel = level + } + } + + // Added in 2.11, absent | zero is the feature is not used. + // one could be stricter and say even if its set but the time + // has already passed it is also not needed to restore the consumer + if cfg.PauseUntil != nil && !cfg.PauseUntil.IsZero() { + requires(1) + } + + if cfg.PriorityPolicy != PriorityNone || cfg.PinnedTTL != 0 || len(cfg.PriorityGroups) > 0 { + requires(1) + } + + cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel) +} + +// setDynamicConsumerMetadata adds dynamic fields into the (copied) metadata. +func setDynamicConsumerMetadata(cfg *ConsumerConfig) *ConsumerConfig { + newCfg := *cfg + newCfg.Metadata = make(map[string]string) + for key, value := range cfg.Metadata { + newCfg.Metadata[key] = value + } + newCfg.Metadata[JSServerVersionMetadataKey] = VERSION + newCfg.Metadata[JSServerLevelMetadataKey] = strconv.Itoa(JSApiLevel) + return &newCfg +} + +// setDynamicConsumerInfoMetadata adds dynamic fields into the (copied) metadata. +func setDynamicConsumerInfoMetadata(info *ConsumerInfo) *ConsumerInfo { + if info == nil { + return nil + } + + newInfo := *info + cfg := setDynamicConsumerMetadata(info.Config) + newInfo.Config = cfg + return &newInfo +} + +// copyConsumerMetadata copies versioning fields from metadata of prevCfg into cfg. +// Removes versioning fields if no previous metadata, updates if set, and removes fields if it doesn't exist in prevCfg. +// Any dynamic metadata is removed, it must not be stored and only be added for responses. +// +// Note: useful when doing equality checks on cfg and prevCfg, but ignoring any versioning metadata differences. +func copyConsumerMetadata(cfg *ConsumerConfig, prevCfg *ConsumerConfig) { + if cfg.Metadata != nil { + deleteDynamicMetadata(cfg.Metadata) + } + setOrDeleteInConsumerMetadata(cfg, prevCfg, JSRequiredLevelMetadataKey) +} + +// setOrDeleteInConsumerMetadata sets field with key/value in metadata of cfg if set, deletes otherwise. +func setOrDeleteInConsumerMetadata(cfg *ConsumerConfig, prevCfg *ConsumerConfig, key string) { + if prevCfg != nil && prevCfg.Metadata != nil { + if value, ok := prevCfg.Metadata[key]; ok { + if cfg.Metadata == nil { + cfg.Metadata = make(map[string]string) + } + cfg.Metadata[key] = value + return + } + } + delete(cfg.Metadata, key) + if len(cfg.Metadata) == 0 { + cfg.Metadata = nil + } +} + +// deleteDynamicMetadata deletes dynamic fields from the metadata. +func deleteDynamicMetadata(metadata map[string]string) { + delete(metadata, JSServerVersionMetadataKey) + delete(metadata, JSServerLevelMetadataKey) +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/jwt.go b/vendor/github.com/nats-io/nats-server/v2/server/jwt.go index b900326561..ada2d91f43 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/jwt.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/jwt.go @@ -1,4 +1,4 @@ -// Copyright 2018-2022 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -106,8 +106,8 @@ func validateTrustedOperators(o *Options) error { return fmt.Errorf("using nats based account resolver - the system account needs to be specified in configuration or the operator jwt") } } - ver := strings.Split(strings.Split(strings.Split(VERSION, "-")[0], ".RC")[0], ".beta")[0] - srvMajor, srvMinor, srvUpdate, _ := jwt.ParseServerVersion(ver) + + srvMajor, srvMinor, srvUpdate, _ := versionComponents(VERSION) for _, opc := range o.TrustedOperators { if major, minor, update, err := jwt.ParseServerVersion(opc.AssertServerVersion); err != nil { return fmt.Errorf("operator %s expects version %s got error instead: %s", diff --git a/vendor/github.com/nats-io/nats-server/v2/server/leafnode.go b/vendor/github.com/nats-io/nats-server/v2/server/leafnode.go index 1ec4cc1849..4904aee2f2 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/leafnode.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/leafnode.go @@ -20,6 +20,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "math/rand" "net" "net/http" @@ -34,7 +35,6 @@ import ( "sync" "sync/atomic" "time" - "unicode" "github.com/klauspost/compress/s2" "github.com/nats-io/jwt/v2" @@ -105,13 +105,14 @@ type leaf struct { type leafNodeCfg struct { sync.RWMutex *RemoteLeafOpts - urls []*url.URL - curURL *url.URL - tlsName string - username string - password string - perms *Permissions - connDelay time.Duration // Delay before a connect, could be used while detecting loop condition, etc.. + urls []*url.URL + curURL *url.URL + tlsName string + username string + password string + perms *Permissions + connDelay time.Duration // Delay before a connect, could be used while detecting loop condition, etc.. + jsMigrateTimer *time.Timer } // Check to see if this is a solicited leafnode. We do special processing for solicited. @@ -493,14 +494,15 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) opts := s.getOpts() reconnectDelay := opts.LeafNode.ReconnectInterval - s.mu.Lock() + s.mu.RLock() dialTimeout := s.leafNodeOpts.dialTimeout resolver := s.leafNodeOpts.resolver var isSysAcc bool if s.eventsEnabled() { isSysAcc = remote.LocalAccount == s.sys.account.Name } - s.mu.Unlock() + jetstreamMigrateDelay := remote.JetStreamClusterMigrateDelay + s.mu.RUnlock() // If we are sharing a system account and we are not standalone delay to gather some info prior. if firstConnect && isSysAcc && !s.standAloneMode() { @@ -522,6 +524,7 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) const connErrFmt = "Error trying to connect as leafnode to remote server %q (attempt %v): %v" attempts := 0 + for s.isRunning() && s.remoteLeafNodeStillValid(remote) { rURL := remote.pickNextURL() url, err := s.getRandomIP(resolver, rURL.Host, nil) @@ -548,15 +551,28 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) } else { s.Debugf(connErrFmt, rURL.Host, attempts, err) } + remote.Lock() + // if we are using a delay to start migrating assets, kick off a migrate timer. + if remote.jsMigrateTimer == nil && jetstreamMigrateDelay > 0 { + remote.jsMigrateTimer = time.AfterFunc(jetstreamMigrateDelay, func() { + s.checkJetStreamMigrate(remote) + }) + } + remote.Unlock() select { case <-s.quitCh: + remote.cancelMigrateTimer() return case <-time.After(delay): - // Check if we should migrate any JetStream assets while this remote is down. - s.checkJetStreamMigrate(remote) + // Check if we should migrate any JetStream assets immediately while this remote is down. + // This will be used if JetStreamClusterMigrateDelay was not set + if jetstreamMigrateDelay == 0 { + s.checkJetStreamMigrate(remote) + } continue } } + remote.cancelMigrateTimer() if !s.remoteLeafNodeStillValid(remote) { conn.Close() return @@ -573,6 +589,12 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) } } +func (cfg *leafNodeCfg) cancelMigrateTimer() { + cfg.Lock() + stopAndClearTimer(&cfg.jsMigrateTimer) + cfg.Unlock() +} + // This will clear any observer state such that stream or consumer assets on this server can become leaders again. func (s *Server) clearObserverState(remote *leafNodeCfg) { s.mu.RLock() @@ -637,17 +659,13 @@ func (s *Server) checkJetStreamMigrate(remote *leafNodeCfg) { // Collect any consumers for _, o := range mset.getConsumers() { if n := o.raftNode(); n != nil { - if n.Leader() { - n.StepDown() - } + n.StepDown() // Ensure we can not become a leader while in this state. n.SetObserver(true) } } // Stepdown if this stream was leader. - if node.Leader() { - node.StepDown() - } + node.StepDown() // Ensure we can not become a leader while in this state. node.SetObserver(true) } @@ -728,7 +746,7 @@ func (s *Server) startLeafNodeAcceptLoop() { Headers: s.supportsHeaders(), JetStream: opts.JetStream, Domain: opts.JetStreamDomain, - Proto: 1, // Fixed for now. + Proto: s.getServerProto(), InfoOnConnect: true, } // If we have selected a random port... @@ -792,6 +810,7 @@ func (c *client) sendLeafConnect(clusterName string, headers bool) error { DenyPub: c.leaf.remote.DenyImports, Compression: c.leaf.compression, RemoteAccount: c.acc.GetName(), + Proto: c.srv.getServerProto(), } // If a signature callback is specified, this takes precedence over anything else. @@ -996,8 +1015,11 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf c.initClient() c.Noticef("Leafnode connection created%s %s", remoteSuffix, c.opts.Name) - var tlsFirst bool - var infoTimeout time.Duration + var ( + tlsFirst bool + tlsFirstFallback time.Duration + infoTimeout time.Duration + ) if remote != nil { solicited = true remote.Lock() @@ -1015,6 +1037,10 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf if ws != nil { c.Debugf("Leafnode compression=%v", c.ws.compress) } + tlsFirst = opts.LeafNode.TLSHandshakeFirst + if f := opts.LeafNode.TLSHandshakeFirstFallback; f > 0 { + tlsFirstFallback = f + } } c.mu.Unlock() @@ -1078,7 +1104,33 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf info.Nonce = bytesToString(c.nonce) info.CID = c.cid proto := generateInfoJSON(info) - if !opts.LeafNode.TLSHandshakeFirst { + + var pre []byte + // We need first to check for "TLS First" fallback delay. + if tlsFirstFallback > 0 { + // We wait and see if we are getting any data. Since we did not send + // the INFO protocol yet, only clients that use TLS first should be + // sending data (the TLS handshake). We don't really check the content: + // if it is a rogue agent and not an actual client performing the + // TLS handshake, the error will be detected when performing the + // handshake on our side. + pre = make([]byte, 4) + c.nc.SetReadDeadline(time.Now().Add(tlsFirstFallback)) + n, _ := io.ReadFull(c.nc, pre[:]) + c.nc.SetReadDeadline(time.Time{}) + // If we get any data (regardless of possible timeout), we will proceed + // with the TLS handshake. + if n > 0 { + pre = pre[:n] + } else { + // We did not get anything so we will send the INFO protocol. + pre = nil + // Set the boolean to false for the rest of the function. + tlsFirst = false + } + } + + if !tlsFirst { // We have to send from this go routine because we may // have to block for TLS handshake before we start our // writeLoop go routine. The other side needs to receive @@ -1095,6 +1147,10 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf // Check to see if we need to spin up TLS. if !c.isWebsocket() && info.TLSRequired { + // If we have a prebuffer create a multi-reader. + if len(pre) > 0 { + c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} + } // Perform server-side TLS handshake. if err := c.doTLSServerHandshake(tlsHandshakeLeaf, opts.LeafNode.TLSConfig, opts.LeafNode.TLSTimeout, opts.LeafNode.TLSPinnedCerts); err != nil { c.mu.Unlock() @@ -1104,7 +1160,8 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf // If the user wants the TLS handshake to occur first, now that it is // done, send the INFO protocol. - if opts.LeafNode.TLSHandshakeFirst { + if tlsFirst { + c.flags.set(didTLSFirst) c.sendProtoNow(proto) if c.isClosed() { c.mu.Unlock() @@ -1297,6 +1354,13 @@ func (c *client) processLeafnodeInfo(info *Info) { c.closeConnection(WrongPort) return } + // Reject a cluster that contains spaces. + if info.Cluster != _EMPTY_ && strings.Contains(info.Cluster, " ") { + c.mu.Unlock() + c.sendErrAndErr(ErrClusterNameHasSpaces.Error()) + c.closeConnection(ProtocolViolation) + return + } // Capture a nonce here. c.nonce = []byte(info.Nonce) if info.TLSRequired && didSolicit { @@ -1316,6 +1380,10 @@ func (c *client) processLeafnodeInfo(info *Info) { } c.leaf.remoteDomain = info.Domain c.leaf.remoteCluster = info.Cluster + // We send the protocol version in the INFO protocol. + // Keep track of it, so we know if this connection supports message + // tracing for instance. + c.opts.Protocol = info.Proto } // For both initial INFO and async INFO protocols, Possibly @@ -1757,6 +1825,14 @@ type leafConnectInfo struct { // Tells the accept side which account the remote is binding to. RemoteAccount string `json:"remote_account,omitempty"` + + // The accept side of a LEAF connection, unlike ROUTER and GATEWAY, receives + // only the CONNECT protocol, and no INFO. So we need to send the protocol + // version as part of the CONNECT. It will indicate if a connection supports + // some features, such as message tracing. + // We use `protocol` as the JSON tag, so this is automatically unmarshal'ed + // in the low level process CONNECT. + Proto int `json:"protocol,omitempty"` } // processLeafNodeConnect will process the inbound connect args. @@ -1777,8 +1853,8 @@ func (c *client) processLeafNodeConnect(s *Server, arg []byte, lang string) erro return err } - // Reject a cluster that contains spaces or line breaks. - if proto.Cluster != _EMPTY_ && strings.ContainsFunc(proto.Cluster, unicode.IsSpace) { + // Reject a cluster that contains spaces. + if proto.Cluster != _EMPTY_ && strings.Contains(proto.Cluster, " ") { c.sendErrAndErr(ErrClusterNameHasSpaces.Error()) c.closeConnection(ProtocolViolation) return ErrClusterNameHasSpaces diff --git a/vendor/github.com/nats-io/nats-server/v2/server/memstore.go b/vendor/github.com/nats-io/nats-server/v2/server/memstore.go index a72e1a4249..297787f9d3 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/memstore.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/memstore.go @@ -17,12 +17,15 @@ import ( crand "crypto/rand" "encoding/binary" "fmt" + "math" + "slices" "sort" "sync" "time" "github.com/nats-io/nats-server/v2/server/avl" "github.com/nats-io/nats-server/v2/server/stree" + "github.com/nats-io/nats-server/v2/server/thw" ) // TODO(dlc) - This is a fairly simplistic approach but should do for now. @@ -35,9 +38,12 @@ type memStore struct { dmap avl.SequenceSet maxp int64 scb StorageUpdateHandler + sdmcb SubjectDeleteMarkerUpdateHandler ageChk *time.Timer consumers int receivedAny bool + ttls *thw.HashWheel + markers []string } func newMemStore(cfg *StreamConfig) (*memStore, error) { @@ -53,8 +59,12 @@ func newMemStore(cfg *StreamConfig) (*memStore, error) { maxp: cfg.MaxMsgsPer, cfg: *cfg, } + // Only create a THW if we're going to allow TTLs. + if cfg.AllowMsgTTL { + ms.ttls = thw.NewHashWheel() + } if cfg.FirstSeq > 0 { - if _, err := ms.purge(cfg.FirstSeq); err != nil { + if _, err := ms.purge(cfg.FirstSeq, true); err != nil { return nil, err } } @@ -109,7 +119,7 @@ func (ms *memStore) UpdateConfig(cfg *StreamConfig) error { // Stores a raw message with expected sequence number and timestamp. // Lock should be held. -func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts int64) error { +func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error { if ms.msgs == nil { return ErrStoreClosed } @@ -210,17 +220,27 @@ func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts int ms.enforceMsgLimit() ms.enforceBytesLimit() + // Per-message TTL. + if ms.ttls != nil && ttl > 0 { + expires := time.Duration(ts) + (time.Second * time.Duration(ttl)) + ms.ttls.Add(seq, int64(expires)) + } + // Check if we have and need the age expiration timer running. - if ms.ageChk == nil && ms.cfg.MaxAge != 0 { + switch { + case ms.ttls != nil && ttl > 0: + ms.resetAgeChk(0) + case ms.ageChk == nil && (ms.cfg.MaxAge > 0 || ms.ttls != nil): ms.startAgeChk() } + return nil } // StoreRawMsg stores a raw message with expected sequence number and timestamp. -func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts int64) error { +func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error { ms.mu.Lock() - err := ms.storeRawMsg(subj, hdr, msg, seq, ts) + err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl) cb := ms.scb // Check if first message timestamp requires expiry // sooner than initial replica expiry timer set to MaxAge when initializing. @@ -239,10 +259,10 @@ func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts int } // Store stores a message. -func (ms *memStore) StoreMsg(subj string, hdr, msg []byte) (uint64, int64, error) { +func (ms *memStore) StoreMsg(subj string, hdr, msg []byte, ttl int64) (uint64, int64, error) { ms.mu.Lock() seq, ts := ms.state.LastSeq+1, time.Now().UnixNano() - err := ms.storeRawMsg(subj, hdr, msg, seq, ts) + err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl) cb := ms.scb ms.mu.Unlock() @@ -312,6 +332,13 @@ func (ms *memStore) RegisterStorageUpdates(cb StorageUpdateHandler) { ms.mu.Unlock() } +// RegisterSubjectDeleteMarkerUpdates registers a callback for updates to new subject delete markers. +func (ms *memStore) RegisterSubjectDeleteMarkerUpdates(cb SubjectDeleteMarkerUpdateHandler) { + ms.mu.Lock() + ms.sdmcb = cb + ms.mu.Unlock() +} + // GetSeqFromTime looks for the first sequence number that has the message // with >= timestamp. // FIXME(dlc) - inefficient. @@ -601,7 +628,55 @@ func (ms *memStore) SubjectsState(subject string) map[string]SimpleState { return fss } -// SubjectsTotal return message totals per subject. +func (ms *memStore) MultiLastSeqs(filters []string, maxSeq uint64, maxAllowed int) ([]uint64, error) { + ms.mu.RLock() + defer ms.mu.RUnlock() + + if len(ms.msgs) == 0 { + return nil, nil + } + + // Implied last sequence. + if maxSeq == 0 { + maxSeq = ms.state.LastSeq + } + + //subs := make(map[string]*SimpleState) + seqs := make([]uint64, 0, 64) + seen := make(map[uint64]struct{}) + + addIfNotDupe := func(seq uint64) { + if _, ok := seen[seq]; !ok { + seqs = append(seqs, seq) + seen[seq] = struct{}{} + } + } + + for _, filter := range filters { + ms.fss.Match(stringToBytes(filter), func(subj []byte, ss *SimpleState) { + if ss.Last <= maxSeq { + addIfNotDupe(ss.Last) + } else if ss.Msgs > 1 { + // The last is greater than maxSeq. + s := bytesToString(subj) + for seq := maxSeq; seq > 0; seq-- { + if sm, ok := ms.msgs[seq]; ok && sm.subj == s { + addIfNotDupe(seq) + break + } + } + } + }) + // If maxAllowed was sepcified check that we will not exceed that. + if maxAllowed > 0 && len(seqs) > maxAllowed { + return nil, ErrTooManyResults + } + } + slices.Sort(seqs) + return seqs, nil +} + +// SubjectsTotals return message totals per subject. func (ms *memStore) SubjectsTotals(filterSubject string) map[string]uint64 { ms.mu.RLock() defer ms.mu.RUnlock() @@ -797,7 +872,7 @@ func (ms *memStore) enforcePerSubjectLimit(subj string, ss *SimpleState) { if ss.firstNeedsUpdate || ss.lastNeedsUpdate { ms.recalculateForSubj(subj, ss) } - if !ms.removeMsg(ss.First, false) { + if !ms.removeMsg(ss.First, false, _EMPTY_) { break } } @@ -834,25 +909,52 @@ func (ms *memStore) enforceBytesLimit() { // Will start the age check timer. // Lock should be held. func (ms *memStore) startAgeChk() { - if ms.ageChk == nil && ms.cfg.MaxAge != 0 { + if ms.ageChk != nil { + return + } + if ms.cfg.MaxAge != 0 || ms.ttls != nil { ms.ageChk = time.AfterFunc(ms.cfg.MaxAge, ms.expireMsgs) } } // Lock should be held. func (ms *memStore) resetAgeChk(delta int64) { - if ms.cfg.MaxAge == 0 { + var next int64 = math.MaxInt64 + if ms.ttls != nil { + next = ms.ttls.GetNextExpiration(next) + } + + // If there's no MaxAge and there's nothing waiting to be expired then + // don't bother continuing. The next storeRawMsg() will wake us up if + // needs be. + if ms.cfg.MaxAge <= 0 && next == math.MaxInt64 { + clearTimer(&ms.ageChk) return } + // Check to see if we should be firing sooner than MaxAge for an expiring TTL. fireIn := ms.cfg.MaxAge - if delta > 0 && time.Duration(delta) < fireIn { - if fireIn = time.Duration(delta); fireIn < 250*time.Millisecond { - // Only fire at most once every 250ms. - // Excessive firing can effect ingest performance. - fireIn = time.Second + if next < math.MaxInt64 { + // Looks like there's a next expiration, use it either if there's no + // MaxAge set or if it looks to be sooner than MaxAge is. + if until := time.Until(time.Unix(0, next)); fireIn == 0 || until < fireIn { + fireIn = until } } + + // If not then look at the delta provided (usually gap to next age expiry). + if delta > 0 { + if fireIn == 0 || time.Duration(delta) < fireIn { + fireIn = time.Duration(delta) + } + } + + // Make sure we aren't firing too often either way, otherwise we can + // negatively impact stream ingest performance. + if fireIn < 250*time.Millisecond { + fireIn = 250 * time.Millisecond + } + if ms.ageChk != nil { ms.ageChk.Reset(fireIn) } else { @@ -860,56 +962,147 @@ func (ms *memStore) resetAgeChk(delta int64) { } } +// Lock should be held. +func (ms *memStore) cancelAgeChk() { + if ms.ageChk != nil { + ms.ageChk.Stop() + ms.ageChk = nil + } +} + +// Lock must be held so that nothing else can interleave and write a +// new message on this subject before we get the chance to write the +// delete marker. If the delete marker is written successfully then +// this function returns a callback func to call scb and sdmcb after +// the lock has been released. +func (ms *memStore) subjectDeleteMarkerIfNeeded(subj string, reason string) func() { + if ms.cfg.SubjectDeleteMarkerTTL <= 0 { + return nil + } + if _, ok := ms.fss.Find(stringToBytes(subj)); ok { + // There are still messages left with this subject, + // therefore it wasn't the last message deleted. + return nil + } + // Build the subject delete marker. If no TTL is specified then + // we'll default to 15 minutes — by that time every possible condition + // should have cleared (i.e. ordered consumer timeout, client timeouts, + // route/gateway interruptions, even device/client restarts etc). + ttl := int64(ms.cfg.SubjectDeleteMarkerTTL.Seconds()) + if ttl <= 0 { + return nil + } + var _hdr [128]byte + hdr := fmt.Appendf( + _hdr[:0], + "NATS/1.0\r\n%s: %s\r\n%s: %s\r\n%s: %d\r\n%s: %s\r\n\r\n\r\n", + JSMarkerReason, reason, + JSMessageTTL, time.Duration(ttl)*time.Second, + JSExpectedLastSubjSeq, 0, + JSExpectedLastSubjSeqSubj, subj, + ) + msg := &inMsg{ + subj: subj, + hdr: hdr, + } + sdmcb := ms.sdmcb + return func() { + if sdmcb != nil { + sdmcb(msg) + } + } +} + +// Memstore lock must be held. The caller should call the callback, if non-nil, +// after releasing the memstore lock. +func (ms *memStore) subjectDeleteMarkersAfterOperation(reason string) func() { + if ms.cfg.SubjectDeleteMarkerTTL <= 0 || len(ms.markers) == 0 { + return nil + } + cbs := make([]func(), 0, len(ms.markers)) + for _, subject := range ms.markers { + if cb := ms.subjectDeleteMarkerIfNeeded(subject, reason); cb != nil { + cbs = append(cbs, cb) + } + } + ms.markers = nil + return func() { + for _, cb := range cbs { + cb() + } + } +} + // Will expire msgs that are too old. func (ms *memStore) expireMsgs() { + var smv StoreMsg + var sm *StoreMsg ms.mu.RLock() - now := time.Now().UnixNano() - minAge := now - int64(ms.cfg.MaxAge) + maxAge := int64(ms.cfg.MaxAge) + minAge := time.Now().UnixNano() - maxAge ms.mu.RUnlock() - for { - ms.mu.Lock() - if sm, ok := ms.msgs[ms.state.FirstSeq]; ok && sm.ts <= minAge { - ms.deleteFirstMsgOrPanic() - // Recalculate in case we are expiring a bunch. - now = time.Now().UnixNano() - minAge = now - int64(ms.cfg.MaxAge) - ms.mu.Unlock() - } else { - // We will exit here - if len(ms.msgs) == 0 { - if ms.ageChk != nil { - ms.ageChk.Stop() - ms.ageChk = nil - } - } else { - var fireIn time.Duration - if sm == nil { - fireIn = ms.cfg.MaxAge - } else { - fireIn = time.Duration(sm.ts - minAge) - } - if ms.ageChk != nil { - ms.ageChk.Reset(fireIn) - } else { - ms.ageChk = time.AfterFunc(fireIn, ms.expireMsgs) + if maxAge > 0 { + var seq uint64 + for sm, seq, _ = ms.LoadNextMsg(fwcs, true, 0, &smv); sm != nil && sm.ts <= minAge; sm, seq, _ = ms.LoadNextMsg(fwcs, true, seq+1, &smv) { + if len(sm.hdr) > 0 { + if ttl, err := getMessageTTL(sm.hdr); err == nil && ttl < 0 { + // The message has a negative TTL, therefore it must "never expire". + minAge = time.Now().UnixNano() - maxAge + continue } } + ms.mu.Lock() + ms.removeMsg(seq, false, JSMarkerReasonMaxAge) ms.mu.Unlock() - break + // Recalculate in case we are expiring a bunch. + minAge = time.Now().UnixNano() - maxAge + } + } + + ms.mu.Lock() + defer ms.mu.Unlock() + + // TODO: Not great that we're holding the lock here, but the timed hash wheel isn't thread-safe. + nextTTL := int64(math.MaxInt64) + if ms.ttls != nil { + ms.ttls.ExpireTasks(func(seq uint64, ts int64) { + ms.removeMsg(seq, false, _EMPTY_) + }) + if maxAge > 0 { + // Only check if we're expiring something in the next MaxAge interval, saves us a bit + // of work if MaxAge will beat us to the next expiry anyway. + nextTTL = ms.ttls.GetNextExpiration(time.Now().Add(time.Duration(maxAge)).UnixNano()) + } else { + nextTTL = ms.ttls.GetNextExpiration(math.MaxInt64) + } + } + + // Only cancel if no message left, not on potential lookup error that would result in sm == nil. + if ms.state.Msgs == 0 && nextTTL == math.MaxInt64 { + ms.cancelAgeChk() + } else { + if sm == nil { + ms.resetAgeChk(0) + } else { + ms.resetAgeChk(sm.ts - minAge) } } } // PurgeEx will remove messages based on subject filters, sequence and number of messages to keep. // Will return the number of purged messages. -func (ms *memStore) PurgeEx(subject string, sequence, keep uint64) (purged uint64, err error) { +func (ms *memStore) PurgeEx(subject string, sequence, keep uint64, _ /* noMarkers */ bool) (purged uint64, err error) { + // TODO: Don't write markers on purge until we have solved performance + // issues with them. + noMarkers := true + if subject == _EMPTY_ || subject == fwcs { if keep == 0 && sequence == 0 { - return ms.Purge() + return ms.purge(0, noMarkers) } if sequence > 1 { - return ms.Compact(sequence) + return ms.compact(sequence, noMarkers) } else if keep > 0 { ms.mu.RLock() msgs, lseq := ms.state.Msgs, ms.state.LastSeq @@ -917,7 +1110,7 @@ func (ms *memStore) PurgeEx(subject string, sequence, keep uint64) (purged uint6 if keep >= msgs { return 0, nil } - return ms.Compact(lseq - keep + 1) + return ms.compact(lseq-keep+1, noMarkers) } return 0, nil @@ -935,9 +1128,13 @@ func (ms *memStore) PurgeEx(subject string, sequence, keep uint64) (purged uint6 last = sequence - 1 } ms.mu.Lock() + var removeReason string + if !noMarkers { + removeReason = JSMarkerReasonPurge + } for seq := ss.First; seq <= last; seq++ { if sm, ok := ms.msgs[seq]; ok && eq(sm.subj, subject) { - if ok := ms.removeMsg(sm.seq, false); ok { + if ok := ms.removeMsg(sm.seq, false, removeReason); ok { purged++ if purged >= ss.Msgs { break @@ -956,10 +1153,14 @@ func (ms *memStore) Purge() (uint64, error) { ms.mu.RLock() first := ms.state.LastSeq + 1 ms.mu.RUnlock() - return ms.purge(first) + return ms.purge(first, false) } -func (ms *memStore) purge(fseq uint64) (uint64, error) { +func (ms *memStore) purge(fseq uint64, _ /* noMarkers */ bool) (uint64, error) { + // TODO: Don't write markers on purge until we have solved performance + // issues with them. + noMarkers := true + ms.mu.Lock() purged := uint64(len(ms.msgs)) cb := ms.scb @@ -974,12 +1175,23 @@ func (ms *memStore) purge(fseq uint64) (uint64, error) { ms.state.Bytes = 0 ms.state.Msgs = 0 ms.msgs = make(map[uint64]*StoreMsg) + // Subject delete markers if needed. + if !noMarkers && ms.cfg.SubjectDeleteMarkerTTL > 0 { + ms.fss.IterOrdered(func(bsubj []byte, ss *SimpleState) bool { + ms.markers = append(ms.markers, string(bsubj)) + return true + }) + } ms.fss = stree.NewSubjectTree[SimpleState]() + sdmcb := ms.subjectDeleteMarkersAfterOperation(JSMarkerReasonPurge) ms.mu.Unlock() if cb != nil { cb(-int64(purged), -bytes, 0, _EMPTY_) } + if sdmcb != nil { + sdmcb() + } return purged, nil } @@ -988,6 +1200,14 @@ func (ms *memStore) purge(fseq uint64) (uint64, error) { // but not including the seq parameter. // Will return the number of purged messages. func (ms *memStore) Compact(seq uint64) (uint64, error) { + return ms.compact(seq, false) +} + +func (ms *memStore) compact(seq uint64, _ /* noMarkers */ bool) (uint64, error) { + // TODO: Don't write markers on compact until we have solved performance + // issues with them. + noMarkers := true + if seq == 0 { return ms.Purge() } @@ -1010,7 +1230,7 @@ func (ms *memStore) Compact(seq uint64) (uint64, error) { if sm := ms.msgs[seq]; sm != nil { bytes += memStoreMsgSize(sm.subj, sm.hdr, sm.msg) purged++ - ms.removeSeqPerSubject(sm.subj, seq) + ms.removeSeqPerSubject(sm.subj, seq, !noMarkers && ms.cfg.SubjectDeleteMarkerTTL > 0) // Must delete message after updating per-subject info, to be consistent with file store. delete(ms.msgs, seq) } else if !ms.dmap.IsEmpty() { @@ -1035,16 +1255,28 @@ func (ms *memStore) Compact(seq uint64) (uint64, error) { ms.state.FirstSeq = seq ms.state.FirstTime = time.Time{} ms.state.LastSeq = seq - 1 + // Subject delete markers if needed. + if !noMarkers && ms.cfg.SubjectDeleteMarkerTTL > 0 { + ms.fss.IterOrdered(func(bsubj []byte, ss *SimpleState) bool { + ms.markers = append(ms.markers, string(bsubj)) + return true + }) + } // Reset msgs, fss and dmap. ms.msgs = make(map[uint64]*StoreMsg) ms.fss = stree.NewSubjectTree[SimpleState]() ms.dmap.Empty() } + // Subject delete markers if needed. + sdmcb := ms.subjectDeleteMarkersAfterOperation(JSMarkerReasonPurge) ms.mu.Unlock() if cb != nil { cb(-int64(purged), -int64(bytes), 0, _EMPTY_) } + if sdmcb != nil { + sdmcb() + } return purged, nil } @@ -1104,7 +1336,7 @@ func (ms *memStore) Truncate(seq uint64) error { if sm := ms.msgs[i]; sm != nil { purged++ bytes += memStoreMsgSize(sm.subj, sm.hdr, sm.msg) - ms.removeSeqPerSubject(sm.subj, i) + ms.removeSeqPerSubject(sm.subj, i, false) // Must delete message after updating per-subject info, to be consistent with file store. delete(ms.msgs, i) } else if !ms.dmap.IsEmpty() { @@ -1141,7 +1373,8 @@ func (ms *memStore) deleteFirstMsgOrPanic() { } func (ms *memStore) deleteFirstMsg() bool { - return ms.removeMsg(ms.state.FirstSeq, false) + // TODO: Currently no markers for these types of limits (max msgs or max bytes) + return ms.removeMsg(ms.state.FirstSeq, false, _EMPTY_) } // LoadMsg will lookup the message by sequence number and return it if found. @@ -1337,7 +1570,8 @@ func (ms *memStore) LoadPrevMsg(start uint64, smp *StoreMsg) (sm *StoreMsg, err // Will return the number of bytes removed. func (ms *memStore) RemoveMsg(seq uint64) (bool, error) { ms.mu.Lock() - removed := ms.removeMsg(seq, false) + // TODO: Don't write markers on removes via the API yet, only via limits. + removed := ms.removeMsg(seq, false, _EMPTY_) ms.mu.Unlock() return removed, nil } @@ -1345,7 +1579,8 @@ func (ms *memStore) RemoveMsg(seq uint64) (bool, error) { // EraseMsg will remove the message and rewrite its contents. func (ms *memStore) EraseMsg(seq uint64) (bool, error) { ms.mu.Lock() - removed := ms.removeMsg(seq, true) + // TODO: Don't write markers on removes via the API yet, only via limits. + removed := ms.removeMsg(seq, true, _EMPTY_) ms.mu.Unlock() return removed, nil } @@ -1385,20 +1620,39 @@ func (ms *memStore) updateFirstSeq(seq uint64) { // Remove a seq from the fss and select new first. // Lock should be held. -func (ms *memStore) removeSeqPerSubject(subj string, seq uint64) { +func (ms *memStore) removeSeqPerSubject(subj string, seq uint64, marker bool) bool { ss, ok := ms.fss.Find(stringToBytes(subj)) if !ok { - return + return false } if ss.Msgs == 1 { ms.fss.Delete(stringToBytes(subj)) - return + if marker { + ms.markers = append(ms.markers, subj) + } + return true } ss.Msgs-- + // Only one left + if ss.Msgs == 1 { + if !ss.lastNeedsUpdate && seq != ss.Last { + ss.First = ss.Last + ss.firstNeedsUpdate = false + return false + } + if !ss.firstNeedsUpdate && seq != ss.First { + ss.Last = ss.First + ss.lastNeedsUpdate = false + return false + } + } + // We can lazily calculate the first/last sequence when needed. ss.firstNeedsUpdate = seq == ss.First || ss.firstNeedsUpdate ss.lastNeedsUpdate = seq == ss.Last || ss.lastNeedsUpdate + + return false } // Will recalculate the first and/or last sequence for this subject. @@ -1443,7 +1697,7 @@ func (ms *memStore) recalculateForSubj(subj string, ss *SimpleState) { // Removes the message referenced by seq. // Lock should be held. -func (ms *memStore) removeMsg(seq uint64, secure bool) bool { +func (ms *memStore) removeMsg(seq uint64, secure bool, marker string) bool { var ss uint64 sm, ok := ms.msgs[seq] if !ok { @@ -1475,15 +1729,29 @@ func (ms *memStore) removeMsg(seq uint64, secure bool) bool { } // Remove any per subject tracking. - ms.removeSeqPerSubject(sm.subj, seq) + needMarker := marker != _EMPTY_ && ms.cfg.SubjectDeleteMarkerTTL > 0 && len(getHeader(JSMarkerReason, sm.hdr)) == 0 + wasLast := ms.removeSeqPerSubject(sm.subj, seq, needMarker) + // Must delete message after updating per-subject info, to be consistent with file store. delete(ms.msgs, seq) - if ms.scb != nil { + // If the deleted message was itself a delete marker then + // don't write out more of them or we'll churn endlessly. + var sdmcb func() + if needMarker && wasLast { + sdmcb = ms.subjectDeleteMarkersAfterOperation(marker) + } + + if ms.scb != nil || sdmcb != nil { // We do not want to hold any locks here. ms.mu.Unlock() - delta := int64(ss) - ms.scb(-1, -delta, seq, sm.subj) + if ms.scb != nil { + delta := int64(ss) + ms.scb(-1, -delta, seq, sm.subj) + } + if sdmcb != nil { + sdmcb() + } ms.mu.Lock() } @@ -1677,7 +1945,7 @@ func (ms *memStore) SyncDeleted(dbs DeleteBlocks) { continue } db.Range(func(seq uint64) bool { - ms.removeMsg(seq, false) + ms.removeMsg(seq, false, _EMPTY_) return true }) } @@ -1736,9 +2004,26 @@ func (o *consumerMemStore) SetStarting(sseq uint64) error { return nil } +// UpdateStarting updates our starting stream sequence. +func (o *consumerMemStore) UpdateStarting(sseq uint64) { + o.mu.Lock() + defer o.mu.Unlock() + + if sseq > o.state.Delivered.Stream { + o.state.Delivered.Stream = sseq + // For AckNone just update delivered and ackfloor at the same time. + if o.cfg.AckPolicy == AckNone { + o.state.AckFloor.Stream = sseq + } + } +} + // HasState returns if this store has a recorded state. func (o *consumerMemStore) HasState() bool { - return false + o.mu.Lock() + defer o.mu.Unlock() + // We have a running state. + return o.state.Delivered.Consumer != 0 || o.state.Delivered.Stream != 0 } func (o *consumerMemStore) UpdateDelivered(dseq, sseq, dc uint64, ts int64) error { @@ -1749,6 +2034,7 @@ func (o *consumerMemStore) UpdateDelivered(dseq, sseq, dc uint64, ts int64) erro return ErrNoAckPolicy } + // On restarts the old leader may get a replay from the raft logs that are old. if dseq <= o.state.AckFloor.Consumer { return nil } @@ -1819,12 +2105,6 @@ func (o *consumerMemStore) UpdateAcks(dseq, sseq uint64) error { return nil } - // Match leader logic on checking if ack is ahead of delivered. - // This could happen on a cooperative takeover with high speed deliveries. - if sseq > o.state.Delivered.Stream { - o.state.Delivered.Stream = sseq + 1 - } - if len(o.state.Pending) == 0 || o.state.Pending[sseq] == nil { delete(o.state.Redelivered, sseq) return ErrStoreMsgNotFound diff --git a/vendor/github.com/nats-io/nats-server/v2/server/monitor.go b/vendor/github.com/nats-io/nats-server/v2/server/monitor.go index 7de9703b03..41270c73ff 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/monitor.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/monitor.go @@ -415,14 +415,13 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) { // Fill in user if auth requested. if auth { ci.AuthorizedUser = client.getRawAuthUser() - // Add in account iff not the global account. - if client.acc != nil && (client.acc.Name != globalAccountName) { - ci.Account = client.acc.Name + if name := client.acc.GetName(); name != globalAccountName { + ci.Account = name } ci.JWT = client.opts.JWT ci.IssuerKey = issuerForClient(client) ci.Tags = client.tags - ci.NameTag = client.nameTag + ci.NameTag = client.acc.getNameTag() } client.mu.Unlock() pconns[i] = ci @@ -465,9 +464,11 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) { // Fill in user if auth requested. if auth { cc.AuthorizedUser = cc.user - // Add in account iff not the global account. if cc.acc != _EMPTY_ && (cc.acc != globalAccountName) { cc.Account = cc.acc + if acc, err := s.LookupAccount(cc.acc); err == nil { + cc.NameTag = acc.getNameTag() + } } } pconns[i] = &cc.ConnInfo @@ -926,21 +927,21 @@ type SubszOptions struct { // SubDetail is for verbose information for subscriptions. type SubDetail struct { - Account string `json:"account,omitempty"` - Subject string `json:"subject"` - Queue string `json:"qgroup,omitempty"` - Sid string `json:"sid"` - Msgs int64 `json:"msgs"` - Max int64 `json:"max,omitempty"` - Cid uint64 `json:"cid"` + Account string `json:"account,omitempty"` + AccountTag string `json:"account_tag,omitempty"` + Subject string `json:"subject"` + Queue string `json:"qgroup,omitempty"` + Sid string `json:"sid"` + Msgs int64 `json:"msgs"` + Max int64 `json:"max,omitempty"` + Cid uint64 `json:"cid"` } // Subscription client should be locked and guaranteed to be present. func newSubDetail(sub *subscription) SubDetail { sd := newClientSubDetail(sub) - if sub.client.acc != nil { - sd.Account = sub.client.acc.Name - } + sd.Account = sub.client.acc.GetName() + sd.AccountTag = sub.client.acc.getNameTag() return sd } @@ -1228,6 +1229,7 @@ type Varz struct { Subscriptions uint32 `json:"subscriptions"` HTTPReqStats map[string]uint64 `json:"http_req_stats"` ConfigLoadTime time.Time `json:"config_load_time"` + ConfigDigest string `json:"config_digest"` Tags jwt.TagList `json:"tags,omitempty"` TrustedOperatorsJwt []string `json:"trusted_operators_jwt,omitempty"` TrustedOperatorsClaim []*jwt.OperatorClaims `json:"trusted_operators_claim,omitempty"` @@ -1242,6 +1244,7 @@ type JetStreamVarz struct { Config *JetStreamConfig `json:"config,omitempty"` Stats *JetStreamStats `json:"stats,omitempty"` Meta *MetaClusterInfo `json:"meta,omitempty"` + Limits *JSLimitOpts `json:"limits,omitempty"` } // ClusterOptsVarz contains monitoring cluster information @@ -1467,6 +1470,7 @@ func (s *Server) updateJszVarz(js *jetStream, v *JetStreamVarz, doConfig bool) { js.mu.RUnlock() } v.Stats = js.usageStats() + v.Limits = &s.getOpts().JetStreamLimits if mg := js.getMetaGroup(); mg != nil { if ci := s.raftNodeToClusterInfo(mg); ci != nil { v.Meta = &MetaClusterInfo{Name: ci.Name, Leader: ci.Leader, Peer: getHash(ci.Leader), Size: mg.ClusterSize()} @@ -1601,6 +1605,11 @@ func (s *Server) createVarz(pcpu float64, rss int64) *Varz { TrustedOperatorsJwt: opts.operatorJWT, TrustedOperatorsClaim: opts.TrustedOperators, } + // If this is a leaf without cluster, reset the cluster name (that is otherwise + // set to the server name). + if s.leafNoCluster { + varz.Cluster.Name = _EMPTY_ + } if len(opts.Routes) > 0 { varz.Cluster.URLs = urlsToStrings(opts.Routes) } @@ -1675,6 +1684,7 @@ func (s *Server) updateVarzConfigReloadableFields(v *Varz) { v.TLSTimeout = opts.TLSTimeout v.WriteDeadline = opts.WriteDeadline v.ConfigLoadTime = s.configTime.UTC() + v.ConfigDigest = opts.configDigest // Update route URLs if applicable if s.varzUpdateRouteURLs { v.Cluster.URLs = urlsToStrings(opts.Routes) @@ -1844,6 +1854,7 @@ func (s *Server) HandleVarz(w http.ResponseWriter, r *http.Request) { } sv.Stats = v.Stats sv.Meta = v.Meta + sv.Limits = v.Limits s.mu.RUnlock() } @@ -2747,27 +2758,27 @@ func (s *Server) accountInfo(accName string) (*AccountInfo, error) { mappings[src] = dests } return &AccountInfo{ - accName, - a.updated.UTC(), - isSys, - a.expired.Load(), - !a.incomplete, - a.js != nil, - a.numLocalLeafNodes(), - a.numLocalConnections(), - a.sl.Count(), - mappings, - exports, - imports, - a.claimJWT, - a.Issuer, - a.nameTag, - a.tags, - claim, - vrIssues, - collectRevocations(a.usersRevoked), - a.sl.Stats(), - responses, + AccountName: accName, + LastUpdate: a.updated.UTC(), + IsSystem: isSys, + Expired: a.expired.Load(), + Complete: !a.incomplete, + JetStream: a.js != nil, + LeafCnt: a.numLocalLeafNodes(), + ClientCnt: a.numLocalConnections(), + SubCnt: a.sl.Count(), + Mappings: mappings, + Exports: exports, + Imports: imports, + Jwt: a.claimJWT, + IssuerKey: a.Issuer, + NameTag: a.getNameTagLocked(), + Tags: a.tags, + Claim: claim, + Vr: vrIssues, + RevokedUser: collectRevocations(a.usersRevoked), + Sublist: a.sl.Stats(), + Responses: responses, }, nil } @@ -2791,6 +2802,7 @@ type HealthzOptions struct { JSEnabled bool `json:"js-enabled,omitempty"` JSEnabledOnly bool `json:"js-enabled-only,omitempty"` JSServerOnly bool `json:"js-server-only,omitempty"` + JSMetaOnly bool `json:"js-meta-only,omitempty"` Account string `json:"account,omitempty"` Stream string `json:"stream,omitempty"` Consumer string `json:"consumer,omitempty"` @@ -2859,6 +2871,7 @@ type JSInfo struct { Now time.Time `json:"now"` Disabled bool `json:"disabled,omitempty"` Config JetStreamConfig `json:"config,omitempty"` + Limits *JSLimitOpts `json:"limits,omitempty"` JetStreamStats Streams int `json:"streams"` Consumers int `json:"consumers"` @@ -3027,6 +3040,8 @@ func (s *Server) Jsz(opts *JSzOptions) (*JSInfo, error) { return jsi, nil } + jsi.Limits = &s.getOpts().JetStreamLimits + js.mu.RLock() isLeader := js.cluster == nil || js.cluster.isLeader() js.mu.RUnlock() @@ -3268,6 +3283,10 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { if err != nil { return } + jsMetaOnly, err := decodeBool(w, r, "js-meta-only") + if err != nil { + return + } includeDetails, err := decodeBool(w, r, "details") if err != nil { @@ -3278,6 +3297,7 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { JSEnabled: jsEnabled, JSEnabledOnly: jsEnabledOnly, JSServerOnly: jsServerOnly, + JSMetaOnly: jsMetaOnly, Account: r.URL.Query().Get("account"), Stream: r.URL.Query().Get("stream"), Consumer: r.URL.Query().Get("consumer"), @@ -3371,6 +3391,11 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus { return health } + // If JSServerOnly is true, then do not check further accounts, streams and consumers. + if opts.JSServerOnly { + return health + } + sopts := s.getOpts() // If JS is not enabled in the config, we stop. @@ -3549,6 +3574,7 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus { } return health } + // If we are not current with the meta leader. if !meta.Healthy() { if !details { @@ -3565,11 +3591,6 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus { return health } - // If JSServerOnly is true, then do not check further accounts, streams and consumers. - if opts.JSServerOnly { - return health - } - // Are we still recovering meta layer? if js.isMetaRecovering() { if !details { @@ -3587,6 +3608,11 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus { return health } + // Skips doing full healthz and only checks the meta leader. + if opts.JSMetaOnly { + return health + } + // Range across all accounts, the streams assigned to them, and the consumers. // If they are assigned to this server check their status. ourID := meta.ID() diff --git a/vendor/github.com/nats-io/nats-server/v2/server/mqtt.go b/vendor/github.com/nats-io/nats-server/v2/server/mqtt.go index a511c4f514..e76b7168f0 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/mqtt.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/mqtt.go @@ -191,6 +191,18 @@ const ( mqttRetainedTransferTimeout = 10 * time.Second ) +const ( + sparkbNBIRTH = "NBIRTH" + sparkbDBIRTH = "DBIRTH" + sparkbNDEATH = "NDEATH" + sparkbDDEATH = "DDEATH" +) + +var ( + sparkbNamespaceTopicPrefix = []byte("spBv1.0/") + sparkbCertificatesTopicPrefix = []byte("$sparkplug/certificates/") +) + var ( mqttPingResponse = []byte{mqttPacketPingResp, 0x0} mqttProtoName = []byte("MQTT") @@ -456,8 +468,22 @@ type mqttPublish struct { // When we submit a PUBREL for delivery, we add a "Nmqtt-PubRel" header that // contains the PI. const ( - mqttNatsHeader = "Nmqtt-Pub" - mqttNatsPubRelHeader = "Nmqtt-PubRel" + // NATS header that indicates that the message originated from MQTT and + // stores the published message QOS. + mqttNatsHeader = "Nmqtt-Pub" + + // NATS headers to store retained message metadata (along with the original + // message as binary). + mqttNatsRetainedMessageTopic = "Nmqtt-RTopic" + mqttNatsRetainedMessageOrigin = "Nmqtt-ROrigin" + mqttNatsRetainedMessageFlags = "Nmqtt-RFlags" + mqttNatsRetainedMessageSource = "Nmqtt-RSource" + + // NATS header that indicates that the message is an MQTT PubRel and stores + // the PI. + mqttNatsPubRelHeader = "Nmqtt-PubRel" + + // NATS headers to store the original MQTT subject and the subject mapping. mqttNatsHeaderSubject = "Nmqtt-Subject" mqttNatsHeaderMapped = "Nmqtt-Mapped" ) @@ -1636,13 +1662,16 @@ func (jsa *mqttJSA) newRequestExMulti(kind, subject, cidHash string, hdrs []int, } func (jsa *mqttJSA) sendAck(ackSubject string) { - if ackSubject == _EMPTY_ { - return - } - // We pass -1 for the hdr so that the send loop does not need to // add the "client info" header. This is not a JS API request per se. - jsa.sendq.push(&mqttJSPubMsg{subj: ackSubject, hdr: -1}) + jsa.sendMsg(ackSubject, nil) +} + +func (jsa *mqttJSA) sendMsg(subj string, msg []byte) { + if subj == _EMPTY_ { + return + } + jsa.sendq.push(&mqttJSPubMsg{subj: subj, msg: msg, hdr: -1}) } func (jsa *mqttJSA) createEphemeralConsumer(cfg *CreateConsumerRequest) (*JSApiConsumerCreateResponse, error) { @@ -1673,13 +1702,6 @@ func (jsa *mqttJSA) createDurableConsumer(cfg *CreateConsumerRequest) (*JSApiCon return ccr, ccr.ToError() } -func (jsa *mqttJSA) sendMsg(subj string, msg []byte) { - if subj == _EMPTY_ { - return - } - jsa.sendq.push(&mqttJSPubMsg{subj: subj, msg: msg, hdr: -1}) -} - // if noWait is specified, does not wait for the JS response, returns nil func (jsa *mqttJSA) deleteConsumer(streamName, consName string, noWait bool) (*JSApiConsumerDeleteResponse, error) { subj := fmt.Sprintf(JSApiConsumerDeleteT, streamName, consName) @@ -1890,61 +1912,61 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl case mqttJSAStreamCreate: var resp = &JSApiStreamCreateResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAStreamUpdate: var resp = &JSApiStreamUpdateResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAStreamLookup: var resp = &JSApiStreamInfoResponse{} if err := json.Unmarshal(msg, &resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAStreamDel: var resp = &JSApiStreamDeleteResponse{} if err := json.Unmarshal(msg, &resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAConsumerCreate: var resp = &JSApiConsumerCreateResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAConsumerDel: var resp = &JSApiConsumerDeleteResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAMsgStore, mqttJSASessPersist: var resp = &JSPubAckResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAMsgLoad: var resp = &JSApiMsgGetResponse{} if err := json.Unmarshal(msg, &resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAStreamNames: var resp = &JSApiStreamNamesResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) case mqttJSAMsgDelete: var resp = &JSApiMsgDeleteResponse{} if err := json.Unmarshal(msg, resp); err != nil { - resp.Error = NewJSInvalidJSONError() + resp.Error = NewJSInvalidJSONError(err) } out(resp) default: @@ -1957,9 +1979,9 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl // Run from various go routines (JS consumer, etc..). // No lock held on entry. func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { - _, msg := c.msgParts(rmsg) - rm := &mqttRetainedMsg{} - if err := json.Unmarshal(msg, rm); err != nil { + h, m := c.msgParts(rmsg) + rm, err := mqttDecodeRetainedMessage(h, m) + if err != nil { return } // If lastSeq is 0 (nothing to recover, or done doing it) and this is @@ -2791,20 +2813,95 @@ func (as *mqttAccountSessionManager) loadRetainedMessages(subjects map[string]st w.Warnf("failed to load retained message for subject %q: %v", ss[i], err) continue } - var rm mqttRetainedMsg - if err := json.Unmarshal(result.Message.Data, &rm); err != nil { + rm, err := mqttDecodeRetainedMessage(result.Message.Header, result.Message.Data) + if err != nil { w.Warnf("failed to decode retained message for subject %q: %v", ss[i], err) continue } // Add the loaded retained message to the cache, and to the results map. key := ss[i][len(mqttRetainedMsgsStreamSubject):] - as.setCachedRetainedMsg(key, &rm, false, false) - rms[key] = &rm + as.setCachedRetainedMsg(key, rm, false, false) + rms[key] = rm } return rms } +// Composes a NATS message for a storeable mqttRetainedMsg. +func mqttEncodeRetainedMessage(rm *mqttRetainedMsg) (natsMsg []byte, headerLen int) { + // No need to encode the subject, we can restore it from topic. + l := len(hdrLine) + l += len(mqttNatsRetainedMessageTopic) + 1 + len(rm.Topic) + 2 // 1 byte for ':', 2 bytes for CRLF + if rm.Origin != _EMPTY_ { + l += len(mqttNatsRetainedMessageOrigin) + 1 + len(rm.Origin) + 2 // 1 byte for ':', 2 bytes for CRLF + } + if rm.Source != _EMPTY_ { + l += len(mqttNatsRetainedMessageSource) + 1 + len(rm.Source) + 2 // 1 byte for ':', 2 bytes for CRLF + } + l += len(mqttNatsRetainedMessageFlags) + 1 + 2 + 2 // 1 byte for ':', 2 bytes for the flags, 2 bytes for CRLF + l += 2 // 2 bytes for the extra CRLF after the header + l += len(rm.Msg) + + buf := bytes.NewBuffer(make([]byte, 0, l)) + + buf.WriteString(hdrLine) + + buf.WriteString(mqttNatsRetainedMessageTopic) + buf.WriteByte(':') + buf.WriteString(rm.Topic) + buf.WriteString(_CRLF_) + + buf.WriteString(mqttNatsRetainedMessageFlags) + buf.WriteByte(':') + buf.WriteString(strconv.FormatUint(uint64(rm.Flags), 16)) + buf.WriteString(_CRLF_) + + if rm.Origin != _EMPTY_ { + buf.WriteString(mqttNatsRetainedMessageOrigin) + buf.WriteByte(':') + buf.WriteString(rm.Origin) + buf.WriteString(_CRLF_) + } + if rm.Source != _EMPTY_ { + buf.WriteString(mqttNatsRetainedMessageSource) + buf.WriteByte(':') + buf.WriteString(rm.Source) + buf.WriteString(_CRLF_) + } + + // End of header, finalize + buf.WriteString(_CRLF_) + headerLen = buf.Len() + buf.Write(rm.Msg) + return buf.Bytes(), headerLen +} + +func mqttDecodeRetainedMessage(h, m []byte) (*mqttRetainedMsg, error) { + fHeader := getHeader(mqttNatsRetainedMessageFlags, h) + if len(fHeader) > 0 { + flags, err := strconv.ParseUint(string(fHeader), 16, 8) + if err != nil { + return nil, fmt.Errorf("invalid retained message flags: %v", err) + } + topic := getHeader(mqttNatsRetainedMessageTopic, h) + subj, _ := mqttToNATSSubjectConversion(topic, false) + return &mqttRetainedMsg{ + Flags: byte(flags), + Subject: string(subj), + Topic: string(topic), + Origin: string(getHeader(mqttNatsRetainedMessageOrigin, h)), + Source: string(getHeader(mqttNatsRetainedMessageSource, h)), + Msg: m, + }, nil + } else { + var rm mqttRetainedMsg + if err := json.Unmarshal(m, &rm); err != nil { + return nil, err + } + return &rm, nil + } +} + // Creates the session stream (limit msgs of 1) for this client ID if it does // not already exist. If it exists, recover the single record to rebuild the // state of the session. If there is a session record but this session is not @@ -2951,7 +3048,9 @@ func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log * return err } - // Unmarshal the message so that we can obtain the subject name. + // Unmarshal the message so that we can obtain the subject name. Do not + // use mqttDecodeRetainedMessage() here because these messages are from + // older versions, and contain the full JSON encoding in payload. var rmsg mqttRetainedMsg if err = json.Unmarshal(smsg.Data, &rmsg); err == nil { // Store the message again, this time with the new per-key subject. @@ -3331,7 +3430,7 @@ func (sess *mqttSession) untrackPublish(pi uint16) (jsAckSubject string) { return ack.jsAckSubject } -// trackPubRel is invoked in 2 cases: (a) when we receive a PUBREC and we need +// trackAsPubRel is invoked in 2 cases: (a) when we receive a PUBREC and we need // to change from tracking the PI as a PUBLISH to a PUBREL; and (b) when we // attempt to deliver the PUBREL to record the JS ack subject for it. // @@ -3547,7 +3646,7 @@ func (c *client) mqttParseConnect(r *mqttReader, hasMappings bool) (byte, *mqttC cp.will.mapped = c.pa.mapped // We also now need to map the original MQTT topic to the new topic // based on the new subject. - topic = natsSubjectToMQTTTopic(string(cp.will.subject)) + topic = natsSubjectToMQTTTopic(cp.will.subject) } // Reset those now. c.pa.subject, c.pa.mapped = nil, nil @@ -3644,7 +3743,7 @@ func (s *Server) mqttProcessConnect(c *client, cp *mqttConnectProto, trace bool) c.authViolation() return ErrAuthentication } - // Now that we are are authenticated, we have the client bound to the account. + // Now that we are authenticated, we have the client bound to the account. // Get the account's level MQTT sessions manager. If it does not exists yet, // this will create it along with the streams where sessions and messages // are stored. @@ -3896,7 +3995,7 @@ func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish, hasMapping pp.mapped = c.pa.mapped // We also now need to map the original MQTT topic to the new topic // based on the new subject. - pp.topic = natsSubjectToMQTTTopic(string(pp.subject)) + pp.topic = natsSubjectToMQTTTopic(pp.subject) } // Reset those now. c.pa.subject, c.pa.mapped = nil, nil @@ -3952,7 +4051,7 @@ func mqttNewDeliverableMessage(pp *mqttPublish, encodePP bool) (natsMsg []byte, size := len(hdrLine) + len(mqttNatsHeader) + 2 + 2 + // 2 for ':', and 2 for CRLF 2 + // end-of-header CRLF - len(pp.msg) + pp.sz if encodePP { size += len(mqttNatsHeaderSubject) + 1 + // +1 for ':' len(pp.subject) + 2 // 2 for CRLF @@ -4160,7 +4259,7 @@ func (s *Server) mqttProcessPubRel(c *client, pi uint16, trace bool) error { } pp := &mqttPublish{ - topic: natsSubjectToMQTTTopic(string(h.subject)), + topic: natsSubjectToMQTTTopic(h.subject), subject: h.subject, mapped: h.mapped, msg: stored.Data, @@ -4179,50 +4278,106 @@ func (s *Server) mqttProcessPubRel(c *client, pi uint16, trace bool) error { // Invoked from the MQTT publisher's readLoop. No client lock is held on entry. func (c *client) mqttHandlePubRetain() { pp := c.mqtt.pp - if !mqttIsRetained(pp.flags) { + retainMQTT := mqttIsRetained(pp.flags) + isBirth, _, isCertificate := sparkbParseBirthDeathTopic(pp.topic) + retainSparkbBirth := isBirth && !isCertificate + + // [tck-id-topics-nbirth-mqtt] NBIRTH messages MUST be published with MQTT + // QoS equal to 0 and retain equal to false. + // + // [tck-id-conformance-mqtt-aware-nbirth-mqtt-retain] A Sparkplug Aware MQTT + // Server MUST make NBIRTH messages available on the topic: + // $sparkplug/certificates/namespace/group_id/NBIRTH/edge_node_id with the + // MQTT retain flag set to true. + if retainMQTT == retainSparkbBirth { + // (retainSparkbBirth && retainMQTT) : not valid, so ignore altogether. + // (!retainSparkbBirth && !retainMQTT) : nothing to do. return } - key := string(pp.subject) + asm := c.mqtt.asm - // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, - // but should still be delivered as a normal message. + key := string(pp.subject) + + // Always clear the retain flag to deliver a normal published message. + defer func() { + pp.flags &= ^mqttPubFlagRetain + }() + + // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, but + // should still be delivered as a normal message. if pp.sz == 0 { if seqToRemove := asm.handleRetainedMsgDel(key, 0); seqToRemove > 0 { asm.deleteRetainedMsg(seqToRemove) asm.notifyRetainedMsgDeleted(key, seqToRemove) } - } else { - // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. - // When coming from a publish protocol, `pp` is referencing a stack - // variable that itself possibly references the read buffer. - rm := &mqttRetainedMsg{ - Origin: asm.jsa.id, - Subject: key, - Topic: string(pp.topic), - Msg: pp.msg, - Flags: pp.flags, - Source: c.opts.Username, - } - rmBytes, _ := json.Marshal(rm) - smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+key, -1, rmBytes) - if err == nil { - // Update the new sequence - rf := &mqttRetainedMsgRef{ - sseq: smr.Sequence, - } - // Add/update the map - asm.handleRetainedMsg(key, rf, rm, true) // will copy the payload bytes if needs to update rmsCache - } else { - c.mu.Lock() - acc := c.acc - c.mu.Unlock() - c.Errorf("unable to store retained message for account %q, subject %q: %v", - acc.GetName(), key, err) - } + return } - // Clear the retain flag for a normal published message. - pp.flags &= ^mqttPubFlagRetain + rm := &mqttRetainedMsg{ + Origin: asm.jsa.id, + Msg: pp.msg, // will copy these bytes later as we process rm. + Flags: pp.flags, + Source: c.opts.Username, + } + + if retainSparkbBirth { + // [tck-id-conformance-mqtt-aware-store] A Sparkplug Aware MQTT Server + // MUST store NBIRTH and DBIRTH messages as they pass through the MQTT + // Server. + // + // [tck-id-conformance-mqtt-aware-nbirth-mqtt-topic]. A Sparkplug Aware + // MQTT Server MUST make NBIRTH messages available on a topic of the + // form: $sparkplug/certificates/namespace/group_id/NBIRTH/edge_node_id + // + // [tck-id-conformance-mqtt-aware-dbirth-mqtt-topic] A Sparkplug Aware + // MQTT Server MUST make DBIRTH messages available on a topic of the + // form: + // $sparkplug/certificates/namespace/group_id/DBIRTH/edge_node_id/device_id + topic := append(sparkbCertificatesTopicPrefix, pp.topic...) + subject, _ := mqttTopicToNATSPubSubject(topic) + rm.Topic = string(topic) + rm.Subject = string(subject) + + // will use to save the retained message. + key = string(subject) + + // Store the retained message with the RETAIN flag set. + rm.Flags |= mqttPubFlagRetain + + // Copy the payload out of pp since we will be sending the message + // asynchronously. + msg := make([]byte, pp.sz) + copy(msg, pp.msg[:pp.sz]) + asm.jsa.sendMsg(key, msg) + + } else { // isRetained + // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. + // + // When coming from a publish protocol, `pp` is referencing a stack + // variable that itself possibly references the read buffer. + rm.Topic = string(pp.topic) + } + + // Set the key to the subject of the message for retained, or the composed + // $sparkplug subject for sparkB. + rm.Subject = key + rmBytes, hdr := mqttEncodeRetainedMessage(rm) // will copy the payload bytes + smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+key, hdr, rmBytes) + if err == nil { + // Update the new sequence. + rf := &mqttRetainedMsgRef{ + sseq: smr.Sequence, + } + // Add/update the map. `true` to copy the payload bytes if needs to + // update rmsCache. + asm.handleRetainedMsg(key, rf, rm, true) + } else { + c.mu.Lock() + acc := c.acc + c.mu.Unlock() + c.Errorf("unable to store retained message for account %q, subject %q: %v", + acc.GetName(), key, err) + } } // After a config reload, it is possible that the source of a publish retained @@ -4286,8 +4441,8 @@ func (s *Server) mqttCheckPubRetainedPerms() { if err != nil || jsm == nil { continue } - var rm mqttRetainedMsg - if err := json.Unmarshal(jsm.Data, &rm); err != nil { + rm, err := mqttDecodeRetainedMessage(jsm.Header, jsm.Data) + if err != nil { continue } if rm.Source == _EMPTY_ { @@ -4484,6 +4639,32 @@ func mqttIsRetained(flags byte) bool { return flags&mqttPubFlagRetain != 0 } +func sparkbParseBirthDeathTopic(topic []byte) (isBirth, isDeath, isCertificate bool) { + if bytes.HasPrefix(topic, sparkbCertificatesTopicPrefix) { + isCertificate = true + topic = topic[len(sparkbCertificatesTopicPrefix):] + } + if !bytes.HasPrefix(topic, sparkbNamespaceTopicPrefix) { + return false, false, false + } + topic = topic[len(sparkbNamespaceTopicPrefix):] + + parts := bytes.Split(topic, []byte{'/'}) + if len(parts) < 3 || len(parts) > 4 { + return false, false, false + } + typ := bytesToString(parts[1]) + switch typ { + case sparkbNBIRTH, sparkbDBIRTH: + isBirth = true + case sparkbNDEATH, sparkbDDEATH: + isDeath = true + default: + return false, false, false + } + return isBirth, isDeath, isCertificate +} + ////////////////////////////////////////////////////////////////////////////// // // SUBSCRIBE related functions @@ -4631,18 +4812,16 @@ func mqttDeliverMsgCbQoS0(sub *subscription, pc *client, _ *Account, subject, re topic = pc.mqtt.pp.topic // Check for service imports where subject mapping is in play. if len(pc.pa.mapped) > 0 && len(pc.pa.psi) > 0 { - topic = natsSubjectToMQTTTopic(subject) + topic = natsSubjectStrToMQTTTopic(subject) } } else { // Non MQTT client, could be NATS publisher, or ROUTER, etc.. h := mqttParsePublishNATSHeader(hdr) - // If the message does not have the MQTT header, it is not a MQTT and - // should be delivered here, at QOS0. If it does have the header, we - // need to lock the session to check the sub QoS, and then ignore the - // message if the Sub wants higher QOS delivery. It will be delivered by - // mqttDeliverMsgCbQoS12. + // Check the subscription's QoS. If the message was published with a + // QoS>0 (in the header) and the sub has the QoS>0 then the message will + // be delivered by mqttDeliverMsgCbQoS12. if subQoS > 0 && h != nil && h.qos > 0 { return } @@ -4652,7 +4831,7 @@ func mqttDeliverMsgCbQoS0(sub *subscription, pc *client, _ *Account, subject, re if len(msg) > mqttMaxPayloadSize { msg = msg[:mqttMaxPayloadSize] } - topic = natsSubjectToMQTTTopic(subject) + topic = natsSubjectStrToMQTTTopic(subject) } // Message never has a packet identifier nor is marked as duplicate. @@ -4714,7 +4893,7 @@ func mqttDeliverMsgCbQoS12(sub *subscription, pc *client, _ *Account, subject, r // Check for reserved subject violation. If so, we will send the ack to // remove the message, and do nothing else. - strippedSubj := string(subject[len(mqttStreamSubjectPrefix):]) + strippedSubj := subject[len(mqttStreamSubjectPrefix):] if mqttMustIgnoreForReservedSub(sub, strippedSubj) { sess.mu.Unlock() sess.jsa.sendAck(reply) @@ -4731,7 +4910,7 @@ func mqttDeliverMsgCbQoS12(sub *subscription, pc *client, _ *Account, subject, r return } - originalTopic := natsSubjectToMQTTTopic(strippedSubj) + originalTopic := natsSubjectStrToMQTTTopic(strippedSubj) pc.mqttEnqueuePublishMsgTo(cc, sub, pi, qos, dup, originalTopic, msg) } @@ -4789,10 +4968,78 @@ func isMQTTReservedSubscription(subject string) bool { return false } +func sparkbReplaceDeathTimestamp(msg []byte) []byte { + const VARINT = 0 + const TIMESTAMP = 1 + + orig := msg + buf := bytes.NewBuffer(make([]byte, 0, len(msg)+16)) // 16 bytes should be enough if we need to add a timestamp + writeDeathTimestamp := func() { + // [tck-id-conformance-mqtt-aware-ndeath-timestamp] A Sparkplug Aware + // MQTT Server MAY replace the timestamp of NDEATH messages. If it does, + // it MUST set the timestamp to the UTC time at which it attempts to + // deliver the NDEATH to subscribed clients + // + // sparkB spec: 6.4.1. Google Protocol Buffer Schema + // optional uint64 timestamp = 1; // Timestamp at message sending time + // + // SparkplugB timestamps are milliseconds since epoch, represented as + // uint64 in go, transmitted as protobuf varint. + ts := uint64(time.Now().UnixMilli()) + buf.Write(protoEncodeVarint(TIMESTAMP<<3 | VARINT)) + buf.Write(protoEncodeVarint(ts)) + } + + for len(msg) > 0 { + fieldNumericID, fieldType, size, err := protoScanField(msg) + if err != nil { + return orig + } + if fieldType != VARINT || fieldNumericID != TIMESTAMP { + // Add the field as is + buf.Write(msg[:size]) + msg = msg[size:] + continue + } + + writeDeathTimestamp() + + // Add the rest of the message as is, we are done + buf.Write(msg[size:]) + return buf.Bytes() + } + + // Add timestamp if we did not find one. + writeDeathTimestamp() + + return buf.Bytes() +} + // Common function to mqtt delivery callbacks to serialize and send the message // to the `cc` client. func (c *client) mqttEnqueuePublishMsgTo(cc *client, sub *subscription, pi uint16, qos byte, dup bool, topic, msg []byte) { - flags, headerBytes := mqttMakePublishHeader(pi, qos, dup, false, topic, len(msg)) + // [tck-id-conformance-mqtt-aware-nbirth-mqtt-retain] A Sparkplug Aware + // MQTT Server MUST make NBIRTH messages available on the topic: + // $sparkplug/certificates/namespace/group_id/NBIRTH/edge_node_id with + // the MQTT retain flag set to true + // + // [tck-id-conformance-mqtt-aware-dbirth-mqtt-retain] A Sparkplug Aware + // MQTT Server MUST make DBIRTH messages available on the topic: + // $sparkplug/certificates/namespace/group_id/DBIRTH/edge_node_id/device_id + // with the MQTT retain flag set to true + // + // $sparkplug/certificates messages are sent as NATS messages, so we + // need to add the retain flag when sending them to MQTT clients. + + retain := false + isBirth, isDeath, isCertificate := sparkbParseBirthDeathTopic(topic) + if isBirth && qos == 0 { + retain = isCertificate + } else if isDeath && !isCertificate { + msg = sparkbReplaceDeathTimestamp(msg) + } + + flags, headerBytes := mqttMakePublishHeader(pi, qos, dup, retain, topic, len(msg)) cc.mu.Lock() if sub.mqtt.prm != nil { @@ -5351,8 +5598,12 @@ func mqttToNATSSubjectConversion(mt []byte, wcOk bool) ([]byte, error) { // Converts a NATS subject to MQTT topic. This is for publish // messages only, so there is no checking for wildcards. // Rules are reversed of mqttToNATSSubjectConversion. -func natsSubjectToMQTTTopic(subject string) []byte { - topic := []byte(subject) +func natsSubjectStrToMQTTTopic(subject string) []byte { + return natsSubjectToMQTTTopic(stringToBytes(subject)) +} + +func natsSubjectToMQTTTopic(subject []byte) []byte { + topic := make([]byte, len(subject)) end := len(subject) - 1 var j int for i := 0; i < len(subject); i++ { diff --git a/vendor/github.com/nats-io/nats-server/v2/server/msgtrace.go b/vendor/github.com/nats-io/nats-server/v2/server/msgtrace.go new file mode 100644 index 0000000000..e3e73421b1 --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/msgtrace.go @@ -0,0 +1,846 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "math/rand" + "strconv" + "strings" + "sync/atomic" + "time" +) + +const ( + MsgTraceDest = "Nats-Trace-Dest" + MsgTraceHop = "Nats-Trace-Hop" + MsgTraceOriginAccount = "Nats-Trace-Origin-Account" + MsgTraceOnly = "Nats-Trace-Only" + + // External trace header. Note that this header is normally in lower + // case (https://www.w3.org/TR/trace-context/#header-name). Vendors + // MUST expect the header in any case (upper, lower, mixed), and + // SHOULD send the header name in lowercase. + traceParentHdr = "traceparent" +) + +type MsgTraceType string + +// Type of message trace events in the MsgTraceEvents list. +// This is needed to unmarshal the list. +const ( + MsgTraceIngressType = "in" + MsgTraceSubjectMappingType = "sm" + MsgTraceStreamExportType = "se" + MsgTraceServiceImportType = "si" + MsgTraceJetStreamType = "js" + MsgTraceEgressType = "eg" +) + +type MsgTraceEvent struct { + Server ServerInfo `json:"server"` + Request MsgTraceRequest `json:"request"` + Hops int `json:"hops,omitempty"` + Events MsgTraceEvents `json:"events"` +} + +type MsgTraceRequest struct { + // We are not making this an http.Header so that header name case is preserved. + Header map[string][]string `json:"header,omitempty"` + MsgSize int `json:"msgsize,omitempty"` +} + +type MsgTraceEvents []MsgTrace + +type MsgTrace interface { + new() MsgTrace + typ() MsgTraceType +} + +type MsgTraceBase struct { + Type MsgTraceType `json:"type"` + Timestamp time.Time `json:"ts"` +} + +type MsgTraceIngress struct { + MsgTraceBase + Kind int `json:"kind"` + CID uint64 `json:"cid"` + Name string `json:"name,omitempty"` + Account string `json:"acc"` + Subject string `json:"subj"` + Error string `json:"error,omitempty"` +} + +type MsgTraceSubjectMapping struct { + MsgTraceBase + MappedTo string `json:"to"` +} + +type MsgTraceStreamExport struct { + MsgTraceBase + Account string `json:"acc"` + To string `json:"to"` +} + +type MsgTraceServiceImport struct { + MsgTraceBase + Account string `json:"acc"` + From string `json:"from"` + To string `json:"to"` +} + +type MsgTraceJetStream struct { + MsgTraceBase + Stream string `json:"stream"` + Subject string `json:"subject,omitempty"` + NoInterest bool `json:"nointerest,omitempty"` + Error string `json:"error,omitempty"` +} + +type MsgTraceEgress struct { + MsgTraceBase + Kind int `json:"kind"` + CID uint64 `json:"cid"` + Name string `json:"name,omitempty"` + Hop string `json:"hop,omitempty"` + Account string `json:"acc,omitempty"` + Subscription string `json:"sub,omitempty"` + Queue string `json:"queue,omitempty"` + Error string `json:"error,omitempty"` + + // This is for applications that unmarshal the trace events + // and want to link an egress to route/leaf/gateway with + // the MsgTraceEvent from that server. + Link *MsgTraceEvent `json:"-"` +} + +// ------------------------------------------------------------- + +func (t MsgTraceBase) typ() MsgTraceType { return t.Type } +func (MsgTraceIngress) new() MsgTrace { return &MsgTraceIngress{} } +func (MsgTraceSubjectMapping) new() MsgTrace { return &MsgTraceSubjectMapping{} } +func (MsgTraceStreamExport) new() MsgTrace { return &MsgTraceStreamExport{} } +func (MsgTraceServiceImport) new() MsgTrace { return &MsgTraceServiceImport{} } +func (MsgTraceJetStream) new() MsgTrace { return &MsgTraceJetStream{} } +func (MsgTraceEgress) new() MsgTrace { return &MsgTraceEgress{} } + +var msgTraceInterfaces = map[MsgTraceType]MsgTrace{ + MsgTraceIngressType: MsgTraceIngress{}, + MsgTraceSubjectMappingType: MsgTraceSubjectMapping{}, + MsgTraceStreamExportType: MsgTraceStreamExport{}, + MsgTraceServiceImportType: MsgTraceServiceImport{}, + MsgTraceJetStreamType: MsgTraceJetStream{}, + MsgTraceEgressType: MsgTraceEgress{}, +} + +func (t *MsgTraceEvents) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + *t = make(MsgTraceEvents, len(raw)) + var tt MsgTraceBase + for i, r := range raw { + if err = json.Unmarshal(r, &tt); err != nil { + return err + } + tr, ok := msgTraceInterfaces[tt.Type] + if !ok { + return fmt.Errorf("unknown trace type %v", tt.Type) + } + te := tr.new() + if err := json.Unmarshal(r, te); err != nil { + return err + } + (*t)[i] = te + } + return nil +} + +func getTraceAs[T MsgTrace](e any) *T { + v, ok := e.(*T) + if ok { + return v + } + return nil +} + +func (t *MsgTraceEvent) Ingress() *MsgTraceIngress { + if len(t.Events) < 1 { + return nil + } + return getTraceAs[MsgTraceIngress](t.Events[0]) +} + +func (t *MsgTraceEvent) SubjectMapping() *MsgTraceSubjectMapping { + for _, e := range t.Events { + if e.typ() == MsgTraceSubjectMappingType { + return getTraceAs[MsgTraceSubjectMapping](e) + } + } + return nil +} + +func (t *MsgTraceEvent) StreamExports() []*MsgTraceStreamExport { + var se []*MsgTraceStreamExport + for _, e := range t.Events { + if e.typ() == MsgTraceStreamExportType { + se = append(se, getTraceAs[MsgTraceStreamExport](e)) + } + } + return se +} + +func (t *MsgTraceEvent) ServiceImports() []*MsgTraceServiceImport { + var si []*MsgTraceServiceImport + for _, e := range t.Events { + if e.typ() == MsgTraceServiceImportType { + si = append(si, getTraceAs[MsgTraceServiceImport](e)) + } + } + return si +} + +func (t *MsgTraceEvent) JetStream() *MsgTraceJetStream { + for _, e := range t.Events { + if e.typ() == MsgTraceJetStreamType { + return getTraceAs[MsgTraceJetStream](e) + } + } + return nil +} + +func (t *MsgTraceEvent) Egresses() []*MsgTraceEgress { + var eg []*MsgTraceEgress + for _, e := range t.Events { + if e.typ() == MsgTraceEgressType { + eg = append(eg, getTraceAs[MsgTraceEgress](e)) + } + } + return eg +} + +const ( + errMsgTraceOnlyNoSupport = "Not delivered because remote does not support message tracing" + errMsgTraceNoSupport = "Message delivered but remote does not support message tracing so no trace event generated from there" + errMsgTraceNoEcho = "Not delivered because of no echo" + errMsgTracePubViolation = "Not delivered because publish denied for this subject" + errMsgTraceSubDeny = "Not delivered because subscription denies this subject" + errMsgTraceSubClosed = "Not delivered because subscription is closed" + errMsgTraceClientClosed = "Not delivered because client is closed" + errMsgTraceAutoSubExceeded = "Not delivered because auto-unsubscribe exceeded" + errMsgTraceFastProdNoStall = "Not delivered because fast producer not stalled and consumer is slow" +) + +type msgTrace struct { + ready int32 + srv *Server + acc *Account + // Origin account name, set only if acc is nil when acc lookup failed. + oan string + dest string + event *MsgTraceEvent + js *MsgTraceJetStream + hop string + nhop string + tonly bool // Will only trace the message, not do delivery. + ct compressionType +} + +// This will be false outside of the tests, so when building the server binary, +// any code where you see `if msgTraceRunInTests` statement will be compiled +// out, so this will have no performance penalty. +var ( + msgTraceRunInTests bool + msgTraceCheckSupport bool +) + +// Returns the message trace object, if message is being traced, +// and `true` if we want to only trace, not actually deliver the message. +func (c *client) isMsgTraceEnabled() (*msgTrace, bool) { + t := c.pa.trace + if t == nil { + return nil, false + } + return t, t.tonly +} + +// For LEAF/ROUTER/GATEWAY, return false if the remote does not support +// message tracing (important if the tracing requests trace-only). +func (c *client) msgTraceSupport() bool { + // Exclude client connection from the protocol check. + return c.kind == CLIENT || c.opts.Protocol >= MsgTraceProto +} + +func getConnName(c *client) string { + switch c.kind { + case ROUTER: + if n := c.route.remoteName; n != _EMPTY_ { + return n + } + case GATEWAY: + if n := c.gw.remoteName; n != _EMPTY_ { + return n + } + case LEAF: + if n := c.leaf.remoteServer; n != _EMPTY_ { + return n + } + } + return c.opts.Name +} + +func getCompressionType(cts string) compressionType { + if cts == _EMPTY_ { + return noCompression + } + cts = strings.ToLower(cts) + if strings.Contains(cts, "snappy") || strings.Contains(cts, "s2") { + return snappyCompression + } + if strings.Contains(cts, "gzip") { + return gzipCompression + } + return unsupportedCompression +} + +func (c *client) initMsgTrace() *msgTrace { + // The code in the "if" statement is only running in test mode. + if msgTraceRunInTests { + // Check the type of client that tries to initialize a trace struct. + if !(c.kind == CLIENT || c.kind == ROUTER || c.kind == GATEWAY || c.kind == LEAF) { + panic(fmt.Sprintf("Unexpected client type %q trying to initialize msgTrace", c.kindString())) + } + // In some tests, we want to make a server behave like an old server + // and so even if a trace header is received, we want the server to + // simply ignore it. + if msgTraceCheckSupport { + if c.srv == nil || c.srv.getServerProto() < MsgTraceProto { + return nil + } + } + } + if c.pa.hdr <= 0 { + return nil + } + hdr := c.msgBuf[:c.pa.hdr] + headers, external := genHeaderMapIfTraceHeadersPresent(hdr) + if len(headers) == 0 { + return nil + } + // Little helper to give us the first value of a given header, or _EMPTY_ + // if key is not present. + getHdrVal := func(key string) string { + vv, ok := headers[key] + if !ok { + return _EMPTY_ + } + return vv[0] + } + ct := getCompressionType(getHdrVal(acceptEncodingHeader)) + var ( + dest string + traceOnly bool + ) + // Check for traceOnly only if not external. + if !external { + if to := getHdrVal(MsgTraceOnly); to != _EMPTY_ { + tos := strings.ToLower(to) + switch tos { + case "1", "true", "on": + traceOnly = true + } + } + dest = getHdrVal(MsgTraceDest) + // Check the destination to see if this is a valid public subject. + if !IsValidPublishSubject(dest) { + // We still have to return a msgTrace object (if traceOnly is set) + // because if we don't, the message will end-up being delivered to + // applications, which may break them. We report the error in any case. + c.Errorf("Destination %q is not valid, won't be able to trace events", dest) + if !traceOnly { + // We can bail, tracing will be disabled for this message. + return nil + } + } + } + var ( + // Account to use when sending the trace event + acc *Account + // Ingress' account name + ian string + // Origin account name + oan string + // The hop "id", taken from headers only when not from CLIENT + hop string + ) + if c.kind == ROUTER || c.kind == GATEWAY || c.kind == LEAF { + // The ingress account name will always be c.pa.account, but `acc` may + // be different if we have an origin account header. + if c.kind == LEAF { + ian = c.acc.GetName() + } else { + ian = string(c.pa.account) + } + // The remote will have set the origin account header only if the + // message changed account (think of service imports). + oan = getHdrVal(MsgTraceOriginAccount) + if oan == _EMPTY_ { + // For LEAF or ROUTER with pinned-account, we can use the c.acc. + if c.kind == LEAF || (c.kind == ROUTER && len(c.route.accName) > 0) { + acc = c.acc + } else { + // We will lookup account with c.pa.account (or ian). + oan = ian + } + } + // Unless we already got the account, we need to look it up. + if acc == nil { + // We don't want to do account resolving here. + if acci, ok := c.srv.accounts.Load(oan); ok { + acc = acci.(*Account) + // Since we have looked-up the account, we don't need oan, so + // clear it in case it was set. + oan = _EMPTY_ + } else { + // We still have to return a msgTrace object (if traceOnly is set) + // because if we don't, the message will end-up being delivered to + // applications, which may break them. We report the error in any case. + c.Errorf("Account %q was not found, won't be able to trace events", oan) + if !traceOnly { + // We can bail, tracing will be disabled for this message. + return nil + } + } + } + // Check the hop header + hop = getHdrVal(MsgTraceHop) + } else { + acc = c.acc + ian = acc.GetName() + } + // If external, we need to have the account's trace destination set, + // otherwise, we are not enabling tracing. + if external { + var sampling int + if acc != nil { + dest, sampling = acc.getTraceDestAndSampling() + } + if dest == _EMPTY_ { + // No account destination, no tracing for external trace headers. + return nil + } + // Check sampling, but only from origin server. + if c.kind == CLIENT && !sample(sampling) { + // Need to desactivate the traceParentHdr so that if the message + // is routed, it does possibly trigger a trace there. + disableTraceHeaders(c, hdr) + return nil + } + } + c.pa.trace = &msgTrace{ + srv: c.srv, + acc: acc, + oan: oan, + dest: dest, + ct: ct, + hop: hop, + event: &MsgTraceEvent{ + Request: MsgTraceRequest{ + Header: headers, + MsgSize: c.pa.size, + }, + Events: append(MsgTraceEvents(nil), &MsgTraceIngress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceIngressType, + Timestamp: time.Now(), + }, + Kind: c.kind, + CID: c.cid, + Name: getConnName(c), + Account: ian, + Subject: string(c.pa.subject), + }), + }, + tonly: traceOnly, + } + return c.pa.trace +} + +func sample(sampling int) bool { + // Option parsing should ensure that sampling is [1..100], but consider + // any value outside of this range to be 100%. + if sampling <= 0 || sampling >= 100 { + return true + } + return rand.Int31n(100) <= int32(sampling) +} + +// This function will return the header as a map (instead of http.Header because +// we want to preserve the header names' case) and a boolean that indicates if +// the headers have been lifted due to the presence of the external trace header +// only. +// Note that because of the traceParentHdr, the search is done in a case +// insensitive way, but if the header is found, it is rewritten in lower case +// as suggested by the spec, but also to make it easier to disable the header +// when needed. +func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) { + + var ( + _keys = [64][]byte{} + _vals = [64][]byte{} + m map[string][]string + traceDestHdrFound bool + traceParentHdrFound bool + ) + // Skip the hdrLine + if !bytes.HasPrefix(hdr, stringToBytes(hdrLine)) { + return nil, false + } + + traceDestHdrAsBytes := stringToBytes(MsgTraceDest) + traceParentHdrAsBytes := stringToBytes(traceParentHdr) + crLFAsBytes := stringToBytes(CR_LF) + dashAsBytes := stringToBytes("-") + + keys := _keys[:0] + vals := _vals[:0] + + for i := len(hdrLine); i < len(hdr); { + // Search for key/val delimiter + del := bytes.IndexByte(hdr[i:], ':') + if del < 0 { + break + } + keyStart := i + key := hdr[keyStart : keyStart+del] + i += del + 1 + valStart := i + nl := bytes.Index(hdr[valStart:], crLFAsBytes) + if nl < 0 { + break + } + if len(key) > 0 { + val := bytes.Trim(hdr[valStart:valStart+nl], " \t") + vals = append(vals, val) + + // Check for the external trace header. + if bytes.EqualFold(key, traceParentHdrAsBytes) { + // Rewrite the header using lower case if needed. + if !bytes.Equal(key, traceParentHdrAsBytes) { + copy(hdr[keyStart:], traceParentHdrAsBytes) + } + // We will now check if the value has sampling or not. + // TODO(ik): Not sure if this header can have multiple values + // or not, and if so, what would be the rule to check for + // sampling. What is done here is to check them all until we + // found one with sampling. + if !traceParentHdrFound { + tk := bytes.Split(val, dashAsBytes) + if len(tk) == 4 && len([]byte(tk[3])) == 2 { + if hexVal, err := strconv.ParseInt(bytesToString(tk[3]), 16, 8); err == nil { + if hexVal&0x1 == 0x1 { + traceParentHdrFound = true + } + } + } + } + // Add to the keys with the external trace header in lower case. + keys = append(keys, traceParentHdrAsBytes) + } else { + // Is the key the Nats-Trace-Dest header? + if bytes.EqualFold(key, traceDestHdrAsBytes) { + traceDestHdrFound = true + } + // Add to the keys and preserve the key's case + keys = append(keys, key) + } + } + i += nl + 2 + } + if !traceDestHdrFound && !traceParentHdrFound { + return nil, false + } + m = make(map[string][]string, len(keys)) + for i, k := range keys { + hname := string(k) + m[hname] = append(m[hname], string(vals[i])) + } + return m, !traceDestHdrFound && traceParentHdrFound +} + +// Special case where we create a trace event before parsing the message. +// This is for cases where the connection will be closed when detecting +// an error during early message processing (for instance max payload). +func (c *client) initAndSendIngressErrEvent(hdr []byte, dest string, ingressError error) { + if ingressError == nil { + return + } + ct := getAcceptEncoding(hdr) + t := &msgTrace{ + srv: c.srv, + acc: c.acc, + dest: dest, + ct: ct, + event: &MsgTraceEvent{ + Request: MsgTraceRequest{MsgSize: c.pa.size}, + Events: append(MsgTraceEvents(nil), &MsgTraceIngress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceIngressType, + Timestamp: time.Now(), + }, + Kind: c.kind, + CID: c.cid, + Name: getConnName(c), + Error: ingressError.Error(), + }), + }, + } + t.sendEvent() +} + +// Returns `true` if message tracing is enabled and we are tracing only, +// that is, we are not going to deliver the inbound message, returns +// `false` otherwise (no tracing, or tracing and message delivery). +func (t *msgTrace) traceOnly() bool { + return t != nil && t.tonly +} + +func (t *msgTrace) setOriginAccountHeaderIfNeeded(c *client, acc *Account, msg []byte) []byte { + var oan string + // If t.acc is set, only check that, not t.oan. + if t.acc != nil { + if t.acc != acc { + oan = t.acc.GetName() + } + } else if t.oan != acc.GetName() { + oan = t.oan + } + if oan != _EMPTY_ { + msg = c.setHeader(MsgTraceOriginAccount, oan, msg) + } + return msg +} + +func (t *msgTrace) setHopHeader(c *client, msg []byte) []byte { + e := t.event + e.Hops++ + if len(t.hop) > 0 { + t.nhop = fmt.Sprintf("%s.%d", t.hop, e.Hops) + } else { + t.nhop = fmt.Sprintf("%d", e.Hops) + } + return c.setHeader(MsgTraceHop, t.nhop, msg) +} + +// Will look for the MsgTraceSendTo and traceParentHdr headers and change the first +// character to an 'X' so that if this message is sent to a remote, the remote +// will not initialize tracing since it won't find the actual trace headers. +// The function returns the position of the headers so it can efficiently be +// re-enabled by calling enableTraceHeaders. +// Note that if `msg` can be either the header alone or the full message +// (header and payload). This function will use c.pa.hdr to limit the +// search to the header section alone. +func disableTraceHeaders(c *client, msg []byte) []int { + // Code largely copied from getHeader(), except that we don't need the value + if c.pa.hdr <= 0 { + return []int{-1, -1} + } + hdr := msg[:c.pa.hdr] + headers := [2]string{MsgTraceDest, traceParentHdr} + positions := [2]int{-1, -1} + for i := 0; i < 2; i++ { + key := stringToBytes(headers[i]) + pos := bytes.Index(hdr, key) + if pos < 0 { + continue + } + // Make sure this key does not have additional prefix. + if pos < 2 || hdr[pos-1] != '\n' || hdr[pos-2] != '\r' { + continue + } + index := pos + len(key) + if index >= len(hdr) { + continue + } + if hdr[index] != ':' { + continue + } + // Disable the trace by altering the first character of the header + hdr[pos] = 'X' + positions[i] = pos + } + // Return the positions of those characters so we can re-enable the headers. + return positions[:2] +} + +// Changes back the character at the given position `pos` in the `msg` +// byte slice to the first character of the MsgTraceSendTo header. +func enableTraceHeaders(msg []byte, positions []int) { + firstChar := [2]byte{MsgTraceDest[0], traceParentHdr[0]} + for i, pos := range positions { + if pos == -1 { + continue + } + msg[pos] = firstChar[i] + } +} + +func (t *msgTrace) setIngressError(err string) { + if i := t.event.Ingress(); i != nil { + i.Error = err + } +} + +func (t *msgTrace) addSubjectMappingEvent(subj []byte) { + if t == nil { + return + } + t.event.Events = append(t.event.Events, &MsgTraceSubjectMapping{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceSubjectMappingType, + Timestamp: time.Now(), + }, + MappedTo: string(subj), + }) +} + +func (t *msgTrace) addEgressEvent(dc *client, sub *subscription, err string) { + if t == nil { + return + } + e := &MsgTraceEgress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceEgressType, + Timestamp: time.Now(), + }, + Kind: dc.kind, + CID: dc.cid, + Name: getConnName(dc), + Hop: t.nhop, + Error: err, + } + t.nhop = _EMPTY_ + // Specific to CLIENT connections... + if dc.kind == CLIENT { + // Set the subscription's subject and possibly queue name. + e.Subscription = string(sub.subject) + if len(sub.queue) > 0 { + e.Queue = string(sub.queue) + } + } + if dc.kind == CLIENT || dc.kind == LEAF { + if i := t.event.Ingress(); i != nil { + // If the Ingress' account is different from the destination's + // account, add the account name into the Egress trace event. + // This would happen with service imports. + if dcAccName := dc.acc.GetName(); dcAccName != i.Account { + e.Account = dcAccName + } + } + } + t.event.Events = append(t.event.Events, e) +} + +func (t *msgTrace) addStreamExportEvent(dc *client, to []byte) { + if t == nil { + return + } + dc.mu.Lock() + accName := dc.acc.GetName() + dc.mu.Unlock() + t.event.Events = append(t.event.Events, &MsgTraceStreamExport{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceStreamExportType, + Timestamp: time.Now(), + }, + Account: accName, + To: string(to), + }) +} + +func (t *msgTrace) addServiceImportEvent(accName, from, to string) { + if t == nil { + return + } + t.event.Events = append(t.event.Events, &MsgTraceServiceImport{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceServiceImportType, + Timestamp: time.Now(), + }, + Account: accName, + From: from, + To: to, + }) +} + +func (t *msgTrace) addJetStreamEvent(streamName string) { + if t == nil { + return + } + t.js = &MsgTraceJetStream{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceJetStreamType, + Timestamp: time.Now(), + }, + Stream: streamName, + } + t.event.Events = append(t.event.Events, t.js) +} + +func (t *msgTrace) updateJetStreamEvent(subject string, noInterest bool) { + if t == nil { + return + } + // JetStream event should have been created in addJetStreamEvent + if t.js == nil { + return + } + t.js.Subject = subject + t.js.NoInterest = noInterest + // Update the timestamp since this is more accurate than when it + // was first added in addJetStreamEvent(). + t.js.Timestamp = time.Now() +} + +func (t *msgTrace) sendEventFromJetStream(err error) { + if t == nil { + return + } + // JetStream event should have been created in addJetStreamEvent + if t.js == nil { + return + } + if err != nil { + t.js.Error = err.Error() + } + t.sendEvent() +} + +func (t *msgTrace) sendEvent() { + if t == nil { + return + } + if t.js != nil { + ready := atomic.AddInt32(&t.ready, 1) == 2 + if !ready { + return + } + } + t.srv.sendInternalAccountSysMsg(t.acc, t.dest, &t.event.Server, t.event, t.ct) +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/opts.go b/vendor/github.com/nats-io/nats-server/v2/server/opts.go index abd6d5774b..e6da09c6bd 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/opts.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/opts.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2025 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -145,22 +145,33 @@ type RemoteGatewayOpts struct { // LeafNodeOpts are options for a given server to accept leaf node connections and/or connect to a remote cluster. type LeafNodeOpts struct { - Host string `json:"addr,omitempty"` - Port int `json:"port,omitempty"` - Username string `json:"-"` - Password string `json:"-"` - Nkey string `json:"-"` - Account string `json:"-"` - Users []*User `json:"-"` - AuthTimeout float64 `json:"auth_timeout,omitempty"` - TLSConfig *tls.Config `json:"-"` - TLSTimeout float64 `json:"tls_timeout,omitempty"` - TLSMap bool `json:"-"` - TLSPinnedCerts PinnedCertSet `json:"-"` - TLSHandshakeFirst bool `json:"-"` - Advertise string `json:"-"` - NoAdvertise bool `json:"-"` - ReconnectInterval time.Duration `json:"-"` + Host string `json:"addr,omitempty"` + Port int `json:"port,omitempty"` + Username string `json:"-"` + Password string `json:"-"` + Nkey string `json:"-"` + Account string `json:"-"` + Users []*User `json:"-"` + AuthTimeout float64 `json:"auth_timeout,omitempty"` + TLSConfig *tls.Config `json:"-"` + TLSTimeout float64 `json:"tls_timeout,omitempty"` + TLSMap bool `json:"-"` + TLSPinnedCerts PinnedCertSet `json:"-"` + // When set to true, the server will perform the TLS handshake before + // sending the INFO protocol. For remote leafnodes that are not configured + // with a similar option, their connection will fail with some sort + // of timeout or EOF error since they are expecting to receive an + // INFO protocol first. + TLSHandshakeFirst bool `json:"-"` + // If TLSHandshakeFirst is true and this value is strictly positive, + // the server will wait for that amount of time for the TLS handshake + // to start before falling back to previous behavior of sending the + // INFO protocol first. It allows for a mix of newer remote leafnodes + // that can require a TLS handshake first, and older that can't. + TLSHandshakeFirstFallback time.Duration `json:"-"` + Advertise string `json:"-"` + NoAdvertise bool `json:"-"` + ReconnectInterval time.Duration `json:"-"` // Compression options Compression CompressionOpts `json:"-"` @@ -230,13 +241,24 @@ type RemoteLeafOpts struct { // not be able to work. This tells the system to migrate the leaders away from this server. // This only changes leader for R>1 assets. JetStreamClusterMigrate bool `json:"jetstream_cluster_migrate,omitempty"` + + // If JetStreamClusterMigrate is set to true, this is the time after which the leader + // will be migrated away from this server if still disconnected. + JetStreamClusterMigrateDelay time.Duration `json:"jetstream_cluster_migrate_delay,omitempty"` } type JSLimitOpts struct { - MaxRequestBatch int - MaxAckPending int - MaxHAAssets int - Duplicates time.Duration + MaxRequestBatch int `json:"max_request_batch,omitempty"` + MaxAckPending int `json:"max_ack_pending,omitempty"` + MaxHAAssets int `json:"max_ha_assets,omitempty"` + Duplicates time.Duration `json:"max_duplicate_window,omitempty"` +} + +type JSTpmOpts struct { + KeysFile string + KeyPassword string + SrkPassword string + Pcr int } // AuthCallout option used to map external AuthN to NATS based AuthZ. @@ -250,6 +272,9 @@ type AuthCallout struct { // XKey is a public xkey for the authorization service. // This will enable encryption for server requests and the authorization service responses. XKey string + // AllowedAccounts that will be delegated to the auth service. + // If empty then all accounts will be delegated. + AllowedAccounts []string } // Options block for nats-server. @@ -300,6 +325,7 @@ type Options struct { Gateway GatewayOpts `json:"gateway,omitempty"` LeafNode LeafNodeOpts `json:"leaf,omitempty"` JetStream bool `json:"jetstream"` + JetStreamStrict bool `json:"-"` JetStreamMaxMemory int64 `json:"-"` JetStreamMaxStore int64 `json:"-"` JetStreamDomain string `json:"-"` @@ -309,8 +335,11 @@ type Options struct { JetStreamCipher StoreCipher `json:"-"` JetStreamUniqueTag string JetStreamLimits JSLimitOpts + JetStreamTpm JSTpmOpts JetStreamMaxCatchup int64 JetStreamRequestQueueLimit int64 + StreamMaxBufferedMsgs int `json:"-"` + StreamMaxBufferedSize int64 `json:"-"` StoreDir string `json:"-"` SyncInterval time.Duration `json:"-"` SyncAlways bool `json:"-"` @@ -409,7 +438,7 @@ type Options struct { // private fields, used for testing gatewaysSolicitDelay time.Duration - routeProto int + overrideProto int // JetStream maxMemSet bool @@ -421,6 +450,9 @@ type Options struct { // Used to mark that we had a top level authorization block. authBlockDefined bool + + // configDigest represents the state of configuration. + configDigest string } // WebsocketOpts are options for websocket @@ -441,6 +473,26 @@ type WebsocketOpts struct { // "jwt" specified in the CONNECT options is missing or empty. JWTCookie string + // Name of the cookie, which if present in WebSocket upgrade headers, + // will be treated as Username during CONNECT phase as long as + // "user" specified in the CONNECT options is missing or empty. + UsernameCookie string + + // Name of the cookie, which if present in WebSocket upgrade headers, + // will be treated as Password during CONNECT phase as long as + // "pass" specified in the CONNECT options is missing or empty. + PasswordCookie string + + // Name of the cookie, which if present in WebSocket upgrade headers, + // will be treated as Token during CONNECT phase as long as + // "auth_token" specified in the CONNECT options is missing or empty. + // Note that when this is useful for passing a JWT to an cuth callout + // when the server uses delegated authentication ("operator mode") or + // when using delegated authentication, but the auth callout validates some + // other JWT or string. Note that this does map to an actual server-wide + // "auth_token", note that using it for that purpose is greatly discouraged. + TokenCookie string + // Authentication section. If anything is configured in this section, // it will override the authorization configuration of regular clients. Username string @@ -485,6 +537,10 @@ type WebsocketOpts struct { // time needed for the TLS Handshake. HandshakeTimeout time.Duration + // Headers to be added to the upgrade response. + // Useful for adding custom headers like Strict-Transport-Security. + Headers map[string]string + // Snapshot of configured TLS options. tlsConfigOpts *TLSConfigOpts } @@ -839,10 +895,32 @@ func (o *Options) ProcessConfigFile(configFile string) error { if configFile == _EMPTY_ { return nil } - m, err := conf.ParseFileWithChecks(configFile) + m, digest, err := conf.ParseFileWithChecksDigest(configFile) if err != nil { return err } + o.configDigest = digest + + return o.processConfigFile(configFile, m) +} + +// ProcessConfigString is the same as ProcessConfigFile, but expects the +// contents of the config file to be passed in rather than the file name. +func (o *Options) ProcessConfigString(data string) error { + m, err := conf.ParseWithChecks(data) + if err != nil { + return err + } + + return o.processConfigFile(_EMPTY_, m) +} + +// ConfigDigest returns the digest representing the configuration. +func (o *Options) ConfigDigest() string { + return o.configDigest +} + +func (o *Options) processConfigFile(configFile string, m map[string]any) error { // Collect all errors and warnings and report them all together. errors := make([]error, 0) warnings := make([]error, 0) @@ -860,6 +938,21 @@ func (o *Options) ProcessConfigFile(configFile string) error { o.processConfigFileLine(k, v, &errors, &warnings) } + // Post-process: check auth callout allowed accounts against configured accounts. + if o.AuthCallout != nil { + accounts := make(map[string]struct{}) + for _, acc := range o.Accounts { + accounts[acc.Name] = struct{}{} + } + + for _, acc := range o.AuthCallout.AllowedAccounts { + if _, ok := accounts[acc]; !ok { + err := &configErr{nil, fmt.Sprintf("auth_callout allowed account %q not found in configured accounts", acc)} + errors = append(errors, err) + } + } + } + if len(errors) > 0 || len(warnings) > 0 { return &processConfigErr{ errors: errors, @@ -889,7 +982,13 @@ func (o *Options) processConfigFileLine(k string, v any, errors *[]error, warnin case "port": o.Port = int(v.(int64)) case "server_name": - o.ServerName = v.(string) + sn := v.(string) + if strings.Contains(sn, " ") { + err := &configErr{tk, ErrServerNameHasSpaces.Error()} + *errors = append(*errors, err) + return + } + o.ServerName = sn case "host", "net": o.Host = v.(string) case "debug": @@ -1682,7 +1781,13 @@ func parseCluster(v any, opts *Options, errors *[]error, warnings *[]error) erro tk, mv = unwrapValue(mv, <) switch strings.ToLower(mk) { case "name": - opts.Cluster.Name = mv.(string) + cn := mv.(string) + if strings.Contains(cn, " ") { + err := &configErr{tk, ErrClusterNameHasSpaces.Error()} + *errors = append(*errors, err) + continue + } + opts.Cluster.Name = cn case "listen": hp, err := parseListen(mv) if err != nil { @@ -1912,7 +2017,13 @@ func parseGateway(v any, o *Options, errors *[]error, warnings *[]error) error { tk, mv = unwrapValue(mv, <) switch strings.ToLower(mk) { case "name": - o.Gateway.Name = mv.(string) + gn := mv.(string) + if strings.Contains(gn, " ") { + err := &configErr{tk, ErrGatewayNameHasSpaces.Error()} + *errors = append(*errors, err) + continue + } + o.Gateway.Name = gn case "listen": hp, err := parseListen(mv) if err != nil { @@ -2068,6 +2179,19 @@ func parseJetStreamForAccount(v any, acc *Account, errors *[]error) error { return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)} } jsLimits.MaxAckPending = int(vv) + case "cluster_traffic": + vv, ok := mv.(string) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected either 'system' or 'account' string value for %q, got %v", mk, mv)} + } + switch vv { + case "system", _EMPTY_: + acc.js.nrgAccount = _EMPTY_ + case "owner": + acc.js.nrgAccount = acc.Name + default: + return &configErr{tk, fmt.Sprintf("Expected 'system' or 'owner' string value for %q, got %v", mk, mv)} + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -2165,6 +2289,61 @@ func parseJetStreamLimits(v any, opts *Options, errors *[]error) error { return nil } +// Parse the JetStream TPM options. +func parseJetStreamTPM(v interface{}, opts *Options, errors *[]error) error { + var lt token + tk, v := unwrapValue(v, <) + + tpm := JSTpmOpts{} + + vv, ok := v.(map[string]interface{}) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected a map to define JetStreamLimits, got %T", v)} + } + for mk, mv := range vv { + tk, mv = unwrapValue(mv, <) + switch strings.ToLower(mk) { + case "keys_file": + tpm.KeysFile = mv.(string) + case "encryption_password": + tpm.KeyPassword = mv.(string) + case "srk_password": + tpm.SrkPassword = mv.(string) + case "pcr": + tpm.Pcr = int(mv.(int64)) + case "cipher": + if err := setJetStreamEkCipher(opts, mv, tk); err != nil { + return err + } + default: + if !tk.IsUsedVariable() { + err := &unknownConfigFieldErr{ + field: mk, + configErr: configErr{ + token: tk, + }, + } + *errors = append(*errors, err) + continue + } + } + } + opts.JetStreamTpm = tpm + return nil +} + +func setJetStreamEkCipher(opts *Options, mv interface{}, tk token) error { + switch strings.ToLower(mv.(string)) { + case "chacha", "chachapoly": + opts.JetStreamCipher = ChaCha + case "aes": + opts.JetStreamCipher = AES + default: + return &configErr{tk, fmt.Sprintf("Unknown cipher type: %q", mv)} + } + return nil +} + // Parse enablement of jetstream for a server. func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) error { var lt token @@ -2189,6 +2368,12 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er for mk, mv := range vv { tk, mv = unwrapValue(mv, <) switch strings.ToLower(mk) { + case "strict": + if v, ok := mv.(bool); ok { + opts.JetStreamStrict = v + } else { + return &configErr{tk, fmt.Sprintf("Expected 'true' or 'false' for bool value, got '%s'", mv)} + } case "store", "store_dir", "storedir": // StoreDir can be set at the top level as well so have to prevent ambiguous declarations. if opts.StoreDir != _EMPTY_ { @@ -2226,13 +2411,8 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er case "prev_key", "prev_ek", "prev_encryption_key": opts.JetStreamOldKey = mv.(string) case "cipher": - switch strings.ToLower(mv.(string)) { - case "chacha", "chachapoly": - opts.JetStreamCipher = ChaCha - case "aes": - opts.JetStreamCipher = AES - default: - return &configErr{tk, fmt.Sprintf("Unknown cipher type: %q", mv)} + if err := setJetStreamEkCipher(opts, mv, tk); err != nil { + return err } case "extension_hint": opts.JetStreamExtHint = mv.(string) @@ -2240,6 +2420,10 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er if err := parseJetStreamLimits(tk, opts, errors); err != nil { return err } + case "tpm": + if err := parseJetStreamTPM(tk, opts, errors); err != nil { + return err + } case "unique_tag": opts.JetStreamUniqueTag = strings.ToLower(strings.TrimSpace(mv.(string))) case "max_outstanding_catchup": @@ -2248,6 +2432,18 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er return &configErr{tk, fmt.Sprintf("%s %s", strings.ToLower(mk), err)} } opts.JetStreamMaxCatchup = s + case "max_buffered_size": + s, err := getStorageSize(mv) + if err != nil { + return &configErr{tk, fmt.Sprintf("%s %s", strings.ToLower(mk), err)} + } + opts.StreamMaxBufferedSize = s + case "max_buffered_msgs": + mlen, ok := mv.(int64) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)} + } + opts.StreamMaxBufferedMsgs = int(mlen) case "request_queue_limit": lim, ok := mv.(int64) if !ok { @@ -2345,6 +2541,7 @@ func parseLeafNodes(v any, opts *Options, errors *[]error, warnings *[]error) er opts.LeafNode.TLSMap = tc.Map opts.LeafNode.TLSPinnedCerts = tc.PinnedCerts opts.LeafNode.TLSHandshakeFirst = tc.HandshakeFirst + opts.LeafNode.TLSHandshakeFirstFallback = tc.FallbackDelay opts.LeafNode.tlsConfigOpts = tc case "leafnode_advertise", "advertise": opts.LeafNode.Advertise = mv.(string) @@ -2611,7 +2808,26 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL case "ws_no_masking", "websocket_no_masking": remote.Websocket.NoMasking = v.(bool) case "jetstream_cluster_migrate", "js_cluster_migrate": - remote.JetStreamClusterMigrate = true + var lt token + + tk, v := unwrapValue(v, <) + switch vv := v.(type) { + case bool: + remote.JetStreamClusterMigrate = vv + case map[string]any: + remote.JetStreamClusterMigrate = true + migrateConfig, ok := v.(map[string]any) + if !ok { + continue + } + val, ok := migrateConfig["leader_migrate_delay"] + tk, delay := unwrapValue(val, &tk) + if ok { + remote.JetStreamClusterMigrateDelay = parseDuration("leader_migrate_delay", tk, delay, errors, warnings) + } + default: + *errors = append(*errors, &configErr{tk, fmt.Sprintf("Expected boolean or map for jetstream_cluster_migrate, got %T", v)}) + } case "compression": if err := parseCompression(&remote.Compression, CompressionS2Auto, tk, k, v); err != nil { *errors = append(*errors, err) @@ -2747,14 +2963,16 @@ type export struct { lat *serviceLatency rthr time.Duration tPos uint + atrc bool // allow_trace } type importStream struct { - acc *Account - an string - sub string - to string - pre string + acc *Account + an string + sub string + to string + pre string + atrc bool // allow_trace } type importService struct { @@ -2932,6 +3150,69 @@ func parseAccountLimits(mv any, acc *Account, errors *[]error) error { return nil } +func parseAccountMsgTrace(mv any, topKey string, acc *Account) error { + processDest := func(tk token, k string, v any) error { + td, ok := v.(string) + if !ok { + return &configErr{tk, fmt.Sprintf("Field %q should be a string, got %T", k, v)} + } + if !IsValidPublishSubject(td) { + return &configErr{tk, fmt.Sprintf("Trace destination %q is not valid", td)} + } + acc.traceDest = td + return nil + } + processSampling := func(tk token, n int) error { + if n <= 0 || n > 100 { + return &configErr{tk, fmt.Sprintf("Ttrace destination sampling value %d is invalid, needs to be [1..100]", n)} + } + acc.traceDestSampling = n + return nil + } + + var lt token + tk, v := unwrapValue(mv, <) + switch vv := v.(type) { + case string: + return processDest(tk, topKey, v) + case map[string]any: + for k, v := range vv { + tk, v := unwrapValue(v, <) + switch strings.ToLower(k) { + case "dest": + if err := processDest(tk, k, v); err != nil { + return err + } + case "sampling": + switch vv := v.(type) { + case int64: + if err := processSampling(tk, int(vv)); err != nil { + return err + } + case string: + s := strings.TrimSuffix(vv, "%") + n, err := strconv.Atoi(s) + if err != nil { + return &configErr{tk, fmt.Sprintf("Invalid trace destination sampling value %q", vv)} + } + if err := processSampling(tk, n); err != nil { + return err + } + default: + return &configErr{tk, fmt.Sprintf("Trace destination sampling field %q should be an integer or a percentage, got %T", k, v)} + } + default: + if !tk.IsUsedVariable() { + return &configErr{tk, fmt.Sprintf("Unknown field %q parsing account message trace map/struct %q", k, topKey)} + } + } + } + default: + return &configErr{tk, fmt.Sprintf("Expected account message trace %q to be a string or a map/struct, got %T", topKey, v)} + } + return nil +} + // parseAccounts will parse the different accounts syntax. func parseAccounts(v any, opts *Options, errors *[]error, warnings *[]error) error { var ( @@ -3061,6 +3342,23 @@ func parseAccounts(v any, opts *Options, errors *[]error, warnings *[]error) err *errors = append(*errors, err) continue } + case "msg_trace", "trace_dest": + if err := parseAccountMsgTrace(tk, k, acc); err != nil { + *errors = append(*errors, err) + continue + } + // If trace destination is set but no sampling, set it to 100%. + if acc.traceDest != _EMPTY_ && acc.traceDestSampling == 0 { + acc.traceDestSampling = 100 + } else if acc.traceDestSampling > 0 && acc.traceDest == _EMPTY_ { + // If no trace destination is provided, no trace would be + // triggered, so if the user set a sampling value expecting + // something to happen, want and set the value to 0 for good + // measure. + *warnings = append(*warnings, + &configErr{tk, "Trace destination sampling ignored since no destination was set"}) + acc.traceDestSampling = 0 + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3185,6 +3483,14 @@ func parseAccounts(v any, opts *Options, errors *[]error, warnings *[]error) err continue } } + + if service.atrc { + if err := service.acc.SetServiceExportAllowTrace(service.sub, true); err != nil { + msg := fmt.Sprintf("Error adding allow_trace for %q: %v", service.sub, err) + *errors = append(*errors, &configErr{tk, msg}) + continue + } + } } for _, stream := range importStreams { ta := am[stream.an] @@ -3194,13 +3500,13 @@ func parseAccounts(v any, opts *Options, errors *[]error, warnings *[]error) err continue } if stream.pre != _EMPTY_ { - if err := stream.acc.AddStreamImport(ta, stream.sub, stream.pre); err != nil { + if err := stream.acc.addStreamImportWithClaim(ta, stream.sub, stream.pre, stream.atrc, nil); err != nil { msg := fmt.Sprintf("Error adding stream import %q: %v", stream.sub, err) *errors = append(*errors, &configErr{tk, msg}) continue } } else { - if err := stream.acc.AddMappedStreamImport(ta, stream.sub, stream.to); err != nil { + if err := stream.acc.addMappedStreamImportWithClaim(ta, stream.sub, stream.to, stream.atrc, nil); err != nil { msg := fmt.Sprintf("Error adding stream import %q: %v", stream.sub, err) *errors = append(*errors, &configErr{tk, msg}) continue @@ -3358,6 +3664,9 @@ func parseExportStreamOrService(v any, errors *[]error) (*export, *export, error latToken token lt token accTokPos uint + atrc bool + atrcSeen bool + atrcToken token ) defer convertPanicToErrorList(<, errors) @@ -3385,6 +3694,11 @@ func parseExportStreamOrService(v any, errors *[]error) (*export, *export, error *errors = append(*errors, err) continue } + if atrcToken != nil { + err := &configErr{atrcToken, "Detected allow_trace directive on non-service"} + *errors = append(*errors, err) + continue + } mvs, ok := mv.(string) if !ok { err := &configErr{tk, fmt.Sprintf("Expected stream name to be string, got %T", mv)} @@ -3420,6 +3734,9 @@ func parseExportStreamOrService(v any, errors *[]error) (*export, *export, error if threshSeen { curService.rthr = thresh } + if atrcSeen { + curService.atrc = atrc + } case "response", "response_type": if rtSeen { err := &configErr{tk, "Duplicate response type definition"} @@ -3508,6 +3825,18 @@ func parseExportStreamOrService(v any, errors *[]error) (*export, *export, error } case "account_token_position": accTokPos = uint(mv.(int64)) + case "allow_trace": + atrcSeen = true + atrcToken = tk + atrc = mv.(bool) + if curStream != nil { + *errors = append(*errors, + &configErr{tk, "Detected allow_trace directive on non-service"}) + continue + } + if curService != nil { + curService.atrc = atrc + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3618,6 +3947,9 @@ func parseImportStreamOrService(v any, errors *[]error) (*importStream, *importS pre, to string share bool lt token + atrc bool + atrcSeen bool + atrcToken token ) defer convertPanicToErrorList(<, errors) @@ -3659,13 +3991,21 @@ func parseImportStreamOrService(v any, errors *[]error) (*importStream, *importS if pre != _EMPTY_ { curStream.pre = pre } + if atrcSeen { + curStream.atrc = atrc + } case "service": if curStream != nil { err := &configErr{tk, "Detected service but already saw a stream"} *errors = append(*errors, err) continue } - ac, ok := mv.(map[string]interface{}) + if atrcToken != nil { + err := &configErr{atrcToken, "Detected allow_trace directive on a non-stream"} + *errors = append(*errors, err) + continue + } + ac, ok := mv.(map[string]any) if !ok { err := &configErr{tk, fmt.Sprintf("Service entry should be an account map, got %T", mv)} *errors = append(*errors, err) @@ -3712,6 +4052,18 @@ func parseImportStreamOrService(v any, errors *[]error) (*importStream, *importS if curService != nil { curService.share = share } + case "allow_trace": + if curService != nil { + err := &configErr{tk, "Detected allow_trace directive on a non-stream"} + *errors = append(*errors, err) + continue + } + atrcSeen = true + atrc = mv.(bool) + atrcToken = tk + if curStream != nil { + curStream.atrc = atrc + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3964,6 +4316,15 @@ func parseAuthCallout(mv any, errors *[]error) (*AuthCallout, error) { if !nkeys.IsValidPublicCurveKey(ac.XKey) { return nil, &configErr{tk, fmt.Sprintf("Expected callout xkey to be a valid public xkey, got %q", ac.XKey)} } + case "allowed_accounts": + aua, ok := mv.([]any) + if !ok { + return nil, &configErr{tk, fmt.Sprintf("Expected allowed accounts field to be an array, got %T", v)} + } + for _, uv := range aua { + _, uv = unwrapValue(uv, <) + ac.AllowedAccounts = append(ac.AllowedAccounts, uv.(string)) + } default: if !tk.IsUsedVariable() { err := &configErr{tk, fmt.Sprintf("Unknown field %q parsing authorization callout", k)} @@ -4440,7 +4801,7 @@ func parseTLS(v any, isClientCtx bool) (t *TLSConfigOpts, retErr error) { rv = append(rv, mv) case []string: rv = append(rv, mv...) - case []interface{}: + case []any: for _, t := range mv { if token, ok := t.(token); ok { if ts, ok := token.Value().(string); ok { @@ -4713,8 +5074,31 @@ func parseWebsocket(v any, o *Options, errors *[]error) error { o.Websocket.AuthTimeout = auth.timeout case "jwt_cookie": o.Websocket.JWTCookie = mv.(string) + case "user_cookie": + o.Websocket.UsernameCookie = mv.(string) + case "pass_cookie": + o.Websocket.PasswordCookie = mv.(string) + case "token_cookie": + o.Websocket.TokenCookie = mv.(string) case "no_auth_user": o.Websocket.NoAuthUser = mv.(string) + case "headers": + m, ok := mv.(map[string]any) + if !ok { + err := &configErr{tk, fmt.Sprintf("error parsing headers: unsupported type %T", mv)} + *errors = append(*errors, err) + continue + } + o.Websocket.Headers = make(map[string]string) + for key, val := range m { + tk, val = unwrapValue(val, <) + if headerValue, ok := val.(string); !ok { + *errors = append(*errors, &configErr{tk, fmt.Sprintf("error parsing header key %s: unsupported type %T", key, val)}) + continue + } else { + o.Websocket.Headers[key] = headerValue + } + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -4985,7 +5369,10 @@ func MergeOptions(fileOpts, flagOpts *Options) *Options { mergeRoutes(&opts, flagOpts) } if flagOpts.JetStream { - fileOpts.JetStream = flagOpts.JetStream + opts.JetStream = flagOpts.JetStream + } + if flagOpts.StoreDir != _EMPTY_ { + opts.StoreDir = flagOpts.StoreDir } return &opts } @@ -5015,86 +5402,6 @@ func mergeRoutes(opts, flagOpts *Options) { opts.RoutesStr = flagOpts.RoutesStr } -// RemoveSelfReference removes this server from an array of routes -func RemoveSelfReference(clusterPort int, routes []*url.URL) ([]*url.URL, error) { - var cleanRoutes []*url.URL - cport := strconv.Itoa(clusterPort) - - selfIPs, err := getInterfaceIPs() - if err != nil { - return nil, err - } - for _, r := range routes { - host, port, err := net.SplitHostPort(r.Host) - if err != nil { - return nil, err - } - - ipList, err := getURLIP(host) - if err != nil { - return nil, err - } - if cport == port && isIPInList(selfIPs, ipList) { - continue - } - cleanRoutes = append(cleanRoutes, r) - } - - return cleanRoutes, nil -} - -func isIPInList(list1 []net.IP, list2 []net.IP) bool { - for _, ip1 := range list1 { - for _, ip2 := range list2 { - if ip1.Equal(ip2) { - return true - } - } - } - return false -} - -func getURLIP(ipStr string) ([]net.IP, error) { - ipList := []net.IP{} - - ip := net.ParseIP(ipStr) - if ip != nil { - ipList = append(ipList, ip) - return ipList, nil - } - - hostAddr, err := net.LookupHost(ipStr) - if err != nil { - return nil, fmt.Errorf("Error looking up host with route hostname: %v", err) - } - for _, addr := range hostAddr { - ip = net.ParseIP(addr) - if ip != nil { - ipList = append(ipList, ip) - } - } - return ipList, nil -} - -func getInterfaceIPs() ([]net.IP, error) { - var localIPs []net.IP - - interfaceAddr, err := net.InterfaceAddrs() - if err != nil { - return nil, fmt.Errorf("Error getting self referencing address: %v", err) - } - - for i := 0; i < len(interfaceAddr); i++ { - interfaceIP, _, _ := net.ParseCIDR(interfaceAddr[i].String()) - if net.ParseIP(interfaceIP.String()) != nil { - localIPs = append(localIPs, interfaceIP) - } else { - return nil, fmt.Errorf("Error parsing self referencing address: %v", err) - } - } - return localIPs, nil -} - func setBaselineOptions(opts *Options) { // Setup non-standard Go defaults if opts.Host == _EMPTY_ { @@ -5450,6 +5757,8 @@ func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, trackExplicitVal(&FlagSnapshot.inCmdLine, "Syslog", FlagSnapshot.Syslog) case "no_advertise": trackExplicitVal(&FlagSnapshot.inCmdLine, "Cluster.NoAdvertise", FlagSnapshot.Cluster.NoAdvertise) + case "js": + trackExplicitVal(&FlagSnapshot.inCmdLine, "JetStream", FlagSnapshot.JetStream) } }) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/parser.go b/vendor/github.com/nats-io/nats-server/v2/server/parser.go index 50b504b7f6..ed3eaa153c 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/parser.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/parser.go @@ -49,6 +49,7 @@ type pubArg struct { size int hdr int psi []*serviceImport + trace *msgTrace delivered bool // Only used for service imports } @@ -286,7 +287,11 @@ func (c *client) parse(buf []byte) error { if trace { c.traceInOp("HPUB", arg) } - if err := c.processHeaderPub(arg); err != nil { + var remaining []byte + if i < len(buf) { + remaining = buf[i+1:] + } + if err := c.processHeaderPub(arg, remaining); err != nil { return err } @@ -484,11 +489,19 @@ func (c *client) parse(buf []byte) error { c.msgBuf = buf[c.as : i+1] } + var mt *msgTrace + if c.pa.hdr > 0 { + mt = c.initMsgTrace() + } // Check for mappings. if (c.kind == CLIENT || c.kind == LEAF) && c.in.flags.isSet(hasMappings) { changed := c.selectMappedSubject() - if trace && changed { - c.traceInOp("MAPPING", []byte(fmt.Sprintf("%s -> %s", c.pa.mapped, c.pa.subject))) + if changed { + if trace { + c.traceInOp("MAPPING", []byte(fmt.Sprintf("%s -> %s", c.pa.mapped, c.pa.subject))) + } + // c.pa.subject is the subject the original is now mapped to. + mt.addSubjectMappingEvent(c.pa.subject) } } if trace { @@ -496,11 +509,14 @@ func (c *client) parse(buf []byte) error { } c.processInboundMsg(c.msgBuf) + + mt.sendEvent() c.argBuf, c.msgBuf, c.header = nil, nil, nil c.drop, c.as, c.state = 0, i+1, OP_START // Drop all pub args c.pa.arg, c.pa.pacache, c.pa.origin, c.pa.account, c.pa.subject, c.pa.mapped = nil, nil, nil, nil, nil, nil c.pa.reply, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.hdb, c.pa.queues = nil, -1, 0, nil, nil, nil + c.pa.trace = nil c.pa.delivered = false lmsg = false case OP_A: @@ -1273,7 +1289,7 @@ func (c *client) clonePubArg(lmsg bool) error { if c.pa.hdr < 0 { return c.processPub(c.argBuf) } else { - return c.processHeaderPub(c.argBuf) + return c.processHeaderPub(c.argBuf, nil) } } } diff --git a/vendor/github.com/nats-io/nats-server/v2/server/proto.go b/vendor/github.com/nats-io/nats-server/v2/server/proto.go new file mode 100644 index 0000000000..9843fff21b --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/proto.go @@ -0,0 +1,269 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Inspired by https://github.com/protocolbuffers/protobuf-go/blob/master/encoding/protowire/wire.go + +package server + +import ( + "errors" + "fmt" + "math" +) + +var errProtoInsufficient = errors.New("insufficient data to read a value") +var errProtoOverflow = errors.New("too much data for a value") +var errProtoInvalidFieldNumber = errors.New("invalid field number") + +func protoScanField(b []byte) (num, typ, size int, err error) { + num, typ, sizeTag, err := protoScanTag(b) + if err != nil { + return 0, 0, 0, err + } + b = b[sizeTag:] + + sizeValue, err := protoScanFieldValue(typ, b) + if err != nil { + return 0, 0, 0, err + } + return num, typ, sizeTag + sizeValue, nil +} + +func protoScanTag(b []byte) (num, typ, size int, err error) { + tagint, size, err := protoScanVarint(b) + if err != nil { + return 0, 0, 0, err + } + + // NOTE: MessageSet allows for larger field numbers than normal. + if (tagint >> 3) > uint64(math.MaxInt32) { + return 0, 0, 0, errProtoInvalidFieldNumber + } + num = int(tagint >> 3) + if num < 1 { + return 0, 0, 0, errProtoInvalidFieldNumber + } + typ = int(tagint & 7) + + return num, typ, size, nil +} + +func protoScanFieldValue(typ int, b []byte) (size int, err error) { + switch typ { + case 0: + _, size, err = protoScanVarint(b) + case 5: // fixed32 + size = 4 + case 1: // fixed64 + size = 8 + case 2: // length-delimited + size, err = protoScanBytes(b) + default: + return 0, fmt.Errorf("unsupported type: %d", typ) + } + return size, err +} + +func protoScanVarint(b []byte) (v uint64, size int, err error) { + var y uint64 + if len(b) <= 0 { + return 0, 0, errProtoInsufficient + } + v = uint64(b[0]) + if v < 0x80 { + return v, 1, nil + } + v -= 0x80 + + if len(b) <= 1 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[1]) + v += y << 7 + if y < 0x80 { + return v, 2, nil + } + v -= 0x80 << 7 + + if len(b) <= 2 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[2]) + v += y << 14 + if y < 0x80 { + return v, 3, nil + } + v -= 0x80 << 14 + + if len(b) <= 3 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[3]) + v += y << 21 + if y < 0x80 { + return v, 4, nil + } + v -= 0x80 << 21 + + if len(b) <= 4 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[4]) + v += y << 28 + if y < 0x80 { + return v, 5, nil + } + v -= 0x80 << 28 + + if len(b) <= 5 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[5]) + v += y << 35 + if y < 0x80 { + return v, 6, nil + } + v -= 0x80 << 35 + + if len(b) <= 6 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[6]) + v += y << 42 + if y < 0x80 { + return v, 7, nil + } + v -= 0x80 << 42 + + if len(b) <= 7 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[7]) + v += y << 49 + if y < 0x80 { + return v, 8, nil + } + v -= 0x80 << 49 + + if len(b) <= 8 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[8]) + v += y << 56 + if y < 0x80 { + return v, 9, nil + } + v -= 0x80 << 56 + + if len(b) <= 9 { + return 0, 0, errProtoInsufficient + } + y = uint64(b[9]) + v += y << 63 + if y < 2 { + return v, 10, nil + } + return 0, 0, errProtoOverflow +} + +func protoScanBytes(b []byte) (size int, err error) { + l, lenSize, err := protoScanVarint(b) + if err != nil { + return 0, err + } + if l > uint64(len(b[lenSize:])) { + return 0, errProtoInsufficient + } + return lenSize + int(l), nil +} + +func protoEncodeVarint(v uint64) []byte { + b := make([]byte, 0, 10) + switch { + case v < 1<<7: + b = append(b, byte(v)) + case v < 1<<14: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte(v>>7)) + case v < 1<<21: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte(v>>14)) + case v < 1<<28: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte(v>>21)) + case v < 1<<35: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte(v>>28)) + case v < 1<<42: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte((v>>28)&0x7f|0x80), + byte(v>>35)) + case v < 1<<49: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte((v>>28)&0x7f|0x80), + byte((v>>35)&0x7f|0x80), + byte(v>>42)) + case v < 1<<56: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte((v>>28)&0x7f|0x80), + byte((v>>35)&0x7f|0x80), + byte((v>>42)&0x7f|0x80), + byte(v>>49)) + case v < 1<<63: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte((v>>28)&0x7f|0x80), + byte((v>>35)&0x7f|0x80), + byte((v>>42)&0x7f|0x80), + byte((v>>49)&0x7f|0x80), + byte(v>>56)) + default: + b = append(b, + byte((v>>0)&0x7f|0x80), + byte((v>>7)&0x7f|0x80), + byte((v>>14)&0x7f|0x80), + byte((v>>21)&0x7f|0x80), + byte((v>>28)&0x7f|0x80), + byte((v>>35)&0x7f|0x80), + byte((v>>42)&0x7f|0x80), + byte((v>>49)&0x7f|0x80), + byte((v>>56)&0x7f|0x80), + 1) + } + return b +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_freebsd.go b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_freebsd.go index 952ccb2dfd..21b5445c4f 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_freebsd.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_freebsd.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build !amd64 -// +build !amd64 package pse diff --git a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_rumprun.go b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_rumprun.go index 93f53a0775..d16e6ea95b 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_rumprun.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_rumprun.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build rumprun -// +build rumprun package pse diff --git a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_wasm.go b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_wasm.go index 4a6689d89f..e3db060c80 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_wasm.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_wasm.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build wasm -// +build wasm package pse diff --git a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_windows.go b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_windows.go index 88d7fb0763..09f84a0ccb 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_windows.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_windows.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build windows -// +build windows package pse diff --git a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_zos.go b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_zos.go index 232db2f086..df469f4e1e 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_zos.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/pse/pse_zos.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build zos -// +build zos package pse diff --git a/vendor/github.com/nats-io/nats-server/v2/server/raft.go b/vendor/github.com/nats-io/nats-server/v2/server/raft.go index 245e419492..dd2f345b42 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/raft.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/raft.go @@ -78,17 +78,18 @@ type RaftNode interface { Stop() WaitForStop() Delete() + RecreateInternalSubs() error IsSystemAccount() bool } type WAL interface { Type() StorageType - StoreMsg(subj string, hdr, msg []byte) (uint64, int64, error) + StoreMsg(subj string, hdr, msg []byte, ttl int64) (uint64, int64, error) LoadMsg(index uint64, sm *StoreMsg) (*StoreMsg, error) RemoveMsg(index uint64) (bool, error) Compact(index uint64) (uint64, error) Purge() (uint64, error) - PurgeEx(subject string, seq, keep uint64) (uint64, error) + PurgeEx(subject string, seq, keep uint64, noMarkers bool) (uint64, error) Truncate(seq uint64) error State() StreamState FastState(*StreamState) @@ -132,6 +133,7 @@ type raft struct { created time.Time // Time that the group was created accName string // Account name of the asset this raft group is for + acc *Account // Account that NRG traffic will be sent/received in group string // Raft group sd string // Store directory id string // Node ID @@ -142,9 +144,10 @@ type raft struct { track bool // werr error // Last write error - state atomic.Int32 // RaftState - hh hash.Hash64 // Highwayhash, used for snapshots - snapfile string // Snapshot filename + state atomic.Int32 // RaftState + leaderState atomic.Bool // Is in (complete) leader state. + hh hash.Hash64 // Highwayhash, used for snapshots + snapfile string // Snapshot filename csz int // Cluster size qn int // Number of nodes needed to establish quorum @@ -166,6 +169,8 @@ type raft struct { commit uint64 // Index of the most recent commit applied uint64 // Index of the most recently applied commit + aflr uint64 // Index when to signal initial messages have been applied after becoming leader. 0 means signaling is disabled. + leader string // The ID of the leader vote string // Our current vote state lxfer bool // Are we doing a leadership transfer? @@ -179,7 +184,9 @@ type raft struct { dflag bool // Debug flag hasleader atomic.Bool // Is there a group leader right now? pleader atomic.Bool // Has the group ever had a leader? - observer bool // The node is observing, i.e. not participating in voting + isSysAcc atomic.Bool // Are we utilizing the system account? + + observer bool // The node is observing, i.e. not participating in voting extSt extensionState // Extension state @@ -361,8 +368,6 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel s.mu.RUnlock() return nil, ErrNoSysAccount } - sq := s.sys.sq - sacc := s.sys.account hash := s.sys.shash s.mu.RUnlock() @@ -390,9 +395,7 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel acks: make(map[uint64]map[string]struct{}), pae: make(map[uint64]*appendEntry), s: s, - c: s.createInternalSystemClient(), js: s.getJetStream(), - sq: sq, quit: make(chan struct{}), reqs: newIPQueue[*voteRequest](s, qpfx+"vreq"), votes: newIPQueue[*voteResponse](s, qpfx+"vresp"), @@ -405,7 +408,14 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel observer: cfg.Observer, extSt: ps.domainExt, } - n.c.registerWithAccount(sacc) + + // Setup our internal subscriptions for proposals, votes and append entries. + // If we fail to do this for some reason then this is fatal — we cannot + // continue setting up or the Raft node may be partially/totally isolated. + if err := n.RecreateInternalSubs(); err != nil { + n.shutdown() + return nil, err + } if atomic.LoadInt32(&s.logging.debug) > 0 { n.dflag = true @@ -498,14 +508,6 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel } } - // Setup our internal subscriptions for proposals, votes and append entries. - // If we fail to do this for some reason then this is fatal — we cannot - // continue setting up or the Raft node may be partially/totally isolated. - if err := n.createInternalSubs(); err != nil { - n.shutdown() - return nil, err - } - n.debug("Started") // Check if we need to start in observer mode due to lame duck status. @@ -545,10 +547,116 @@ func (s *Server) startRaftNode(accName string, cfg *RaftConfig, labels pprofLabe return n, nil } +// Returns whether peers within this group claim to support +// moving NRG traffic into the asset account. +// Lock must be held. +func (n *raft) checkAccountNRGStatus() bool { + if !n.s.accountNRGAllowed.Load() { + return false + } + enabled := true + for pn := range n.peers { + if si, ok := n.s.nodeToInfo.Load(pn); ok && si != nil { + enabled = enabled && si.(nodeInfo).accountNRG + } + } + return enabled +} + // Whether we are using the system account or not. -// In 2.10.x this is always true as there is no account NRG like in 2.11.x. func (n *raft) IsSystemAccount() bool { - return true + return n.isSysAcc.Load() +} + +func (n *raft) RecreateInternalSubs() error { + n.Lock() + defer n.Unlock() + return n.recreateInternalSubsLocked() +} + +func (n *raft) recreateInternalSubsLocked() error { + // Sanity check for system account, as it can disappear when + // the system is shutting down. + if n.s == nil { + return fmt.Errorf("server not found") + } + n.s.mu.RLock() + sys := n.s.sys + n.s.mu.RUnlock() + if sys == nil { + return fmt.Errorf("system account not found") + } + + // Default is the system account. + nrgAcc := sys.account + n.isSysAcc.Store(true) + + // Is account NRG enabled in this account and do all group + // peers claim to also support account NRG? + if n.checkAccountNRGStatus() { + // Check whether the account that the asset belongs to + // has volunteered a different NRG account. + target := nrgAcc.Name + if a, _ := n.s.lookupAccount(n.accName); a != nil { + a.mu.RLock() + if a.js != nil { + target = a.js.nrgAccount + } + a.mu.RUnlock() + } + + // If the target account exists, then we'll use that. + if target != _EMPTY_ { + if a, _ := n.s.lookupAccount(target); a != nil { + nrgAcc = a + if a != sys.account { + n.isSysAcc.Store(false) + } + } + } + } + if n.aesub != nil && n.acc == nrgAcc { + // Subscriptions already exist and the account NRG state + // hasn't changed. + return nil + } + + // Need to cancel any in-progress catch-ups, otherwise the + // inboxes are about to be pulled out from underneath it in + // the next step... + n.cancelCatchup() + + // If we have an existing client then tear down any existing + // subscriptions and close the internal client. + if c := n.c; c != nil { + c.mu.Lock() + subs := make([]*subscription, 0, len(c.subs)) + for _, sub := range c.subs { + subs = append(subs, sub) + } + c.mu.Unlock() + for _, sub := range subs { + n.unsubscribe(sub) + } + c.closeConnection(InternalClient) + } + + if n.acc != nrgAcc { + n.debug("Subscribing in '%s'", nrgAcc.GetName()) + } + + c := n.s.createInternalSystemClient() + c.registerWithAccount(nrgAcc) + if nrgAcc.sq == nil { + nrgAcc.sq = n.s.newSendQ(nrgAcc) + } + n.c = c + n.sq = nrgAcc.sq + n.acc = nrgAcc + + // Recreate any internal subscriptions for voting, append + // entries etc in the new account. + return n.createInternalSubs() } // outOfResources checks to see if we are out of resources. @@ -650,9 +758,7 @@ func (s *Server) stepdownRaftNodes() { s.rnMu.RUnlock() for _, node := range nodes { - if node.Leader() { - node.StepDown() - } + node.StepDown() node.SetObserver(true) } } @@ -701,8 +807,7 @@ func (s *Server) transferRaftLeaders() bool { var didTransfer bool for _, node := range nodes { - if node.Leader() { - node.StepDown() + if err := node.StepDown(); err == nil { didTransfer = true } node.SetObserver(true) @@ -753,7 +858,7 @@ func (n *raft) ProposeMulti(entries []*Entry) error { // ForwardProposal will forward the proposal to the leader if known. // If we are the leader this is the same as calling propose. func (n *raft) ForwardProposal(entry []byte) error { - if n.Leader() { + if n.State() == Leader { return n.Propose(entry) } @@ -952,7 +1057,7 @@ func (n *raft) ResumeApply() { } } -// Applied is a callback that must be be called by the upper layer when it +// Applied is a callback that must be called by the upper layer when it // has successfully applied the committed entries that it received from the // apply queue. It will return the number of entries and an estimation of the // byte size that could be removed with a snapshot/compact. @@ -970,6 +1075,17 @@ func (n *raft) Applied(index uint64) (entries uint64, bytes uint64) { n.applied = index } + // If it was set, and we reached the minimum applied index, reset and send signal to upper layer. + if n.aflr > 0 && index >= n.aflr { + n.aflr = 0 + // Quick sanity-check to confirm we're still leader. + // In which case we must signal, since switchToLeader would not have done so already. + if n.State() == Leader { + n.leaderState.Store(true) + n.updateLeadChange(true) + } + } + // Calculate the number of entries and estimate the byte size that // we can now remove with a compaction/snapshot. var state StreamState @@ -1262,7 +1378,7 @@ func (n *raft) Leader() bool { if n == nil { return false } - return n.State() == Leader + return n.leaderState.Load() } // stepdown immediately steps down the Raft node to the @@ -1435,18 +1551,16 @@ func (n *raft) selectNextLeader() string { // StepDown will have a leader stepdown and optionally do a leader transfer. func (n *raft) StepDown(preferred ...string) error { - n.Lock() + if n.State() != Leader { + return errNotLeader + } + n.Lock() if len(preferred) > 1 { n.Unlock() return errTooManyPrefs } - if n.State() != Leader { - n.Unlock() - return errNotLeader - } - n.debug("Being asked to stepdown") // See if we have up to date followers. @@ -1557,6 +1671,7 @@ func (n *raft) xferCampaign() error { } // State returns the current state for this node. +// Upper layers should not check State to check if we're Leader, use n.Leader() instead. func (n *raft) State() RaftState { return RaftState(n.state.Load()) } @@ -1565,7 +1680,7 @@ func (n *raft) State() RaftState { func (n *raft) Progress() (index, commit, applied uint64) { n.RLock() defer n.RUnlock() - return n.pindex + 1, n.commit, n.applied + return n.pindex, n.commit, n.applied } // Size returns number of entries and total bytes for our WAL. @@ -1674,6 +1789,7 @@ func (n *raft) shutdown() { // First call to Stop or Delete should close the quit chan // to notify the runAs goroutines to stop what they're doing. if n.state.Swap(int32(Closed)) != int32(Closed) { + n.leaderState.Store(false) close(n.quit) } } @@ -1725,9 +1841,8 @@ func (n *raft) unsubscribe(sub *subscription) { } } +// Lock should be held. func (n *raft) createInternalSubs() error { - n.Lock() - defer n.Unlock() n.vsubj, n.vreply = fmt.Sprintf(raftVoteSubj, n.group), n.newInbox() n.asubj, n.areply = fmt.Sprintf(raftAppendSubj, n.group), n.newInbox() n.psubj = fmt.Sprintf(raftPropSubj, n.group) @@ -2316,7 +2431,7 @@ func (n *raft) decodeAppendEntryResponse(msg []byte) *appendEntryResponse { func (n *raft) handleForwardedRemovePeerProposal(sub *subscription, c *client, _ *Account, _, reply string, msg []byte) { n.debug("Received forwarded remove peer proposal: %q", msg) - if !n.Leader() { + if n.State() != Leader { n.debug("Ignoring forwarded peer removal proposal, not leader") return } @@ -2341,7 +2456,7 @@ func (n *raft) handleForwardedRemovePeerProposal(sub *subscription, c *client, _ // Called when a peer has forwarded a proposal. func (n *raft) handleForwardedProposal(sub *subscription, c *client, _ *Account, _, reply string, msg []byte) { - if !n.Leader() { + if n.State() != Leader { n.debug("Ignoring forwarded proposal, not leader") return } @@ -2605,14 +2720,14 @@ func (n *raft) runCatchup(ar *appendEntryResponse, indexUpdatesQ *ipQueue[uint64 defer stepCheck.Stop() // Run as long as we are leader and still not caught up. - for n.Leader() { + for n.State() == Leader { select { case <-n.s.quitCh: return case <-n.quit: return case <-stepCheck.C: - if !n.Leader() { + if n.State() != Leader { n.debug("Catching up canceled, no longer leader") return } @@ -2904,18 +3019,23 @@ func (n *raft) trackResponse(ar *appendEntryResponse) { // See if we have items to apply. var sendHB bool - if results := n.acks[ar.index]; results != nil { - results[ar.peer] = struct{}{} - if nr := len(results); nr >= n.qn { - // We have a quorum. - for index := n.commit + 1; index <= ar.index; index++ { - if err := n.applyCommit(index); err != nil && err != errNodeClosed { - n.error("Got an error applying commit for %d: %v", index, err) - break - } + results := n.acks[ar.index] + if results == nil { + results = make(map[string]struct{}) + n.acks[ar.index] = results + } + results[ar.peer] = struct{}{} + + // We don't count ourselves to account for leader changes, so add 1. + if nr := len(results); nr+1 >= n.qn { + // We have a quorum. + for index := n.commit + 1; index <= ar.index; index++ { + if err := n.applyCommit(index); err != nil && err != errNodeClosed { + n.error("Got an error applying commit for %d: %v", index, err) + break } - sendHB = n.prop.len() == 0 } + sendHB = n.prop.len() == 0 } n.Unlock() @@ -2945,6 +3065,9 @@ func (n *raft) adjustClusterSizeAndQuorum() { go n.sendHeartbeat() } } + if ncsz != pcsz { + n.recreateInternalSubsLocked() + } } // Track interactions with this peer. @@ -3143,7 +3266,7 @@ func (n *raft) truncateWAL(term, index uint64) { n.wal.Truncate(0) // If our index is non-zero use PurgeEx to set us to the correct next index. if index > 0 { - n.wal.PurgeEx(fwcs, index+1, 0) + n.wal.PurgeEx(fwcs, index+1, 0, true) } } else { n.warn("Error truncating WAL: %v", err) @@ -3604,7 +3727,7 @@ func (n *raft) storeToWAL(ae *appendEntry) error { return n.werr } - seq, _, err := n.wal.StoreMsg(_EMPTY_, nil, ae.buf) + seq, _, err := n.wal.StoreMsg(_EMPTY_, nil, ae.buf, 0) if err != nil { n.setWriteErrLocked(err) return err @@ -3651,8 +3774,6 @@ func (n *raft) sendAppendEntry(entries []*Entry) { if err := n.storeToWAL(ae); err != nil { return } - // We count ourselves. - n.acks[n.pindex] = map[string]struct{}{n.id: {}} n.active = time.Now() // Save in memory for faster processing during applyCommit. @@ -4111,11 +4232,11 @@ func (n *raft) updateLeadChange(isLeader bool) { } // Lock should be held. -func (n *raft) switchState(state RaftState) { +func (n *raft) switchState(state RaftState) bool { retry: pstate := n.State() if pstate == Closed { - return + return false } // Set our state. If something else has changed our state @@ -4127,19 +4248,23 @@ retry: // Reset the election timer. n.resetElectionTimeout() + var leadChange bool if pstate == Leader && state != Leader { + leadChange = true n.updateLeadChange(false) // Drain the append entry response and proposal queues. n.resp.drain() n.prop.drain() } else if state == Leader && pstate != Leader { + // Don't updateLeadChange here, it will be done in switchToLeader or after initial messages are applied. + leadChange = true if len(n.pae) > 0 { n.pae = make(map[uint64]*appendEntry) } - n.updateLeadChange(true) } n.writeTermVote() + return leadChange } const ( @@ -4161,7 +4286,13 @@ func (n *raft) switchToFollowerLocked(leader string) { n.debug("Switching to follower") + n.aflr = 0 + n.leaderState.Store(false) n.lxfer = false + // Reset acks, we can't assume acks from a previous term are still valid in another term. + if len(n.acks) > 0 { + n.acks = make(map[uint64]map[string]struct{}) + } n.updateLeader(leader) n.switchState(Follower) } @@ -4214,7 +4345,22 @@ func (n *raft) switchToLeader() { n.lxfer = false n.updateLeader(n.id) - n.switchState(Leader) + leadChange := n.switchState(Leader) + + if leadChange { + // Wait for messages to be applied if we've stored more, otherwise signal immediately. + // It's important to wait signaling we're leader if we're not up-to-date yet, as that + // would mean we're in a consistent state compared with the previous leader. + if n.pindex > n.applied { + n.aflr = n.pindex + } else { + // We know we have applied all entries in our log and can signal immediately. + // For sanity reset applied floor back down to 0, so we aren't able to signal twice. + n.aflr = 0 + n.leaderState.Store(true) + n.updateLeadChange(true) + } + } n.Unlock() if sendHB { diff --git a/vendor/github.com/nats-io/nats-server/v2/server/reload.go b/vendor/github.com/nats-io/nats-server/v2/server/reload.go index aea3348429..7bf940cded 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/reload.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/reload.go @@ -848,6 +848,7 @@ func (l *leafNodeOption) Apply(s *Server) { opts := s.getOpts() if l.tlsFirstChanged { s.Noticef("Reloaded: LeafNode TLS HandshakeFirst value is: %v", opts.LeafNode.TLSHandshakeFirst) + s.Noticef("Reloaded: LeafNode TLS HandshakeFirstFallback value is: %v", opts.LeafNode.TLSHandshakeFirstFallback) for _, r := range opts.LeafNode.Remotes { s.Noticef("Reloaded: LeafNode Remote to %v TLS HandshakeFirst value is: %v", r.URLs, r.TLSHandshakeFirst) } @@ -1168,6 +1169,7 @@ func imposeOrder(value any) error { *OCSPConfig, map[string]string, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig: // explicitly skipped types case *AuthCallout: + case JSTpmOpts: default: // this will fail during unit tests return fmt.Errorf("OnReload, sort or explicitly skip type: %s", @@ -1367,14 +1369,16 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { tmpNew.TLSConfig = nil tmpOld.tlsConfigOpts = nil tmpNew.tlsConfigOpts = nil - // We will allow TLSHandshakeFirst to me config reloaded. First, + // We will allow TLSHandshakeFirst to be config reloaded. First, // we just want to detect if there was a change in the leafnodes{} // block, and if not, we will check the remotes. - handshakeFirstChanged := tmpOld.TLSHandshakeFirst != tmpNew.TLSHandshakeFirst + handshakeFirstChanged := tmpOld.TLSHandshakeFirst != tmpNew.TLSHandshakeFirst || + tmpOld.TLSHandshakeFirstFallback != tmpNew.TLSHandshakeFirstFallback // If changed, set them (in the temporary variables) to false so that the // rest of the comparison does not fail. if handshakeFirstChanged { tmpOld.TLSHandshakeFirst, tmpNew.TLSHandshakeFirst = false, false + tmpOld.TLSHandshakeFirstFallback, tmpNew.TLSHandshakeFirstFallback = 0, 0 } else if len(tmpOld.Remotes) == len(tmpNew.Remotes) { // Since we don't support changes in the remotes, we will do a // simple pass to see if there was a change of this field. @@ -1636,6 +1640,10 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) { if new != old { diffOpts = append(diffOpts, &profBlockRateReload{newValue: new}) } + case "configdigest": + // skip changes in config digest, this is handled already while + // processing the config. + continue case "nofastproducerstall": diffOpts = append(diffOpts, &noFastProdStallReload{noStall: newValue.(bool)}) default: @@ -1785,8 +1793,11 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) { if err := s.reloadOCSP(); err != nil { s.Warnf("Can't restart OCSP features: %v", err) } - - s.Noticef("Reloaded server configuration") + var cd string + if newOpts.configDigest != "" { + cd = fmt.Sprintf("(%s)", newOpts.configDigest) + } + s.Noticef("Reloaded server configuration %s", cd) } // This will send a reset to the internal send loop. diff --git a/vendor/github.com/nats-io/nats-server/v2/server/route.go b/vendor/github.com/nats-io/nats-server/v2/server/route.go index 51e7712352..d56cdf9728 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/route.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/route.go @@ -42,17 +42,6 @@ const ( Explicit ) -const ( - // RouteProtoZero is the original Route protocol from 2009. - // http://nats.io/documentation/internals/nats-protocol/ - RouteProtoZero = iota - // RouteProtoInfo signals a route can receive more then the original INFO block. - // This can be used to update remote cluster permissions, etc... - RouteProtoInfo - // RouteProtoV2 is the new route/cluster protocol that provides account support. - RouteProtoV2 -) - // Include the space for the proto var ( aSubBytes = []byte{'A', '+', ' '} @@ -63,11 +52,6 @@ var ( lUnsubBytes = []byte{'L', 'S', '-', ' '} ) -// Used by tests -func setRouteProtoForTest(wantedProto int) int { - return (wantedProto + 1) * -1 -} - type route struct { remoteID string remoteName string @@ -99,8 +83,18 @@ type route struct { // Selected compression mode, which may be different from the // server configured mode. compression string + // Transient value used to set the Info.GossipMode when initiating + // an implicit route and sending to the remote. + gossipMode byte } +// Do not change the values/order since they are exchanged between servers. +const ( + gossipDefault = byte(iota) + gossipDisabled + gossipOverride +) + type connectInfo struct { Echo bool `json:"echo"` Verbose bool `json:"verbose"` @@ -578,6 +572,12 @@ func (c *client) processRouteInfo(info *Info) { return } s.mu.Lock() + // If running without system account and adding a dedicated + // route for an account for the first time, it could be that + // the map is nil. If so, create it. + if s.accRoutes == nil { + s.accRoutes = make(map[string]map[string]*client) + } if _, ok := s.accRoutes[an]; !ok { s.accRoutes[an] = make(map[string]*client) } @@ -702,12 +702,15 @@ func (c *client) processRouteInfo(info *Info) { return } + var sendDelayedInfo bool + // First INFO, check if this server is configured for compression because // if that is the case, we need to negotiate it with the remote server. if needsCompression(opts.Cluster.Compression.Mode) { accName := bytesToString(c.route.accName) // If we did not yet negotiate... - if !c.flags.isSet(compressionNegotiated) { + compNeg := c.flags.isSet(compressionNegotiated) + if !compNeg { // Prevent from getting back here. c.flags.set(compressionNegotiated) // Release client lock since following function will need server lock. @@ -724,24 +727,21 @@ func (c *client) processRouteInfo(info *Info) { } // No compression because one side does not want/can't, so proceed. c.mu.Lock() - } else if didSolicit { - // The other side has switched to compression, so we can now set - // the first ping timer and send the delayed INFO for situations - // where it was not already sent. - c.setFirstPingTimer() - if !routeShouldDelayInfo(accName, opts) { - cm := compressionModeForInfoProtocol(&opts.Cluster.Compression, c.route.compression) - // Need to release and then reacquire... + // Check that the connection did not close if the lock was released. + if c.isClosed() { c.mu.Unlock() - s.sendDelayedRouteInfo(c, accName, cm) - c.mu.Lock() + return } } - // Check that the connection did not close if the lock was released. - if c.isClosed() { - c.mu.Unlock() - return + // We can set the ping timer after we just negotiated compression above, + // or for solicited routes if we already negotiated. + if !compNeg || didSolicit { + c.setFirstPingTimer() } + // When compression is configured, we delay the initial INFO for any + // solicited route. So we need to send the delayed INFO simply based + // on the didSolicit boolean. + sendDelayedInfo = didSolicit } else { // Coming from an old server, the Compression field would be the empty // string. For servers that are configured with CompressionNotSupported, @@ -751,12 +751,17 @@ func (c *client) processRouteInfo(info *Info) { } else { c.route.compression = CompressionOff } + // When compression is not configured, we delay the initial INFO only + // for solicited pooled routes, so use the same check that we did when + // we decided to delay in createRoute(). + sendDelayedInfo = didSolicit && routeShouldDelayInfo(bytesToString(c.route.accName), opts) } // Mark that the INFO protocol has been received, so we can detect updates. c.flags.set(infoReceived) - // Get the route's proto version + // Get the route's proto version. It will be used to check if the connection + // supports certain features, such as message tracing. c.opts.Protocol = info.Proto // Headers @@ -828,11 +833,15 @@ func (c *client) processRouteInfo(info *Info) { } accName := string(c.route.accName) + // Capture the noGossip value and reset it here. + gossipMode := c.route.gossipMode + c.route.gossipMode = 0 + // Check to see if we have this remote already registered. // This can happen when both servers have routes to each other. c.mu.Unlock() - if added := s.addRoute(c, didSolicit, info, accName); added { + if added := s.addRoute(c, didSolicit, sendDelayedInfo, gossipMode, info, accName); added { if accName != _EMPTY_ { c.Debugf("Registering remote route %q for account %q", info.ID, accName) } else { @@ -866,7 +875,7 @@ func (s *Server) negotiateRouteCompression(c *client, didSolicit bool, accName, if needsCompression(cm) { // Generate an INFO with the chosen compression mode. s.mu.Lock() - infoProto := s.generateRouteInitialInfoJSON(accName, cm, 0) + infoProto := s.generateRouteInitialInfoJSON(accName, cm, 0, gossipDefault) s.mu.Unlock() // If we solicited, then send this INFO protocol BEFORE switching @@ -895,29 +904,9 @@ func (s *Server) negotiateRouteCompression(c *client, didSolicit bool, accName, c.mu.Unlock() return true, nil } - // We are not using compression, set the ping timer. - c.mu.Lock() - c.setFirstPingTimer() - c.mu.Unlock() - // If this is a solicited route, we need to send the INFO if it was not - // done during createRoute() and will not be done in addRoute(). - if didSolicit && !routeShouldDelayInfo(accName, opts) { - cm = compressionModeForInfoProtocol(&opts.Cluster.Compression, cm) - s.sendDelayedRouteInfo(c, accName, cm) - } return false, nil } -func (s *Server) sendDelayedRouteInfo(c *client, accName, cm string) { - s.mu.Lock() - infoProto := s.generateRouteInitialInfoJSON(accName, cm, 0) - s.mu.Unlock() - - c.mu.Lock() - c.enqueueProto(infoProto) - c.mu.Unlock() -} - // Possibly sends local subscriptions interest to this route // based on changes in the remote's Export permissions. func (s *Server) updateRemoteRoutePerms(c *client, info *Info) { @@ -1053,7 +1042,7 @@ func (s *Server) processImplicitRoute(info *Info, routeNoPool bool) { if info.AuthRequired { r.User = url.UserPassword(opts.Cluster.Username, opts.Cluster.Password) } - s.startGoRoutine(func() { s.connectToRoute(r, false, true, info.RouteAccount) }) + s.startGoRoutine(func() { s.connectToRoute(r, Implicit, true, info.GossipMode, info.RouteAccount) }) // If we are processing an implicit route from a route that does not // support pooling/pinned-accounts, we won't receive an INFO for each of // the pinned-accounts that we would normally receive. In that case, just @@ -1063,7 +1052,7 @@ func (s *Server) processImplicitRoute(info *Info, routeNoPool bool) { rURL := r for _, an := range opts.Cluster.PinnedAccounts { accName := an - s.startGoRoutine(func() { s.connectToRoute(rURL, false, true, accName) }) + s.startGoRoutine(func() { s.connectToRoute(rURL, Implicit, true, info.GossipMode, accName) }) } } } @@ -1102,26 +1091,89 @@ func (s *Server) hasThisRouteConfigured(info *Info) bool { return false } -// forwardNewRouteInfoToKnownServers sends the INFO protocol of the new route -// to all routes known by this server. In turn, each server will contact this -// new route. +// forwardNewRouteInfoToKnownServers possibly sends the INFO protocol of the +// new route to all routes known by this server. In turn, each server will +// contact this new route. // Server lock held on entry. -func (s *Server) forwardNewRouteInfoToKnownServers(info *Info) { +func (s *Server) forwardNewRouteInfoToKnownServers(info *Info, rtype RouteType, didSolicit bool, localGossipMode byte) { + // Determine if this connection is resulting from a gossip notification. + fromGossip := didSolicit && rtype == Implicit + // If from gossip (but we are not overriding it) or if the remote disabled gossip, bail out. + if (fromGossip && localGossipMode != gossipOverride) || info.GossipMode == gossipDisabled { + return + } + // Note: nonce is not used in routes. // That being said, the info we get is the initial INFO which // contains a nonce, but we now forward this to existing routes, // so clear it now. info.Nonce = _EMPTY_ - b, _ := json.Marshal(info) - infoJSON := []byte(fmt.Sprintf(InfoProto, b)) + + var ( + infoGMDefault []byte + infoGMDisabled []byte + infoGMOverride []byte + ) + + generateJSON := func(gm byte) []byte { + info.GossipMode = gm + b, _ := json.Marshal(info) + return []byte(fmt.Sprintf(InfoProto, b)) + } + + getJSON := func(r *client) []byte { + if (!didSolicit && r.route.routeType == Explicit) || (didSolicit && rtype == Explicit) { + if infoGMOverride == nil { + infoGMOverride = generateJSON(gossipOverride) + } + return infoGMOverride + } else if !didSolicit { + if infoGMDisabled == nil { + infoGMDisabled = generateJSON(gossipDisabled) + } + return infoGMDisabled + } + if infoGMDefault == nil { + infoGMDefault = generateJSON(0) + } + return infoGMDefault + } + + var accRemotes map[string]*client + pinnedAccount := info.RouteAccount != _EMPTY_ + // If this is for a pinned account, we will try to send the gossip + // through our pinned account routes, but fall back to the other + // routes in case we don't have one for a given remote. + if pinnedAccount { + var ok bool + if accRemotes, ok = s.accRoutes[info.RouteAccount]; ok { + for remoteID, r := range accRemotes { + if r == nil { + continue + } + r.mu.Lock() + // Do not send to a remote that does not support pooling/pinned-accounts. + if remoteID != info.ID && !r.route.noPool { + r.enqueueProto(getJSON(r)) + } + r.mu.Unlock() + } + } + } s.forEachRemote(func(r *client) { r.mu.Lock() + remoteID := r.route.remoteID + if pinnedAccount { + if _, processed := accRemotes[remoteID]; processed { + r.mu.Unlock() + return + } + } // If this is a new route for a given account, do not send to a server // that does not support pooling/pinned-accounts. - if r.route.remoteID != info.ID && - (info.RouteAccount == _EMPTY_ || (info.RouteAccount != _EMPTY_ && !r.route.noPool)) { - r.enqueueProto(infoJSON) + if remoteID != info.ID && (!pinnedAccount || !r.route.noPool) { + r.enqueueProto(getJSON(r)) } r.mu.Unlock() }) @@ -1839,17 +1891,12 @@ func (c *client) sendRouteSubOrUnSubProtos(subs []*subscription, isSubProto, tra c.enqueueProto(buf) } -func (s *Server) createRoute(conn net.Conn, rURL *url.URL, accName string) *client { +func (s *Server) createRoute(conn net.Conn, rURL *url.URL, rtype RouteType, gossipMode byte, accName string) *client { // Snapshot server options. opts := s.getOpts() didSolicit := rURL != nil - r := &route{didSolicit: didSolicit, poolIdx: -1} - for _, route := range opts.Routes { - if rURL != nil && (strings.EqualFold(rURL.Host, route.Host)) { - r.routeType = Explicit - } - } + r := &route{routeType: rtype, didSolicit: didSolicit, poolIdx: -1, gossipMode: gossipMode} c := &client{srv: s, nc: conn, opts: ClientOpts{}, kind: ROUTER, msubs: -1, mpay: -1, route: r, start: time.Now()} @@ -1866,7 +1913,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL, accName string) *clie // the incoming INFO from the remote. Also delay if configured for compression. delayInfo := didSolicit && (compressionConfigured || routeShouldDelayInfo(accName, opts)) if !delayInfo { - infoJSON = s.generateRouteInitialInfoJSON(accName, opts.Cluster.Compression.Mode, 0) + infoJSON = s.generateRouteInitialInfoJSON(accName, opts.Cluster.Compression.Mode, 0, gossipMode) } authRequired := s.routeInfo.AuthRequired tlsRequired := s.routeInfo.TLSRequired @@ -1999,7 +2046,7 @@ func routeShouldDelayInfo(accName string, opts *Options) bool { // To be used only when a route is created (to send the initial INFO protocol). // // Server lock held on entry. -func (s *Server) generateRouteInitialInfoJSON(accName, compression string, poolIdx int) []byte { +func (s *Server) generateRouteInitialInfoJSON(accName, compression string, poolIdx int, gossipMode byte) []byte { // New proto wants a nonce (although not used in routes, that is, not signed in CONNECT) var raw [nonceLen]byte nonce := raw[:] @@ -2009,11 +2056,11 @@ func (s *Server) generateRouteInitialInfoJSON(accName, compression string, poolI if s.getOpts().Cluster.Compression.Mode == CompressionS2Auto { compression = CompressionS2Auto } - ri.Nonce, ri.RouteAccount, ri.RoutePoolIdx, ri.Compression = string(nonce), accName, poolIdx, compression + ri.Nonce, ri.RouteAccount, ri.RoutePoolIdx, ri.Compression, ri.GossipMode = string(nonce), accName, poolIdx, compression, gossipMode infoJSON := generateInfoJSON(&s.routeInfo) // Clear now that it has been serialized. Will prevent nonce to be included in async INFO that we may send. // Same for some other fields. - ri.Nonce, ri.RouteAccount, ri.RoutePoolIdx, ri.Compression = _EMPTY_, _EMPTY_, 0, _EMPTY_ + ri.Nonce, ri.RouteAccount, ri.RoutePoolIdx, ri.Compression, ri.GossipMode = _EMPTY_, _EMPTY_, 0, _EMPTY_, 0 return infoJSON } @@ -2022,7 +2069,7 @@ const ( _EMPTY_ = "" ) -func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string) bool { +func (s *Server) addRoute(c *client, didSolicit, sendDelayedInfo bool, gossipMode byte, info *Info, accName string) bool { id := info.ID var acc *Account @@ -2108,6 +2155,11 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string c.mu.Lock() idHash := c.route.idHash cid := c.cid + rtype := c.route.routeType + if sendDelayedInfo { + cm := compressionModeForInfoProtocol(&opts.Cluster.Compression, c.route.compression) + c.enqueueProto(s.generateRouteInitialInfoJSON(accName, cm, 0, gossipMode)) + } if c.last.IsZero() { c.last = time.Now() } @@ -2122,8 +2174,10 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string // Now that we have registered the route, we can remove from the temp map. s.removeFromTempClients(cid) - // Notify other routes about this new route - s.forwardNewRouteInfoToKnownServers(info) + // We don't need to send if the only route is the one we just accepted. + if len(conns) > 1 { + s.forwardNewRouteInfoToKnownServers(info, rtype, didSolicit, gossipMode) + } // Send subscription interest s.sendSubsToRoute(c, -1, accName) @@ -2204,9 +2258,9 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string rHash := c.route.hash rn := c.route.remoteName url := c.route.url - // For solicited routes, we need now to send the INFO protocol. - if didSolicit { - c.enqueueProto(s.generateRouteInitialInfoJSON(_EMPTY_, c.route.compression, idx)) + if sendDelayedInfo { + cm := compressionModeForInfoProtocol(&opts.Cluster.Compression, c.route.compression) + c.enqueueProto(s.generateRouteInitialInfoJSON(_EMPTY_, cm, idx, gossipMode)) } if c.last.IsZero() { c.last = time.Now() @@ -2224,7 +2278,7 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string // check to be consistent and future proof. but will be same domain if s.sameDomain(info.Domain) { s.nodeToInfo.Store(rHash, - nodeInfo{rn, s.info.Version, s.info.Cluster, info.Domain, id, nil, nil, nil, false, info.JetStream, false}) + nodeInfo{rn, s.info.Version, s.info.Cluster, info.Domain, id, nil, nil, nil, false, info.JetStream, false, false}) } } @@ -2243,10 +2297,9 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string s.sendAsyncGatewayInfo() } - // we don't need to send if the only route is the one we just accepted. + // We don't need to send if the only route is the one we just accepted. if len(s.routes) > 1 { - // Now let the known servers know about this new route - s.forwardNewRouteInfoToKnownServers(info) + s.forwardNewRouteInfoToKnownServers(info, rtype, didSolicit, gossipMode) } // Send info about the known gateways to this route. @@ -2281,7 +2334,7 @@ func (s *Server) addRoute(c *client, didSolicit bool, info *Info, accName string s.grWG.Done() return } - s.connectToRoute(url, rtype == Explicit, true, _EMPTY_) + s.connectToRoute(url, rtype, true, gossipMode, _EMPTY_) }) } } @@ -2596,17 +2649,6 @@ func (s *Server) startRouteAcceptLoop() { s.Noticef("Listening for route connections on %s", net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) - proto := RouteProtoV2 - // For tests, we want to be able to make this server behave - // as an older server so check this option to see if we should override - if opts.routeProto < 0 { - // We have a private option that allows test to override the route - // protocol. We want this option initial value to be 0, however, - // since original proto is RouteProtoZero, tests call setRouteProtoForTest(), - // which sets as negative value the (desired proto + 1) * -1. - // Here we compute back the real value. - proto = (opts.routeProto * -1) - 1 - } // Check for TLSConfig tlsReq := opts.Cluster.TLSConfig != nil info := Info{ @@ -2619,7 +2661,7 @@ func (s *Server) startRouteAcceptLoop() { TLSVerify: tlsReq, MaxPayload: s.info.MaxPayload, JetStream: s.info.JetStream, - Proto: proto, + Proto: s.getServerProto(), GatewayURL: s.getGatewayURL(), Headers: s.supportsHeaders(), Cluster: s.info.Cluster, @@ -2696,7 +2738,7 @@ func (s *Server) startRouteAcceptLoop() { } // Start the accept loop in a different go routine. - go s.acceptConnections(l, "Route", func(conn net.Conn) { s.createRoute(conn, nil, _EMPTY_) }, nil) + go s.acceptConnections(l, "Route", func(conn net.Conn) { s.createRoute(conn, nil, Implicit, gossipDefault, _EMPTY_) }, nil) // Solicit Routes if applicable. This will not block. s.solicitRoutes(opts.Routes, opts.Cluster.PinnedAccounts) @@ -2728,7 +2770,7 @@ func (s *Server) setRouteInfoHostPortAndIP() error { func (s *Server) StartRouting(clientListenReady chan struct{}) { defer s.grWG.Done() - // Wait for the client and and leafnode listen ports to be opened, + // Wait for the client and leafnode listen ports to be opened, // and the possible ephemeral ports to be selected. <-clientListenReady @@ -2738,14 +2780,13 @@ func (s *Server) StartRouting(clientListenReady chan struct{}) { } func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType, accName string) { - tryForEver := rtype == Explicit // If A connects to B, and B to A (regardless if explicit or // implicit - due to auto-discovery), and if each server first // registers the route on the opposite TCP connection, the // two connections will end-up being closed. // Add some random delay to reduce risk of repeated failures. delay := time.Duration(rand.Intn(100)) * time.Millisecond - if tryForEver { + if rtype == Explicit { delay += DEFAULT_ROUTE_RECONNECT } select { @@ -2754,7 +2795,7 @@ func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType, accName string s.grWG.Done() return } - s.connectToRoute(rURL, tryForEver, false, accName) + s.connectToRoute(rURL, rtype, false, gossipDefault, accName) } // Checks to make sure the route is still valid. @@ -2767,21 +2808,26 @@ func (s *Server) routeStillValid(rURL *url.URL) bool { return false } -func (s *Server) connectToRoute(rURL *url.URL, tryForEver, firstConnect bool, accName string) { +func (s *Server) connectToRoute(rURL *url.URL, rtype RouteType, firstConnect bool, gossipMode byte, accName string) { + defer s.grWG.Done() + if rURL == nil { + return + } + // For explicit routes, we will try to connect until we succeed. For implicit + // we will try only based on the number of ConnectRetries optin. + tryForEver := rtype == Explicit + // Snapshot server options. opts := s.getOpts() - defer s.grWG.Done() - const connErrFmt = "Error trying to connect to route (attempt %v): %v" - s.mu.Lock() + s.mu.RLock() resolver := s.routeResolver excludedAddresses := s.routesToSelf - s.mu.Unlock() + s.mu.RUnlock() - attempts := 0 - for s.isRunning() && rURL != nil { + for attempts := 0; s.isRunning(); { if tryForEver { if !s.routeStillValid(rURL) { return @@ -2835,7 +2881,7 @@ func (s *Server) connectToRoute(rURL *url.URL, tryForEver, firstConnect bool, ac // We have a route connection here. // Go ahead and create it and exit this func. - s.createRoute(conn, rURL, accName) + s.createRoute(conn, rURL, rtype, gossipMode, accName) return } } @@ -2864,13 +2910,13 @@ func (s *Server) solicitRoutes(routes []*url.URL, accounts []string) { s.saveRouteTLSName(routes) for _, r := range routes { route := r - s.startGoRoutine(func() { s.connectToRoute(route, true, true, _EMPTY_) }) + s.startGoRoutine(func() { s.connectToRoute(route, Explicit, true, gossipDefault, _EMPTY_) }) } // Now go over possible per-account routes and create them. for _, an := range accounts { for _, r := range routes { route, accName := r, an - s.startGoRoutine(func() { s.connectToRoute(route, true, true, accName) }) + s.startGoRoutine(func() { s.connectToRoute(route, Explicit, true, gossipDefault, accName) }) } } } @@ -2993,7 +3039,7 @@ func (s *Server) removeRoute(c *client) { opts = s.getOpts() rURL *url.URL noPool bool - didSolicit bool + rtype RouteType ) c.mu.Lock() cid := c.cid @@ -3012,7 +3058,7 @@ func (s *Server) removeRoute(c *client) { connectURLs = r.connectURLs wsConnectURLs = r.wsConnURLs rURL = r.url - didSolicit = r.didSolicit + rtype = r.routeType } c.mu.Unlock() if accName != _EMPTY_ { @@ -3075,12 +3121,12 @@ func (s *Server) removeRoute(c *client) { // this remote was a "no pool" route, attempt to reconnect. if noPool { if s.routesPoolSize > 1 { - s.startGoRoutine(func() { s.connectToRoute(rURL, didSolicit, true, _EMPTY_) }) + s.startGoRoutine(func() { s.connectToRoute(rURL, rtype, true, gossipDefault, _EMPTY_) }) } if len(opts.Cluster.PinnedAccounts) > 0 { for _, an := range opts.Cluster.PinnedAccounts { accName := an - s.startGoRoutine(func() { s.connectToRoute(rURL, didSolicit, true, accName) }) + s.startGoRoutine(func() { s.connectToRoute(rURL, rtype, true, gossipDefault, accName) }) } } } diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sendq.go b/vendor/github.com/nats-io/nats-server/v2/server/sendq.go index 178ec5d76c..5018482db5 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sendq.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sendq.go @@ -29,10 +29,11 @@ type sendq struct { mu sync.Mutex q *ipQueue[*outMsg] s *Server + a *Account } -func (s *Server) newSendQ() *sendq { - sq := &sendq{s: s, q: newIPQueue[*outMsg](s, "SendQ")} +func (s *Server) newSendQ(acc *Account) *sendq { + sq := &sendq{s: s, q: newIPQueue[*outMsg](s, "SendQ"), a: acc} s.startGoRoutine(sq.internalLoop) return sq } @@ -45,7 +46,7 @@ func (sq *sendq) internalLoop() { defer s.grWG.Done() c := s.createInternalSystemClient() - c.registerWithAccount(s.SystemAccount()) + c.registerWithAccount(sq.a) c.noIcb = true defer c.closeConnection(ClientClosed) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/server.go b/vendor/github.com/nats-io/nats-server/v2/server/server.go index 81013d1e1b..dd2867d62d 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/server.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/server.go @@ -1,4 +1,4 @@ -// Copyright 2012-2024 The NATS Authors +// Copyright 2012-2025 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -28,28 +28,26 @@ import ( "net" "net/http" "net/url" - "regexp" - "runtime/pprof" - "unicode" - - // Allow dynamic profiling. - _ "net/http/pprof" "os" "path" "path/filepath" + "regexp" "runtime" + "runtime/pprof" "strconv" "strings" "sync" "sync/atomic" "time" + // Allow dynamic profiling. + _ "net/http/pprof" + "github.com/klauspost/compress/s2" "github.com/nats-io/jwt/v2" + "github.com/nats-io/nats-server/v2/logger" "github.com/nats-io/nkeys" "github.com/nats-io/nuid" - - "github.com/nats-io/nats-server/v2/logger" ) const ( @@ -60,6 +58,49 @@ const ( firstClientPingInterval = 2 * time.Second ) +// These are protocol versions sent between server connections: ROUTER, LEAF and +// GATEWAY. We may have protocol versions that have a meaning only for a certain +// type of connections, but we don't have to have separate enums for that. +// However, it is CRITICAL to not change the order of those constants since they +// are exchanged between servers. When adding a new protocol version, add to the +// end of the list, don't try to group them by connection types. +const ( + // RouteProtoZero is the original Route protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + RouteProtoZero = iota + // RouteProtoInfo signals a route can receive more then the original INFO block. + // This can be used to update remote cluster permissions, etc... + RouteProtoInfo + // RouteProtoV2 is the new route/cluster protocol that provides account support. + RouteProtoV2 + // MsgTraceProto indicates that this server understands distributed message tracing. + MsgTraceProto +) + +// Will return the latest server-to-server protocol versions, unless the +// option to override it is set. +func (s *Server) getServerProto() int { + opts := s.getOpts() + // Initialize with the latest protocol version. + proto := MsgTraceProto + // For tests, we want to be able to make this server behave + // as an older server so check this option to see if we should override. + if opts.overrideProto < 0 { + // The option overrideProto is set to 0 by default (when creating an + // Options structure). Since this is the same value than the original + // proto RouteProtoZero, tests call setServerProtoForTest() with the + // desired protocol level, which sets it as negative value equal to: + // (wantedProto + 1) * -1. Here we compute back the real value. + proto = (opts.overrideProto * -1) - 1 + } + return proto +} + +// Used by tests. +func setServerProtoForTest(wantedProto int) int { + return (wantedProto + 1) * -1 +} + // Info is the information sent to clients, routes, gateways, and leaf nodes, // to help them understand information about this server. type Info struct { @@ -101,6 +142,7 @@ type Info struct { RoutePoolIdx int `json:"route_pool_idx,omitempty"` RouteAccount string `json:"route_account,omitempty"` RouteAccReqID string `json:"route_acc_add_reqid,omitempty"` + GossipMode byte `json:"gossip_mode,omitempty"` // Gateways Specific Gateway string `json:"gateway,omitempty"` // Name of the origin Gateway (sent by gateway's INFO) @@ -189,6 +231,7 @@ type Server struct { leafRemoteAccounts sync.Map leafNodeEnabled bool leafDisableConnect bool // Used in test only + leafNoCluster bool // Indicate that this server has only remotes and no cluster defined quitCh chan struct{} startupComplete chan struct{} @@ -320,6 +363,14 @@ type Server struct { // Queue to process JS API requests that come from routes (or gateways) jsAPIRoutedReqs *ipQueue[*jsAPIRoutedReq] + + // Delayed API responses. + delayedAPIResponses *ipQueue[*delayedAPIResponse] + + // Whether moving NRG traffic into accounts is permitted on this server. + // Controls whether or not the account NRG capability is set in statsz. + // Currently used by unit tests to simulate nodes not supporting account NRG. + accountNRGAllowed atomic.Bool } // For tracking JS nodes. @@ -335,6 +386,7 @@ type nodeInfo struct { offline bool js bool binarySnapshots bool + accountNRG bool } // Make sure all are 64bits for atomic use @@ -683,6 +735,14 @@ func NewServer(opts *Options) (*Server, error) { syncOutSem: make(chan struct{}, maxConcurrentSyncRequests), } + // Delayed API response queue. Create regardless if JetStream is configured + // or not (since it can be enabled/disabled with config reload, we want this + // queue to exist at all times). + s.delayedAPIResponses = newIPQueue[*delayedAPIResponse](s, "delayed API responses") + + // By default we'll allow account NRG. + s.accountNRGAllowed.Store(true) + // Fill up the maximum in flight syncRequests for this server. // Used in JetStream catchup semantics. for i := 0; i < maxConcurrentSyncRequests; i++ { @@ -701,6 +761,7 @@ func NewServer(opts *Options) (*Server, error) { // If we have solicited leafnodes but no clustering and no clustername. // However we may need a stable clustername so use the server name. if len(opts.LeafNode.Remotes) > 0 && opts.Cluster.Port == 0 && opts.Cluster.Name == _EMPTY_ { + s.leafNoCluster = true opts.Cluster.Name = opts.ServerName } @@ -726,7 +787,7 @@ func NewServer(opts *Options) (*Server, error) { opts.Tags, &JetStreamConfig{MaxMemory: opts.JetStreamMaxMemory, MaxStore: opts.JetStreamMaxStore, CompressOK: true}, nil, - false, true, true, + false, true, true, true, }) } @@ -968,6 +1029,9 @@ func (s *Server) ClientURL() string { } func validateCluster(o *Options) error { + if o.Cluster.Name != _EMPTY_ && strings.Contains(o.Cluster.Name, " ") { + return ErrClusterNameHasSpaces + } if o.Cluster.Compression.Mode != _EMPTY_ { if err := validateAndNormalizeCompressionOption(&o.Cluster.Compression, CompressionS2Fast); err != nil { return err @@ -977,8 +1041,9 @@ func validateCluster(o *Options) error { return fmt.Errorf("cluster: %v", err) } // Check that cluster name if defined matches any gateway name. - if o.Gateway.Name != "" && o.Gateway.Name != o.Cluster.Name { - if o.Cluster.Name != "" { + // Note that we have already verified that the gateway name does not have spaces. + if o.Gateway.Name != _EMPTY_ && o.Gateway.Name != o.Cluster.Name { + if o.Cluster.Name != _EMPTY_ { return ErrClusterNameConfigConflict } // Set this here so we do not consider it dynamic. @@ -1019,6 +1084,9 @@ func validateOptions(o *Options) error { return fmt.Errorf("max_payload (%v) cannot be higher than max_pending (%v)", o.MaxPayload, o.MaxPending) } + if o.ServerName != _EMPTY_ && strings.Contains(o.ServerName, " ") { + return errors.New("server name cannot contain spaces") + } // Check that the trust configuration is correct. if err := validateTrustedOperators(o); err != nil { return err @@ -1315,8 +1383,9 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) // Add any required exports from system account. if s.sys != nil { + sysAcc := s.sys.account s.mu.Unlock() - s.addSystemAccountExports(s.sys.account) + s.addSystemAccountExports(sysAcc) s.mu.Lock() } @@ -1702,7 +1771,7 @@ func (s *Server) setSystemAccount(acc *Account) error { recvq: newIPQueue[*inSysMsg](s, "System recvQ"), recvqp: newIPQueue[*inSysMsg](s, "System recvQ Pings"), resetCh: make(chan struct{}), - sq: s.newSendQ(), + sq: s.newSendQ(acc), statsz: statsHBInterval, orphMax: 5 * eventsHBInterval, chkOrph: 3 * eventsHBInterval, @@ -2120,7 +2189,17 @@ func (s *Server) Start() { // Snapshot server options. opts := s.getOpts() - clusterName := s.ClusterName() + + // Capture if this server is a leaf that has no cluster, so we don't + // display the cluster name if that is the case. + s.mu.RLock() + leafNoCluster := s.leafNoCluster + s.mu.RUnlock() + + var clusterName string + if !leafNoCluster { + clusterName = s.ClusterName() + } s.Noticef(" Version: %s", VERSION) s.Noticef(" Git: [%s]", gc) @@ -2165,7 +2244,11 @@ func (s *Server) Start() { } if opts.ConfigFile != _EMPTY_ { - s.Noticef("Using configuration file: %s", opts.ConfigFile) + var cd string + if opts.configDigest != "" { + cd = fmt.Sprintf("(%s)", opts.configDigest) + } + s.Noticef("Using configuration file: %s %s", opts.ConfigFile, cd) } hasOperators := len(opts.TrustedOperators) > 0 @@ -2280,6 +2363,7 @@ func (s *Server) Start() { StoreDir: opts.StoreDir, SyncInterval: opts.SyncInterval, SyncAlways: opts.SyncAlways, + Strict: opts.JetStreamStrict, MaxMemory: opts.JetStreamMaxMemory, MaxStore: opts.JetStreamMaxStore, Domain: opts.JetStreamDomain, @@ -2326,6 +2410,11 @@ func (s *Server) Start() { } } + // Delayed API response handling. Start regardless of JetStream being + // currently configured or not (since it can be enabled/disabled with + // configuration reload). + s.startGoRoutine(s.delayedAPIResponder) + // Start OCSP Stapling monitoring for TLS certificates if enabled. Hook TLS handshake for // OCSP check on peers (LEAF and CLIENT kind) if enabled. s.startOCSPMonitoring() @@ -2357,9 +2446,6 @@ func (s *Server) Start() { // Solicit remote servers for leaf node connections. if len(opts.LeafNode.Remotes) > 0 { s.solicitLeafNodeRemotes(opts.LeafNode.Remotes) - if opts.Cluster.Name == opts.ServerName && strings.ContainsFunc(opts.Cluster.Name, unicode.IsSpace) { - s.Warnf("Server name has spaces and used as the cluster name, leaf remotes may not connect properly") - } } // TODO (ik): I wanted to refactor this by starting the client @@ -3075,7 +3161,16 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } now := time.Now() - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now} + c := &client{ + srv: s, + nc: conn, + opts: defaultOpts, + mpay: maxPay, + msubs: maxSubs, + start: now, + last: now, + iproc: inProcess, + } c.registerWithAccount(s.globalAccount()) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/service.go b/vendor/github.com/nats-io/nats-server/v2/server/service.go index ab3239b380..7822206a62 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/service.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/service.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build !windows -// +build !windows package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/signal.go b/vendor/github.com/nats-io/nats-server/v2/server/signal.go index 18b37a0222..24c0827f38 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/signal.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/signal.go @@ -1,4 +1,4 @@ -// Copyright 2012-2024 The NATS Authors +// Copyright 2012-2025 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -12,7 +12,6 @@ // limitations under the License. //go:build !windows && !wasm -// +build !windows,!wasm package server @@ -62,7 +61,7 @@ func (s *Server) handleSignals() { if !ldm { s.Shutdown() s.WaitForShutdown() - os.Exit(1) + os.Exit(0) } case syscall.SIGUSR1: // File log re-open for rotating file logs. diff --git a/vendor/github.com/nats-io/nats-server/v2/server/signal_wasm.go b/vendor/github.com/nats-io/nats-server/v2/server/signal_wasm.go index ce9088f72b..7788d3ffeb 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/signal_wasm.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/signal_wasm.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build wasm -// +build wasm package server diff --git a/vendor/github.com/nats-io/nats-server/v2/server/store.go b/vendor/github.com/nats-io/nats-server/v2/server/store.go index 03ef7b29cd..6f05561e13 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/store.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/store.go @@ -66,6 +66,8 @@ var ( ErrSequenceMismatch = errors.New("expected sequence does not match store") // ErrCorruptStreamState ErrCorruptStreamState = errors.New("stream state snapshot is corrupt") + // ErrTooManyResults + ErrTooManyResults = errors.New("too many matching results for request") ) // StoreMsg is the stored message format for messages that are retained by the Store layer. @@ -82,9 +84,12 @@ type StoreMsg struct { // For the cases where its a single message we will also supply sequence number and subject. type StorageUpdateHandler func(msgs, bytes int64, seq uint64, subj string) +// Used to call back into the upper layers to report on newly created subject delete markers. +type SubjectDeleteMarkerUpdateHandler func(*inMsg) + type StreamStore interface { - StoreMsg(subject string, hdr, msg []byte) (uint64, int64, error) - StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64) error + StoreMsg(subject string, hdr, msg []byte, ttl int64) (uint64, int64, error) + StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64, ttl int64) error SkipMsg() uint64 SkipMsgs(seq uint64, num uint64) error LoadMsg(seq uint64, sm *StoreMsg) (*StoreMsg, error) @@ -95,13 +100,14 @@ type StreamStore interface { RemoveMsg(seq uint64) (bool, error) EraseMsg(seq uint64) (bool, error) Purge() (uint64, error) - PurgeEx(subject string, seq, keep uint64) (uint64, error) + PurgeEx(subject string, seq, keep uint64, noMarkers bool) (uint64, error) Compact(seq uint64) (uint64, error) Truncate(seq uint64) error GetSeqFromTime(t time.Time) uint64 FilteredState(seq uint64, subject string) SimpleState SubjectsState(filterSubject string) map[string]SimpleState SubjectsTotals(filterSubject string) map[string]uint64 + MultiLastSeqs(filters []string, maxSeq uint64, maxAllowed int) ([]uint64, error) NumPending(sseq uint64, filter string, lastPerSubject bool) (total, validThrough uint64) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject bool) (total, validThrough uint64) State() StreamState @@ -110,6 +116,7 @@ type StreamStore interface { SyncDeleted(dbs DeleteBlocks) Type() StorageType RegisterStorageUpdates(StorageUpdateHandler) + RegisterSubjectDeleteMarkerUpdates(SubjectDeleteMarkerUpdateHandler) UpdateConfig(cfg *StreamConfig) error Delete() error Stop() error @@ -182,6 +189,7 @@ type LostStreamData struct { type SnapshotResult struct { Reader io.ReadCloser State StreamState + errCh chan string } const ( @@ -342,6 +350,7 @@ func (dbs DeleteBlocks) NumDeleted() (total uint64) { // ConsumerStore stores state on consumers for streams. type ConsumerStore interface { SetStarting(sseq uint64) error + UpdateStarting(sseq uint64) HasState() bool UpdateDelivered(dseq, sseq, dc uint64, ts int64) error UpdateAcks(dseq, sseq uint64) error diff --git a/vendor/github.com/nats-io/nats-server/v2/server/stream.go b/vendor/github.com/nats-io/nats-server/v2/server/stream.go index e7d7512e42..bf70f5c424 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/stream.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/stream.go @@ -36,6 +36,14 @@ import ( "github.com/nats-io/nuid" ) +// StreamConfigRequest is used to create or update a stream. +type StreamConfigRequest struct { + StreamConfig + // This is not part of the StreamConfig, because its scoped to request, + // and not to the stream itself. + Pedantic bool `json:"pedantic,omitempty"` +} + // StreamConfig will determine the name, subjects and retention policy // for a given stream. If subjects is empty the name will be used. type StreamConfig struct { @@ -92,10 +100,54 @@ type StreamConfig struct { // TODO(nat): Can/should we name these better? ConsumerLimits StreamConsumerLimits `json:"consumer_limits"` + // AllowMsgTTL allows header initiated per-message TTLs. If disabled, + // then the `NATS-TTL` header will be ignored. + AllowMsgTTL bool `json:"allow_msg_ttl"` + + // SubjectDeleteMarkerTTL sets the TTL of delete marker messages left behind by + // subject delete markers. + SubjectDeleteMarkerTTL time.Duration `json:"subject_delete_marker_ttl,omitempty"` + // Metadata is additional metadata for the Stream. Metadata map[string]string `json:"metadata,omitempty"` } +// clone performs a deep copy of the StreamConfig struct, returning a new clone with +// all values copied. +func (cfg *StreamConfig) clone() *StreamConfig { + clone := *cfg + if cfg.Placement != nil { + placement := *cfg.Placement + clone.Placement = &placement + } + if cfg.Mirror != nil { + mirror := *cfg.Mirror + clone.Mirror = &mirror + } + if len(cfg.Sources) > 0 { + clone.Sources = make([]*StreamSource, len(cfg.Sources)) + for i, cfgSource := range cfg.Sources { + source := *cfgSource + clone.Sources[i] = &source + } + } + if cfg.SubjectTransform != nil { + transform := *cfg.SubjectTransform + clone.SubjectTransform = &transform + } + if cfg.RePublish != nil { + rePublish := *cfg.RePublish + clone.RePublish = &rePublish + } + if cfg.Metadata != nil { + clone.Metadata = make(map[string]string, len(cfg.Metadata)) + for k, v := range cfg.Metadata { + clone.Metadata[k] = v + } + } + return &clone +} + type StreamConsumerLimits struct { InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"` MaxAckPending int `json:"max_ack_pending,omitempty"` @@ -210,6 +262,12 @@ type ExternalStream struct { DeliverPrefix string `json:"deliver"` } +// For managing stream ingest. +const ( + streamDefaultMaxQueueMsgs = 10_000 + streamDefaultMaxQueueBytes = 1024 * 1024 * 128 +) + // Stream is a jetstream stream of messages. When we receive a message internally destined // for a Stream we will direct link from the client to this structure. type stream struct { @@ -276,26 +334,32 @@ type stream struct { sigq *ipQueue[*cMsg] // Intra-process queue for the messages to signal to the consumers. csl *gsl.GenericSublist[*consumer] // Consumer subscription list. + // Leader will store seq/msgTrace in clustering mode. Used in applyStreamEntries + // to know if trace event should be sent after processing. + mt map[uint64]*msgTrace + // For non limits policy streams when they process an ack before the actual msg. // Can happen in stretch clusters, multi-cloud, or during catchup for a restarted server. preAcks map[uint64]map[*consumer]struct{} // TODO(dlc) - Hide everything below behind two pointers. // Clustered mode. - sa *streamAssignment // What the meta controller uses to assign streams to peers. - node RaftNode // Our RAFT node for the stream's group. - catchup atomic.Bool // Used to signal we are in catchup mode. - catchups map[string]uint64 // The number of messages that need to be caught per peer. - syncSub *subscription // Internal subscription for sync messages (on "$JSC.SYNC"). - infoSub *subscription // Internal subscription for stream info requests. - clMu sync.Mutex // The mutex for clseq and clfs. - clseq uint64 // The current last seq being proposed to the NRG layer. - clfs uint64 // The count (offset) of the number of failed NRG sequences used to compute clseq. - inflight map[uint64]uint64 // Inflight message sizes per clseq. - lqsent time.Time // The time at which the last lost quorum advisory was sent. Used to rate limit. - uch chan struct{} // The channel to signal updates to the monitor routine. - compressOK bool // True if we can do message compression in RAFT and catchup logic - inMonitor bool // True if the monitor routine has been started. + sa *streamAssignment // What the meta controller uses to assign streams to peers. + node RaftNode // Our RAFT node for the stream's group. + catchup atomic.Bool // Used to signal we are in catchup mode. + catchups map[string]uint64 // The number of messages that need to be caught per peer. + syncSub *subscription // Internal subscription for sync messages (on "$JSC.SYNC"). + infoSub *subscription // Internal subscription for stream info requests. + clMu sync.Mutex // The mutex for clseq and clfs. + clseq uint64 // The current last seq being proposed to the NRG layer. + clfs uint64 // The count (offset) of the number of failed NRG sequences used to compute clseq. + inflight map[uint64]uint64 // Inflight message sizes per clseq. + lqsent time.Time // The time at which the last lost quorum advisory was sent. Used to rate limit. + uch chan struct{} // The channel to signal updates to the monitor routine. + inMonitor bool // True if the monitor routine has been started. + + expectedPerSubjectSequence map[uint64]string // Inflight 'expected per subject' subjects per clseq. + expectedPerSubjectInProcess map[string]struct{} // Current 'expected per subject' subjects in process. // Direct get subscription. directSub *subscription @@ -342,18 +406,21 @@ const ( // Headers for published messages. const ( - JSMsgId = "Nats-Msg-Id" - JSExpectedStream = "Nats-Expected-Stream" - JSExpectedLastSeq = "Nats-Expected-Last-Sequence" - JSExpectedLastSubjSeq = "Nats-Expected-Last-Subject-Sequence" - JSExpectedLastMsgId = "Nats-Expected-Last-Msg-Id" - JSStreamSource = "Nats-Stream-Source" - JSLastConsumerSeq = "Nats-Last-Consumer" - JSLastStreamSeq = "Nats-Last-Stream" - JSConsumerStalled = "Nats-Consumer-Stalled" - JSMsgRollup = "Nats-Rollup" - JSMsgSize = "Nats-Msg-Size" - JSResponseType = "Nats-Response-Type" + JSMsgId = "Nats-Msg-Id" + JSExpectedStream = "Nats-Expected-Stream" + JSExpectedLastSeq = "Nats-Expected-Last-Sequence" + JSExpectedLastSubjSeq = "Nats-Expected-Last-Subject-Sequence" + JSExpectedLastSubjSeqSubj = "Nats-Expected-Last-Subject-Sequence-Subject" + JSExpectedLastMsgId = "Nats-Expected-Last-Msg-Id" + JSStreamSource = "Nats-Stream-Source" + JSLastConsumerSeq = "Nats-Last-Consumer" + JSLastStreamSeq = "Nats-Last-Stream" + JSConsumerStalled = "Nats-Consumer-Stalled" + JSMsgRollup = "Nats-Rollup" + JSMsgSize = "Nats-Msg-Size" + JSResponseType = "Nats-Response-Type" + JSMessageTTL = "Nats-TTL" + JSMarkerReason = "Nats-Marker-Reason" ) // Headers for republished messages and direct gets. @@ -363,6 +430,8 @@ const ( JSTimeStamp = "Nats-Time-Stamp" JSSubject = "Nats-Subject" JSLastSequence = "Nats-Last-Sequence" + JSNumPending = "Nats-Num-Pending" + JSUpToSequence = "Nats-UpTo-Sequence" ) // Rollups, can be subject only or all messages. @@ -371,6 +440,13 @@ const ( JSMsgRollupAll = "all" ) +// Applied limits in the Nats-Applied-Limit header. +const ( + JSMarkerReasonMaxAge = "MaxAge" + JSMarkerReasonPurge = "Purge" + JSMarkerReasonRemove = "Remove" +) + const ( jsCreateResponse = "create" ) @@ -387,15 +463,19 @@ const StreamMaxReplicas = 5 // AddStream adds a stream for the given account. func (a *Account) addStream(config *StreamConfig) (*stream, error) { - return a.addStreamWithAssignment(config, nil, nil) + return a.addStreamWithAssignment(config, nil, nil, false) } // AddStreamWithStore adds a stream for the given account with custome store config options. func (a *Account) addStreamWithStore(config *StreamConfig, fsConfig *FileStoreConfig) (*stream, error) { - return a.addStreamWithAssignment(config, fsConfig, nil) + return a.addStreamWithAssignment(config, fsConfig, nil, false) } -func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileStoreConfig, sa *streamAssignment) (*stream, error) { +func (a *Account) addStreamPedantic(config *StreamConfig, pedantic bool) (*stream, error) { + return a.addStreamWithAssignment(config, nil, nil, pedantic) +} + +func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileStoreConfig, sa *streamAssignment, pedantic bool) (*stream, error) { s, jsa, err := a.checkForJetStream() if err != nil { return nil, err @@ -409,7 +489,7 @@ func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileSt } // Sensible defaults. - cfg, apiErr := s.checkStreamCfg(config, a) + cfg, apiErr := s.checkStreamCfg(config, a, pedantic) if apiErr != nil { return nil, apiErr } @@ -557,6 +637,16 @@ func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileSt c := s.createInternalJetStreamClient() ic := s.createInternalJetStreamClient() + // Work out the stream ingest limits. + mlen := s.opts.StreamMaxBufferedMsgs + msz := uint64(s.opts.StreamMaxBufferedSize) + if mlen == 0 { + mlen = streamDefaultMaxQueueMsgs + } + if msz == 0 { + msz = streamDefaultMaxQueueBytes + } + qpfx := fmt.Sprintf("[ACC:%s] stream '%s' ", a.Name, config.Name) mset := &stream{ acc: a, @@ -569,12 +659,18 @@ func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileSt tier: tier, stype: cfg.Storage, consumers: make(map[string]*consumer), - msgs: newIPQueue[*inMsg](s, qpfx+"messages"), - gets: newIPQueue[*directGetReq](s, qpfx+"direct gets"), - qch: make(chan struct{}), - mqch: make(chan struct{}), - uch: make(chan struct{}, 4), - sch: make(chan struct{}, 1), + msgs: newIPQueue[*inMsg](s, qpfx+"messages", + ipqSizeCalculation(func(msg *inMsg) uint64 { + return uint64(len(msg.hdr) + len(msg.msg) + len(msg.rply) + len(msg.subj)) + }), + ipqLimitByLen[*inMsg](mlen), + ipqLimitBySize[*inMsg](msz), + ), + gets: newIPQueue[*directGetReq](s, qpfx+"direct gets"), + qch: make(chan struct{}), + mqch: make(chan struct{}), + uch: make(chan struct{}, 4), + sch: make(chan struct{}, 1), } // Start our signaling routine to process consumers. @@ -844,7 +940,19 @@ func (mset *stream) isLeader() bool { return true } -// TODO(dlc) - Check to see if we can accept being the leader or we should should step down. +// isLeaderNodeState should NOT be used normally, use isLeader instead. +// Returns whether the node thinks it is the leader, regardless of whether applies are up-to-date yet +// (unlike isLeader, which requires applies to be caught up). +// May be used to respond to clients after a leader change, when applying entries from a former leader. +// Lock should be held. +func (mset *stream) isLeaderNodeState() bool { + if mset.isClustered() { + return mset.node.State() == Leader + } + return true +} + +// TODO(dlc) - Check to see if we can accept being the leader or we should step down. func (mset *stream) setLeader(isLeader bool) error { mset.mu.Lock() // If we are here we have a change in leader status. @@ -1173,7 +1281,7 @@ func (jsa *jsAccount) subjectsOverlap(subjects []string, self *stream) bool { // StreamDefaultDuplicatesWindow default duplicates window. const StreamDefaultDuplicatesWindow = 2 * time.Minute -func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfig, *ApiError) { +func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account, pedantic bool) (StreamConfig, *ApiError) { lim := &s.getOpts().JetStreamLimits if config == nil { @@ -1230,9 +1338,15 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi if cfg.Duplicates == 0 && cfg.Mirror == nil { maxWindow := StreamDefaultDuplicatesWindow if lim.Duplicates > 0 && maxWindow > lim.Duplicates { + if pedantic { + return StreamConfig{}, NewJSPedanticError(fmt.Errorf("pedantic mode: duplicate window limits are higher than current limits")) + } maxWindow = lim.Duplicates } if cfg.MaxAge != 0 && cfg.MaxAge < maxWindow { + if pedantic { + return StreamConfig{}, NewJSPedanticError(fmt.Errorf("pedantic mode: duplicate window cannot be bigger than max age")) + } cfg.Duplicates = cfg.MaxAge } else { cfg.Duplicates = maxWindow @@ -1270,6 +1384,17 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi } } + if cfg.SubjectDeleteMarkerTTL > 0 { + if !cfg.AllowMsgTTL { + return StreamConfig{}, NewJSStreamInvalidConfigError(fmt.Errorf("subject marker delete cannot be set if message TTLs are disabled")) + } + if cfg.SubjectDeleteMarkerTTL < time.Second { + return StreamConfig{}, NewJSStreamInvalidConfigError(fmt.Errorf("subject marker delete TTL must be at least 1 second")) + } + } else if cfg.SubjectDeleteMarkerTTL < 0 { + return StreamConfig{}, NewJSStreamInvalidConfigError(fmt.Errorf("subject marker delete TTL must not be negative")) + } + getStream := func(streamName string) (bool, StreamConfig) { var exists bool var cfg StreamConfig @@ -1277,7 +1402,7 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi if js, _ := s.getJetStreamCluster(); js != nil { js.mu.RLock() if sa := js.streamAssignment(acc.Name, streamName); sa != nil { - cfg = *sa.Config + cfg = *sa.Config.clone() exists = true } js.mu.RUnlock() @@ -1312,11 +1437,23 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi if cfg.Mirror.FilterSubject != _EMPTY_ && len(cfg.Mirror.SubjectTransforms) != 0 { return StreamConfig{}, NewJSMirrorMultipleFiltersNotAllowedError() } + if cfg.SubjectDeleteMarkerTTL > 0 { + // Delete markers cannot be configured on a mirror as it would result in new + // tombstones which would use up sequence numbers, diverging from the origin + // stream. + return StreamConfig{}, NewJSStreamInvalidConfigError(fmt.Errorf("subject delete markers forbidden on mirrors")) + } // Check subject filters overlap. for outer, tr := range cfg.Mirror.SubjectTransforms { - if !IsValidSubject(tr.Source) { - return StreamConfig{}, NewJSMirrorInvalidSubjectFilterError() + if tr.Source != _EMPTY_ && !IsValidSubject(tr.Source) { + return StreamConfig{}, NewJSMirrorInvalidSubjectFilterError(fmt.Errorf("%w %s", ErrBadSubject, tr.Source)) } + + err := ValidateMapping(tr.Source, tr.Destination) + if err != nil { + return StreamConfig{}, NewJSMirrorInvalidTransformDestinationError(err) + } + for inner, innertr := range cfg.Mirror.SubjectTransforms { if inner != outer && SubjectsCollide(tr.Source, innertr.Source) { return StreamConfig{}, NewJSMirrorOverlappingSubjectFiltersError() @@ -1341,6 +1478,9 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi } // Determine if we are inheriting direct gets. if exists, ocfg := getStream(cfg.Mirror.Name); exists { + if pedantic && cfg.MirrorDirect != ocfg.AllowDirect { + return StreamConfig{}, NewJSPedanticError(fmt.Errorf("origin stream has direct get set, mirror has it disabled")) + } cfg.MirrorDirect = ocfg.AllowDirect } else if js := s.getJetStream(); js != nil && js.isClustered() { // Could not find it here. If we are clustered we can look it up. @@ -1348,6 +1488,10 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi if cc := js.cluster; cc != nil { if as := cc.streams[acc.Name]; as != nil { if sa := as[cfg.Mirror.Name]; sa != nil { + if pedantic && cfg.MirrorDirect != sa.Config.AllowDirect { + js.mu.RUnlock() + return StreamConfig{}, NewJSPedanticError(fmt.Errorf("origin stream has direct get set, mirror has it disabled")) + } cfg.MirrorDirect = sa.Config.AllowDirect } } @@ -1358,10 +1502,10 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi if cfg.Mirror.External.DeliverPrefix != _EMPTY_ { deliveryPrefixes = append(deliveryPrefixes, cfg.Mirror.External.DeliverPrefix) } + if cfg.Mirror.External.ApiPrefix != _EMPTY_ { apiPrefixes = append(apiPrefixes, cfg.Mirror.External.ApiPrefix) } - } } @@ -1394,17 +1538,18 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi } for _, tr := range src.SubjectTransforms { - err := ValidateMappingDestination(tr.Destination) + if tr.Source != _EMPTY_ && !IsValidSubject(tr.Source) { + return StreamConfig{}, NewJSSourceInvalidSubjectFilterError(fmt.Errorf("%w %s", ErrBadSubject, tr.Source)) + } + + err := ValidateMapping(tr.Source, tr.Destination) if err != nil { - return StreamConfig{}, NewJSSourceInvalidTransformDestinationError() + return StreamConfig{}, NewJSSourceInvalidTransformDestinationError(err) } } // Check subject filters overlap. for outer, tr := range src.SubjectTransforms { - if !IsValidSubject(tr.Source) { - return StreamConfig{}, NewJSSourceInvalidSubjectFilterError() - } for inner, innertr := range src.SubjectTransforms { if inner != outer && subjectIsSubsetMatch(tr.Source, innertr.Source) { return StreamConfig{}, NewJSSourceOverlappingSubjectFiltersError() @@ -1568,6 +1713,9 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi // Also make sure it does not form a cycle. // Empty same as all. if cfg.RePublish.Source == _EMPTY_ { + if pedantic { + return StreamConfig{}, NewJSPedanticError(fmt.Errorf("pedantic mode: republish source can not be empty")) + } cfg.RePublish.Source = fwcs } var formsCycle bool @@ -1585,6 +1733,23 @@ func (s *Server) checkStreamCfg(config *StreamConfig, acc *Account) (StreamConfi } } + // Check the subject transform if any + if cfg.SubjectTransform != nil { + if cfg.SubjectTransform.Source != _EMPTY_ && !IsValidSubject(cfg.SubjectTransform.Source) { + return StreamConfig{}, NewJSStreamTransformInvalidSourceError(fmt.Errorf("%w %s", ErrBadSubject, cfg.SubjectTransform.Source)) + } + + err := ValidateMapping(cfg.SubjectTransform.Source, cfg.SubjectTransform.Destination) + if err != nil { + return StreamConfig{}, NewJSStreamTransformInvalidDestinationError(err) + } + } + + // For now don't allow preferred server in placement. + if cfg.Placement != nil && cfg.Placement.Preferred != _EMPTY_ { + return StreamConfig{}, NewJSStreamInvalidConfigError(fmt.Errorf("preferred server not permitted in placement")) + } + return cfg, nil } @@ -1606,8 +1771,8 @@ func (mset *stream) fileStoreConfig() (FileStoreConfig, error) { } // Do not hold jsAccount or jetStream lock -func (jsa *jsAccount) configUpdateCheck(old, new *StreamConfig, s *Server) (*StreamConfig, error) { - cfg, apiErr := s.checkStreamCfg(new, jsa.acc()) +func (jsa *jsAccount) configUpdateCheck(old, new *StreamConfig, s *Server, pedantic bool) (*StreamConfig, error) { + cfg, apiErr := s.checkStreamCfg(new, jsa.acc(), pedantic) if apiErr != nil { return nil, apiErr } @@ -1664,7 +1829,13 @@ func (jsa *jsAccount) configUpdateCheck(old, new *StreamConfig, s *Server) (*Str } } + // Check on the allowed message TTL status. + if cfg.AllowMsgTTL != old.AllowMsgTTL { + return nil, NewJSStreamInvalidConfigError(fmt.Errorf("message TTL status can not be changed after stream creation")) + } + // Do some adjustments for being sealed. + // Pedantic mode will allow those changes to be made, as they are determinictic and important to get a sealed stream. if cfg.Sealed { cfg.MaxAge = 0 cfg.Discard = DiscardNew @@ -1733,11 +1904,15 @@ func (jsa *jsAccount) configUpdateCheck(old, new *StreamConfig, s *Server) (*Str // Update will allow certain configuration properties of an existing stream to be updated. func (mset *stream) update(config *StreamConfig) error { - return mset.updateWithAdvisory(config, true) + return mset.updateWithAdvisory(config, true, false) +} + +func (mset *stream) updatePedantic(config *StreamConfig, pedantic bool) error { + return mset.updateWithAdvisory(config, true, pedantic) } // Update will allow certain configuration properties of an existing stream to be updated. -func (mset *stream) updateWithAdvisory(config *StreamConfig, sendAdvisory bool) error { +func (mset *stream) updateWithAdvisory(config *StreamConfig, sendAdvisory bool, pedantic bool) error { _, jsa, err := mset.acc.checkForJetStream() if err != nil { return err @@ -1748,7 +1923,7 @@ func (mset *stream) updateWithAdvisory(config *StreamConfig, sendAdvisory bool) s := mset.srv mset.mu.RUnlock() - cfg, err := mset.jsa.configUpdateCheck(&ocfg, config, s) + cfg, err := mset.jsa.configUpdateCheck(&ocfg, config, s, pedantic) if err != nil { return NewJSStreamInvalidConfigError(err, Unless(err)) } @@ -2048,7 +2223,7 @@ func (mset *stream) purge(preq *JSApiStreamPurgeRequest) (purged uint64, err err mset.mu.RUnlock() if preq != nil { - purged, err = mset.store.PurgeEx(preq.Subject, preq.Sequence, preq.Keep) + purged, err = mset.store.PurgeEx(preq.Subject, preq.Sequence, preq.Keep, false /*preq.NoMarkers*/) } else { purged, err = mset.store.Purge() } @@ -2404,10 +2579,10 @@ func (mset *stream) processInboundMirrorMsg(m *inMsg) bool { s.resourcesExceededError() err = ApiErrors[JSInsufficientResourcesErr] } else { - err = node.Propose(encodeStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts)) + err = node.Propose(encodeStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts, true)) } } else { - err = mset.processJetStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts) + err = mset.processJetStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts, nil, true) } if err != nil { if strings.Contains(err.Error(), "no space left") { @@ -2486,7 +2661,7 @@ func (mset *stream) skipMsgs(start, end uint64) { // With syncRequest was easy to add bool into request. var entries []*Entry for seq := start; seq <= end; seq++ { - entries = append(entries, newEntry(EntryNormal, encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, seq-1, 0))) + entries = append(entries, newEntry(EntryNormal, encodeStreamMsg(_EMPTY_, _EMPTY_, nil, nil, seq-1, 0, false))) // So a single message does not get too big. if len(entries) > 10_000 { node.ProposeMulti(entries) @@ -2574,12 +2749,13 @@ func (mset *stream) setupMirrorConsumer() error { } else { mset.cancelSourceInfo(mset.mirror) mset.mirror.sseq = mset.lseq - - // If we are no longer the leader stop trying. - if !mset.isLeader() { - return nil - } } + + // If we are no longer the leader stop trying. + if !mset.isLeader() { + return nil + } + mirror := mset.mirror // We want to throttle here in terms of how fast we request new consumers, @@ -2758,7 +2934,7 @@ func (mset *stream) setupMirrorConsumer() error { msgs := mirror.msgs sub, err := mset.subscribeInternal(deliverSubject, func(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy. - mset.queueInbound(msgs, subject, reply, hdr, msg, nil) + mset.queueInbound(msgs, subject, reply, hdr, msg, nil, nil) mirror.last.Store(time.Now().UnixNano()) }) if err != nil { @@ -2780,7 +2956,7 @@ func (mset *stream) setupMirrorConsumer() error { // Check to see if delivered is past our last and we have no msgs. This will help the // case when mirroring a stream that has a very high starting sequence number. if state.Msgs == 0 && ccr.ConsumerInfo.Delivered.Stream > state.LastSeq { - mset.store.PurgeEx(_EMPTY_, ccr.ConsumerInfo.Delivered.Stream+1, 0) + mset.store.PurgeEx(_EMPTY_, ccr.ConsumerInfo.Delivered.Stream+1, 0, true) mset.lseq = ccr.ConsumerInfo.Delivered.Stream } else { mset.skipMsgs(state.LastSeq+1, ccr.ConsumerInfo.Delivered.Stream) @@ -3131,7 +3307,7 @@ func (mset *stream) trySetupSourceConsumer(iname string, seq uint64, startTime t msgs := mset.smsgs sub, err := mset.subscribeInternal(deliverSubject, func(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy. - mset.queueInbound(msgs, subject, reply, hdr, msg, si) + mset.queueInbound(msgs, subject, reply, hdr, msg, si, nil) si.last.Store(time.Now().UnixNano()) }) if err != nil { @@ -3192,6 +3368,7 @@ func (mset *stream) processAllSourceMsgs() { if !mset.processInboundSourceMsg(im.si, im) { // If we are no longer leader bail. if !mset.IsLeader() { + msgs.recycle(&ims) cleanUp() return } @@ -3258,7 +3435,7 @@ func (mset *stream) sendFlowControlReply(reply string) { func (mset *stream) handleFlowControl(m *inMsg) { // If we are clustered we will send the flow control message through the replication stack. if mset.isClustered() { - mset.node.Propose(encodeStreamMsg(_EMPTY_, m.rply, m.hdr, nil, 0, 0)) + mset.node.Propose(encodeStreamMsg(_EMPTY_, m.rply, m.hdr, nil, 0, 0, false)) } else { mset.outq.sendMsg(m.rply, nil) } @@ -3364,9 +3541,9 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool { var err error // If we are clustered we need to propose this message to the underlying raft group. if node != nil { - err = mset.processClusteredInboundMsg(m.subj, _EMPTY_, hdr, msg) + err = mset.processClusteredInboundMsg(m.subj, _EMPTY_, hdr, msg, nil, true) } else { - err = mset.processJetStreamMsg(m.subj, _EMPTY_, hdr, msg, 0, 0) + err = mset.processJetStreamMsg(m.subj, _EMPTY_, hdr, msg, 0, 0, nil, true) } if err != nil { @@ -3382,7 +3559,7 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool { // Can happen temporarily all the time during normal operations when the sourcing stream // is working queue/interest with a limit and discard new. // TODO - Improve sourcing to WQ with limit and new to use flow control rather than re-creating the consumer. - if errors.Is(err, ErrMaxMsgs) { + if errors.Is(err, ErrMaxMsgs) || errors.Is(err, ErrMaxBytes) { // Do not need to do a full retry that includes finding the last sequence in the stream // for that source. Just re-create starting with the seq we couldn't store instead. mset.mu.Lock() @@ -3936,6 +4113,15 @@ func (mset *stream) setupStore(fsCfg *FileStoreConfig) error { } // This will fire the callback but we do not require the lock since md will be 0 here. mset.store.RegisterStorageUpdates(mset.storeUpdates) + mset.store.RegisterSubjectDeleteMarkerUpdates(func(im *inMsg) { + if mset.IsClustered() { + if mset.IsLeader() { + mset.processClusteredInboundMsg(im.subj, im.rply, im.hdr, im.msg, im.mt, false) + } + } else { + mset.processJetStreamMsg(im.subj, im.rply, im.hdr, im.msg, 0, 0, im.mt, false) + } + }) mset.mu.Unlock() return nil @@ -4098,6 +4284,48 @@ func getExpectedLastSeqPerSubject(hdr []byte) (uint64, bool) { return uint64(parseInt64(bseq)), true } +// Fast lookup of expected subject for the expected stream sequence per subject. +func getExpectedLastSeqPerSubjectForSubject(hdr []byte) string { + return string(getHeader(JSExpectedLastSubjSeqSubj, hdr)) +} + +// Fast lookup of the message TTL from headers: +// - Positive return value: duration in seconds. +// - Zero return value: no TTL or parse error. +// - Negative return value: never expires. +func getMessageTTL(hdr []byte) (int64, error) { + ttl := getHeader(JSMessageTTL, hdr) + if len(ttl) == 0 { + return 0, nil + } + return parseMessageTTL(bytesToString(ttl)) +} + +// - Positive return value: duration in seconds. +// - Zero return value: no TTL or parse error. +// - Negative return value: never expires. +func parseMessageTTL(ttl string) (int64, error) { + if strings.ToLower(ttl) == "never" { + return -1, nil + } + dur, err := time.ParseDuration(ttl) + if err == nil { + if dur < time.Second { + return 0, NewJSMessageTTLInvalidError() + } + return int64(dur.Seconds()), nil + } + t := parseInt64(stringToBytes(ttl)) + if t < 0 { + // This probably means a parse failure, hence why + // we have a special case "never" for returning -1. + // Otherwise we can't know if it's a genuine TTL + // that says never expire or if it's a parse error. + return 0, NewJSMessageTTLInvalidError() + } + return t, nil +} + // Signal if we are clustered. Will acquire rlock. func (mset *stream) IsClustered() bool { mset.mu.RLock() @@ -4117,6 +4345,7 @@ type inMsg struct { hdr []byte msg []byte si *sourceInfo + mt *msgTrace } var inMsgPool = sync.Pool{ @@ -4126,14 +4355,22 @@ var inMsgPool = sync.Pool{ } func (im *inMsg) returnToPool() { - im.subj, im.rply, im.hdr, im.msg, im.si = _EMPTY_, _EMPTY_, nil, nil, nil + im.subj, im.rply, im.hdr, im.msg, im.si, im.mt = _EMPTY_, _EMPTY_, nil, nil, nil, nil inMsgPool.Put(im) } -func (mset *stream) queueInbound(ib *ipQueue[*inMsg], subj, rply string, hdr, msg []byte, si *sourceInfo) { +func (mset *stream) queueInbound(ib *ipQueue[*inMsg], subj, rply string, hdr, msg []byte, si *sourceInfo, mt *msgTrace) { im := inMsgPool.Get().(*inMsg) - im.subj, im.rply, im.hdr, im.msg, im.si = subj, rply, hdr, msg, si - ib.push(im) + im.subj, im.rply, im.hdr, im.msg, im.si, im.mt = subj, rply, hdr, msg, si, mt + if _, err := ib.push(im); err != nil { + im.returnToPool() + mset.srv.RateLimitWarnf("Dropping messages due to excessive stream ingest rate on '%s' > '%s': %s", mset.acc.Name, mset.name(), err) + if rply != _EMPTY_ { + hdr := []byte("NATS/1.0 429 Too Many Requests\r\n\r\n") + b, _ := json.Marshal(&JSPubAckResponse{PubAck: &PubAck{Stream: mset.cfg.Name}, Error: NewJSStreamTooManyRequestsError()}) + mset.outq.send(newJSPubMsg(rply, _EMPTY_, _EMPTY_, hdr, b, nil, 0)) + } + } } var dgPool = sync.Pool{ @@ -4168,18 +4405,21 @@ func (mset *stream) processDirectGetRequest(_ *subscription, c *client, _ *Accou return } // Check if nothing set. - if req.Seq == 0 && req.LastFor == _EMPTY_ && req.NextFor == _EMPTY_ { + if req.Seq == 0 && req.LastFor == _EMPTY_ && req.NextFor == _EMPTY_ && len(req.MultiLastFor) == 0 && req.StartTime == nil { hdr := []byte("NATS/1.0 408 Empty Request\r\n\r\n") mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) return } - // Check that we do not have both options set. - if req.Seq > 0 && req.LastFor != _EMPTY_ { - hdr := []byte("NATS/1.0 408 Bad Request\r\n\r\n") - mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) - return - } - if req.LastFor != _EMPTY_ && req.NextFor != _EMPTY_ { + // Check we don't have conflicting options set. + // We do not allow batch mode for lastFor requests. + if (req.Seq > 0 && req.LastFor != _EMPTY_) || + (req.Seq > 0 && req.StartTime != nil) || + (req.StartTime != nil && req.LastFor != _EMPTY_) || + (req.LastFor != _EMPTY_ && req.NextFor != _EMPTY_) || + (req.LastFor != _EMPTY_ && req.Batch > 0) || + (req.LastFor != _EMPTY_ && len(req.MultiLastFor) > 0) || + (req.NextFor != _EMPTY_ && len(req.MultiLastFor) > 0) || + (req.UpToSeq > 0 && req.UpToTime != nil) { hdr := []byte("NATS/1.0 408 Bad Request\r\n\r\n") mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) return @@ -4238,50 +4478,262 @@ func (mset *stream) processDirectGetLastBySubjectRequest(_ *subscription, c *cli } } -// Do actual work on a direct msg request. -// This could be called in a Go routine if we are inline for a non-client connection. -func (mset *stream) getDirectRequest(req *JSApiMsgGetRequest, reply string) { - var svp StoreMsg - var sm *StoreMsg - var err error +// For direct get batch and multi requests. +const ( + dg = "NATS/1.0\r\nNats-Stream: %s\r\nNats-Subject: %s\r\nNats-Sequence: %d\r\nNats-Time-Stamp: %s\r\n\r\n" + dgb = "NATS/1.0\r\nNats-Stream: %s\r\nNats-Subject: %s\r\nNats-Sequence: %d\r\nNats-Time-Stamp: %s\r\nNats-Num-Pending: %d\r\nNats-Last-Sequence: %d\r\n\r\n" + eob = "NATS/1.0 204 EOB\r\nNats-Num-Pending: %d\r\nNats-Last-Sequence: %d\r\n\r\n" + eobm = "NATS/1.0 204 EOB\r\nNats-Num-Pending: %d\r\nNats-Last-Sequence: %d\r\nNats-UpTo-Sequence: %d\r\n\r\n" +) +// Handle a multi request. +func (mset *stream) getDirectMulti(req *JSApiMsgGetRequest, reply string) { + // TODO(dlc) - Make configurable? + const maxAllowedResponses = 1024 + + // We hold the lock here to try to avoid changes out from underneath of us. mset.mu.RLock() - store, name := mset.store, mset.cfg.Name - mset.mu.RUnlock() + defer mset.mu.RUnlock() + // Grab store and name. + store, name, s := mset.store, mset.cfg.Name, mset.srv - if req.Seq > 0 && req.NextFor == _EMPTY_ { - sm, err = store.LoadMsg(req.Seq, &svp) - } else if req.NextFor != _EMPTY_ { - sm, _, err = store.LoadNextMsg(req.NextFor, subjectHasWildcard(req.NextFor), req.Seq, &svp) - } else { - sm, err = store.LoadLastMsg(req.LastFor, &svp) + // Grab MaxBytes + mb := req.MaxBytes + if mb == 0 && s != nil { + // Fill in with the server's MaxPending. + mb = int(s.opts.MaxPending) } + + upToSeq := req.UpToSeq + // If we have UpToTime set get the proper sequence. + if req.UpToTime != nil { + upToSeq = store.GetSeqFromTime((*req.UpToTime).UTC()) + // We need to back off one since this is used to determine start sequence normally, + // were as here we want it to be the ceiling. + upToSeq-- + } + // If not set, set to the last sequence and remember that for EOB. + if upToSeq == 0 { + var state StreamState + mset.store.FastState(&state) + upToSeq = state.LastSeq + } + + seqs, err := store.MultiLastSeqs(req.MultiLastFor, upToSeq, maxAllowedResponses) if err != nil { - hdr := []byte("NATS/1.0 404 Message Not Found\r\n\r\n") + var hdr []byte + if err == ErrTooManyResults { + hdr = []byte("NATS/1.0 413 Too Many Results\r\n\r\n") + } else { + hdr = []byte(fmt.Sprintf("NATS/1.0 500 %v\r\n\r\n", err)) + } + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) + return + } + if len(seqs) == 0 { + hdr := []byte("NATS/1.0 404 No Results\r\n\r\n") mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) return } - hdr := sm.hdr - ts := time.Unix(0, sm.ts).UTC() + np, lseq, sentBytes, sent := uint64(len(seqs)-1), uint64(0), 0, 0 + for _, seq := range seqs { + if seq < req.Seq { + if np > 0 { + np-- + } + continue + } + var svp StoreMsg + sm, err := store.LoadMsg(seq, &svp) + if err != nil { + hdr := []byte("NATS/1.0 404 Message Not Found\r\n\r\n") + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) + return + } - if len(hdr) == 0 { - const ht = "NATS/1.0\r\nNats-Stream: %s\r\nNats-Subject: %s\r\nNats-Sequence: %d\r\nNats-Time-Stamp: %s\r\n\r\n" - hdr = fmt.Appendf(nil, ht, name, sm.subj, sm.seq, ts.Format(time.RFC3339Nano)) - } else { - hdr = copyBytes(hdr) - hdr = genHeader(hdr, JSStream, name) - hdr = genHeader(hdr, JSSubject, sm.subj) - hdr = genHeader(hdr, JSSequence, strconv.FormatUint(sm.seq, 10)) - hdr = genHeader(hdr, JSTimeStamp, ts.Format(time.RFC3339Nano)) + hdr := sm.hdr + ts := time.Unix(0, sm.ts).UTC() + + if len(hdr) == 0 { + hdr = fmt.Appendf(nil, dgb, name, sm.subj, sm.seq, ts.Format(time.RFC3339Nano), np, lseq) + } else { + hdr = copyBytes(hdr) + hdr = genHeader(hdr, JSStream, name) + hdr = genHeader(hdr, JSSubject, sm.subj) + hdr = genHeader(hdr, JSSequence, strconv.FormatUint(sm.seq, 10)) + hdr = genHeader(hdr, JSTimeStamp, ts.Format(time.RFC3339Nano)) + hdr = genHeader(hdr, JSNumPending, strconv.FormatUint(np, 10)) + hdr = genHeader(hdr, JSLastSequence, strconv.FormatUint(lseq, 10)) + } + // Decrement num pending. This is optimization and we do not continue to look it up for these operations. + if np > 0 { + np-- + } + // Track our lseq + lseq = sm.seq + // Send out our message. + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, sm.msg, nil, 0)) + // Check if we have exceeded max bytes. + sentBytes += len(sm.subj) + len(sm.hdr) + len(sm.msg) + if sentBytes >= mb { + break + } + sent++ + if req.Batch > 0 && sent >= req.Batch { + break + } + } + + // Send out EOB + hdr := fmt.Appendf(nil, eobm, np, lseq, upToSeq) + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) +} + +// Do actual work on a direct msg request. +// This could be called in a Go routine if we are inline for a non-client connection. +func (mset *stream) getDirectRequest(req *JSApiMsgGetRequest, reply string) { + // Handle multi in separate function. + if len(req.MultiLastFor) > 0 { + mset.getDirectMulti(req, reply) + return + } + + mset.mu.RLock() + store, name, s := mset.store, mset.cfg.Name, mset.srv + mset.mu.RUnlock() + + var seq uint64 + // Lookup start seq if AsOfTime is set. + if req.StartTime != nil { + seq = store.GetSeqFromTime(*req.StartTime) + } else { + seq = req.Seq + } + + wc := subjectHasWildcard(req.NextFor) + // For tracking num pending if we are batch. + var np, lseq, validThrough uint64 + var isBatchRequest bool + batch := req.Batch + if batch == 0 { + batch = 1 + } else { + // This is a batch request, capture initial numPending. + isBatchRequest = true + np, validThrough = store.NumPending(seq, req.NextFor, false) + } + + // Grab MaxBytes + mb := req.MaxBytes + if mb == 0 && s != nil { + // Fill in with the server's MaxPending. + mb = int(s.opts.MaxPending) + } + // Track what we have sent. + var sentBytes int + + // Loop over batch, which defaults to 1. + for i := 0; i < batch; i++ { + var ( + svp StoreMsg + sm *StoreMsg + err error + ) + if seq > 0 && req.NextFor == _EMPTY_ { + // Only do direct lookup for first in a batch. + if i == 0 { + sm, err = store.LoadMsg(seq, &svp) + } else { + // We want to use load next with fwcs to step over deleted msgs. + sm, seq, err = store.LoadNextMsg(fwcs, true, seq, &svp) + } + // Bump for next loop if applicable. + seq++ + } else if req.NextFor != _EMPTY_ { + sm, seq, err = store.LoadNextMsg(req.NextFor, wc, seq, &svp) + seq++ + } else { + // Batch is not applicable here, this is checked before we get here. + sm, err = store.LoadLastMsg(req.LastFor, &svp) + } + if err != nil { + // For batches, if we stop early we want to do EOB logic below. + if batch > 1 && i > 0 { + break + } + hdr := []byte("NATS/1.0 404 Message Not Found\r\n\r\n") + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) + return + } + + hdr := sm.hdr + ts := time.Unix(0, sm.ts).UTC() + + if isBatchRequest { + if len(hdr) == 0 { + hdr = fmt.Appendf(nil, dgb, name, sm.subj, sm.seq, ts.Format(time.RFC3339Nano), np, lseq) + } else { + hdr = copyBytes(hdr) + hdr = genHeader(hdr, JSStream, name) + hdr = genHeader(hdr, JSSubject, sm.subj) + hdr = genHeader(hdr, JSSequence, strconv.FormatUint(sm.seq, 10)) + hdr = genHeader(hdr, JSTimeStamp, ts.Format(time.RFC3339Nano)) + hdr = genHeader(hdr, JSNumPending, strconv.FormatUint(np, 10)) + hdr = genHeader(hdr, JSLastSequence, strconv.FormatUint(lseq, 10)) + } + // Decrement num pending. This is optimization and we do not continue to look it up for these operations. + np-- + } else { + if len(hdr) == 0 { + hdr = fmt.Appendf(nil, dg, name, sm.subj, sm.seq, ts.Format(time.RFC3339Nano)) + } else { + hdr = copyBytes(hdr) + hdr = genHeader(hdr, JSStream, name) + hdr = genHeader(hdr, JSSubject, sm.subj) + hdr = genHeader(hdr, JSSequence, strconv.FormatUint(sm.seq, 10)) + hdr = genHeader(hdr, JSTimeStamp, ts.Format(time.RFC3339Nano)) + } + } + // Track our lseq + lseq = sm.seq + // Send out our message. + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, sm.msg, nil, 0)) + // Check if we have exceeded max bytes. + sentBytes += len(sm.subj) + len(sm.hdr) + len(sm.msg) + if sentBytes >= mb { + break + } + } + + // If batch was requested send EOB. + if isBatchRequest { + // Update if the stream's lasts sequence has moved past our validThrough. + if mset.lastSeq() > validThrough { + np, _ = store.NumPending(seq, req.NextFor, false) + } + hdr := fmt.Appendf(nil, eob, np, lseq) + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) } - mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, hdr, sm.msg, nil, 0)) } // processInboundJetStreamMsg handles processing messages bound for a stream. func (mset *stream) processInboundJetStreamMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy. - mset.queueInbound(mset.msgs, subject, reply, hdr, msg, nil) + if mt, traceOnly := c.isMsgTraceEnabled(); mt != nil { + // If message is delivered, we need to disable the message trace headers + // to prevent a trace event to be generated when a stored message + // is delivered to a consumer and routed. + if !traceOnly { + disableTraceHeaders(c, hdr) + } + // This will add the jetstream event while in the client read loop. + // Since the event will be updated in a different go routine, the + // tracing object will have a separate reference to the JS trace + // object. + mt.addJetStreamEvent(mset.name()) + } + mset.queueInbound(mset.msgs, subject, reply, hdr, msg, nil, c.pa.trace) } var ( @@ -4290,10 +4742,19 @@ var ( errStreamClosed = errors.New("stream closed") errInvalidMsgHandler = errors.New("undefined message handler") errStreamMismatch = errors.New("expected stream does not match") + errMsgTTLDisabled = errors.New("message TTL disabled") ) // processJetStreamMsg is where we try to actually process the stream msg. -func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64) error { +func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, mt *msgTrace, sourced bool) (retErr error) { + if mt != nil { + // Only the leader/standalone will have mt!=nil. On exit, send the + // message trace event. + defer func() { + mt.sendEventFromJetStream(retErr) + }() + } + if mset.closed.Load() { return errStreamClosed } @@ -4301,11 +4762,17 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, mset.mu.Lock() s, store := mset.srv, mset.store + traceOnly := mt.traceOnly() bumpCLFS := func() { + // Do not bump if tracing and not doing message delivery. + if traceOnly { + return + } mset.clMu.Lock() mset.clfs++ mset.clMu.Unlock() } + // Apply the input subject transform if any if mset.itr != nil { ts, err := mset.itr.Match(subject) @@ -4326,7 +4793,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, numConsumers := len(mset.consumers) interestRetention := mset.cfg.Retention == InterestPolicy // Snapshot if we are the leader and if we can respond. - isLeader, isSealed := mset.isLeader(), mset.cfg.Sealed + isLeader, isSealed := mset.isLeaderNodeState(), mset.cfg.Sealed canRespond := doAck && len(reply) > 0 && isLeader var resp = &JSPubAckResponse{} @@ -4397,7 +4864,9 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, outq := mset.outq // Certain checks have already been performed if in clustered mode, so only check if not. - if !isClustered { + // Note, for cluster mode but with message tracing (without message delivery), we need + // to do this check here since it was not done in processClusteredInboundMsg(). + if !isClustered || traceOnly { // Expected stream. if sname := getExpectedStream(hdr); sname != _EMPTY_ && sname != name { mset.mu.Unlock() @@ -4412,11 +4881,26 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, } } + // TTL'd messages are rejected entirely if TTLs are not enabled on the stream. + // Shouldn't happen in clustered mode since we should have already caught this + // in processClusteredInboundMsg, but needed here for non-clustered etc. + if ttl, _ := getMessageTTL(hdr); !sourced && ttl != 0 && !mset.cfg.AllowMsgTTL { + mset.mu.Unlock() + bumpCLFS() + if canRespond { + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSMessageTTLDisabledError() + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return errMsgTTLDisabled + } + // Dedupe detection. This is done at the cluster level for dedupe detectiom above the // lower layers. But we still need to pull out the msgId. if msgId = getMsgId(hdr); msgId != _EMPTY_ { // Do real check only if not clustered or traceOnly flag is set. - if !isClustered { + if !isClustered || traceOnly { if dde := mset.checkMsgId(msgId); dde != nil { mset.mu.Unlock() bumpCLFS() @@ -4432,10 +4916,16 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Expected last sequence per subject. if seq, exists := getExpectedLastSeqPerSubject(hdr); exists { + // Allow override of the subject used for the check. + seqSubj := subject + if optSubj := getExpectedLastSeqPerSubjectForSubject(hdr); optSubj != _EMPTY_ { + seqSubj = optSubj + } + // TODO(dlc) - We could make a new store func that does this all in one. var smv StoreMsg var fseq uint64 - sm, err := store.LoadLastMsg(subject, &smv) + sm, err := store.LoadLastMsg(seqSubj, &smv) if sm != nil { fseq = sm.seq } @@ -4516,7 +5006,14 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, default: mset.mu.Unlock() bumpCLFS() - return fmt.Errorf("rollup value invalid: %q", rollup) + err := fmt.Errorf("rollup value invalid: %q", rollup) + if canRespond { + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSStreamRollupFailedError(err) + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return err } } } @@ -4585,6 +5082,12 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, ts = time.Now().UnixNano() } + mt.updateJetStreamEvent(subject, noInterest) + if traceOnly { + mset.mu.Unlock() + return nil + } + // Skip msg here. if noInterest { mset.lseq = store.SkipMsg() @@ -4648,19 +5151,30 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, } } + // Find the message TTL if any. + ttl, err := getMessageTTL(hdr) + if err != nil { + if canRespond { + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSMessageTTLInvalidError() + response, _ = json.Marshal(resp) + mset.outq.send(newJSPubMsg(reply, _EMPTY_, _EMPTY_, nil, response, nil, 0)) + } + mset.mu.Unlock() + return err + } + // Store actual msg. if lseq == 0 && ts == 0 { - seq, ts, err = store.StoreMsg(subject, hdr, msg) + seq, ts, err = store.StoreMsg(subject, hdr, msg, ttl) } else { // Make sure to take into account any message assignments that we had to skip (clfs). seq = lseq + 1 - clfs - // Check for preAcks and the need to skip vs store. + // Check for preAcks and the need to clear it. if mset.hasAllPreAcks(seq, subject) { mset.clearAllPreAcks(seq) - store.SkipMsg() - } else { - err = store.StoreRawMsg(subject, hdr, msg, seq, ts) } + err = store.StoreRawMsg(subject, hdr, msg, seq, ts, ttl) } if err != nil { @@ -5068,9 +5582,9 @@ func (mset *stream) internalLoop() { for _, im := range ims { // If we are clustered we need to propose this message to the underlying raft group. if isClustered { - mset.processClusteredInboundMsg(im.subj, im.rply, im.hdr, im.msg) + mset.processClusteredInboundMsg(im.subj, im.rply, im.hdr, im.msg, im.mt, false) } else { - mset.processJetStreamMsg(im.subj, im.rply, im.hdr, im.msg, 0, 0) + mset.processJetStreamMsg(im.subj, im.rply, im.hdr, im.msg, 0, 0, im.mt, false) } im.returnToPool() } @@ -5108,9 +5622,7 @@ func (mset *stream) resetAndWaitOnConsumers() { for _, o := range consumers { if node := o.raftNode(); node != nil { - if o.IsLeader() { - node.StepDown() - } + node.StepDown() node.Delete() } if o.isMonitorRunning() { @@ -5338,6 +5850,17 @@ func (mset *stream) getPublicConsumers() []*consumer { return obs } +// 2 minutes plus up to 30s jitter. +const ( + defaultCheckInterestStateT = 2 * time.Minute + defaultCheckInterestStateJ = 30 +) + +var ( + checkInterestStateT = defaultCheckInterestStateT // Interval + checkInterestStateJ = defaultCheckInterestStateJ // Jitter (secs) +) + // Will check for interest retention and make sure messages // that have been acked are processed and removed. // This will check the ack floors of all consumers, and adjust our first sequence accordingly. @@ -5684,6 +6207,7 @@ func (mset *stream) clearPreAck(o *consumer, seq uint64) { // ackMsg is called into from a consumer when we have a WorkQueue or Interest Retention Policy. // Returns whether the message at seq was removed as a result of the ACK. +// (Or should be removed in the case of clustered streams, since it requires a message delete proposal) func (mset *stream) ackMsg(o *consumer, seq uint64) bool { if seq == 0 { return false @@ -5727,18 +6251,34 @@ func (mset *stream) ackMsg(o *consumer, seq uint64) bool { case InterestPolicy: shouldRemove = mset.noInterest(seq, o) } - mset.mu.Unlock() // If nothing else to do. if !shouldRemove { + mset.mu.Unlock() return false } - // If we are here we should attempt to remove. - if _, err := store.RemoveMsg(seq); err == ErrStoreEOF { - // This should not happen, but being pedantic. - mset.registerPreAckLock(o, seq) + if !mset.isClustered() { + mset.mu.Unlock() + // If we are here we should attempt to remove. + if _, err := store.RemoveMsg(seq); err == ErrStoreEOF { + // This should not happen, but being pedantic. + mset.registerPreAckLock(o, seq) + } + return true } + + // Only propose message deletion to the stream if we're consumer leader, otherwise all followers would also propose. + // We must be the consumer leader, since we know for sure we've stored the message and don't register as pre-ack. + if o != nil && !o.IsLeader() { + mset.mu.Unlock() + // Must still mark as removal if follower. If we become leader later, we must be able to retry the proposal. + return true + } + + md := streamMsgDelete{Seq: seq, NoErase: true, Stream: mset.cfg.Name} + mset.node.ForwardProposal(encodeMsgDelete(&md)) + mset.mu.Unlock() return true } @@ -5764,7 +6304,7 @@ func (a *Account) RestoreStream(ncfg *StreamConfig, r io.Reader) (*stream, error return nil, err } - cfg, apiErr := s.checkStreamCfg(ncfg, a) + cfg, apiErr := s.checkStreamCfg(ncfg, a, false) if apiErr != nil { return nil, apiErr } diff --git a/vendor/github.com/nats-io/nats-server/v2/server/subject_transform.go b/vendor/github.com/nats-io/nats-server/v2/server/subject_transform.go index 41e42722d1..42cc17e067 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/subject_transform.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/subject_transform.go @@ -459,7 +459,16 @@ func (tr *subjectTransform) TransformTokenizedSubject(tokens []string) string { } b.WriteString(tr.getHashPartition(keyForHashing, int(tr.dtokmfintargs[i]))) case Wildcard: // simple substitution - b.WriteString(tokens[tr.dtokmftokindexesargs[i][0]]) + switch { + case len(tr.dtokmftokindexesargs) < i: + break + case len(tr.dtokmftokindexesargs[i]) < 1: + break + case len(tokens) <= tr.dtokmftokindexesargs[i][0]: + break + default: + b.WriteString(tokens[tr.dtokmftokindexesargs[i][0]]) + } case SplitFromLeft: sourceToken := tokens[tr.dtokmftokindexesargs[i][0]] sourceTokenLen := len(sourceToken) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sublist.go b/vendor/github.com/nats-io/nats-server/v2/server/sublist.go index 9f79dfe18b..004150aa9d 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sublist.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sublist.go @@ -1271,12 +1271,12 @@ func isValidLiteralSubject(tokens []string) bool { return true } -// ValidateMappingDestination returns nil error if the subject is a valid subject mapping destination subject -func ValidateMappingDestination(subject string) error { - if subject == _EMPTY_ { +// ValidateMapping returns nil error if the subject is a valid subject mapping destination subject +func ValidateMapping(src string, dest string) error { + if dest == _EMPTY_ { return nil } - subjectTokens := strings.Split(subject, tsep) + subjectTokens := strings.Split(dest, tsep) sfwc := false for _, t := range subjectTokens { length := len(t) @@ -1284,6 +1284,7 @@ func ValidateMappingDestination(subject string) error { return &mappingDestinationErr{t, ErrInvalidMappingDestinationSubject} } + // if it looks like it contains a mapping function, it should be a valid mapping function if length > 4 && t[0] == '{' && t[1] == '{' && t[length-2] == '}' && t[length-1] == '}' { if !partitionMappingFunctionRegEx.MatchString(t) && !wildcardMappingFunctionRegEx.MatchString(t) && @@ -1304,7 +1305,10 @@ func ValidateMappingDestination(subject string) error { return ErrInvalidMappingDestinationSubject } } - return nil + + // Finally, verify that the transform can actually be created from the source and destination + _, err := NewSubjectTransform(src, dest) + return err } // Will check tokens and report back if the have any partial or full wildcards. diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_bsd.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_bsd.go index 341b31a72d..6cc63b1d1a 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_bsd.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_bsd.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build freebsd || openbsd || dragonfly || netbsd -// +build freebsd openbsd dragonfly netbsd package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_darwin.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_darwin.go index ae078443d8..f8e049b9a8 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_darwin.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_darwin.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build darwin -// +build darwin package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_linux.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_linux.go index 6bfa73a0be..26e0bd1525 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_linux.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_linux.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build linux -// +build linux package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_wasm.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_wasm.go index 806360640e..bbc43af7fe 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_wasm.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_wasm.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build wasm -// +build wasm package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_windows.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_windows.go index bf02133e8f..3f070887d2 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_windows.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_windows.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build windows -// +build windows package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_zos.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_zos.go index e798a80bfd..cc57620e85 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_zos.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/mem_zos.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build zos -// +build zos package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/sysctl.go b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/sysctl.go index 82946f0205..550961ae10 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/sysmem/sysctl.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/sysmem/sysctl.go @@ -12,7 +12,6 @@ // limitations under the License. //go:build darwin || freebsd || openbsd || dragonfly || netbsd -// +build darwin freebsd openbsd dragonfly netbsd package sysmem diff --git a/vendor/github.com/nats-io/nats-server/v2/server/thw/thw.go b/vendor/github.com/nats-io/nats-server/v2/server/thw/thw.go new file mode 100644 index 0000000000..5a54470104 --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/thw/thw.go @@ -0,0 +1,257 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package thw + +import ( + "encoding/binary" + "errors" + "io" + "math" + "time" +) + +// Error for when we can not locate a task for removal or updates. +var ErrTaskNotFound = errors.New("thw: task not found") + +// Error for when we try to decode a binary-encoded THW with an unknown version number. +var ErrInvalidVersion = errors.New("thw: encoded version not known") + +const ( + tickDuration = int64(time.Second) // Tick duration in nanoseconds. + wheelBits = 12 // 2^12 = 4096 slots. + wheelSize = 1 << wheelBits // Number of slots in the wheel. + wheelMask = wheelSize - 1 // Mask for calculating position. + headerLen = 17 // 1 byte magic + 2x uint64s +) + +// slot represents a single slot in the wheel. +type slot struct { + entries map[uint64]int64 // Map of sequence to expires. + lowest int64 // Lowest expiration time in this slot. +} + +// HashWheel represents the timing wheel. +type HashWheel struct { + wheel []*slot // Array of slots. + lowest int64 // Track the lowest expiration time across all slots. + count uint64 // How many entries are present? +} + +// NewHashWheel initializes a new HashWheel. +func NewHashWheel() *HashWheel { + return &HashWheel{ + wheel: make([]*slot, wheelSize), + lowest: math.MaxInt64, + } +} + +// getPosition calculates the slot position for a given expiration time. +func (hw *HashWheel) getPosition(expires int64) int64 { + return (expires / tickDuration) & wheelMask +} + +// updateLowestExpires finds the new lowest expiration time across all slots. +func (hw *HashWheel) updateLowestExpires() { + lowest := int64(math.MaxInt64) + for _, s := range hw.wheel { + if s != nil && s.lowest < lowest { + lowest = s.lowest + } + } + hw.lowest = lowest +} + +// newSlot creates a new slot. +func newSlot() *slot { + return &slot{ + entries: make(map[uint64]int64), + lowest: math.MaxInt64, + } +} + +// Add schedules a new timer task. +func (hw *HashWheel) Add(seq uint64, expires int64) error { + pos := hw.getPosition(expires) + // Initialize the slot lazily. + if hw.wheel[pos] == nil { + hw.wheel[pos] = newSlot() + } + if _, ok := hw.wheel[pos].entries[seq]; !ok { + hw.count++ + } + hw.wheel[pos].entries[seq] = expires + + // Update slot's lowest expiration if this is earlier. + if expires < hw.wheel[pos].lowest { + hw.wheel[pos].lowest = expires + // Update global lowest if this is now the earliest. + if expires < hw.lowest { + hw.lowest = expires + } + } + + return nil +} + +// Remove removes a timer task. +func (hw *HashWheel) Remove(seq uint64, expires int64) error { + pos := hw.getPosition(expires) + s := hw.wheel[pos] + if s == nil { + return ErrTaskNotFound + } + if _, exists := s.entries[seq]; !exists { + return ErrTaskNotFound + } + delete(s.entries, seq) + hw.count-- + + // If the slot is empty, we can set it to nil to free memory. + if len(s.entries) == 0 { + hw.wheel[pos] = nil + } else if expires == s.lowest { + // Find new lowest in this slot. + lowest := int64(math.MaxInt64) + for _, exp := range s.entries { + if exp < lowest { + lowest = exp + } + } + s.lowest = lowest + } + + // If we removed the global lowest, find the new one. + if expires == hw.lowest { + hw.updateLowestExpires() + } + + return nil +} + +// Update updates the expiration time of an existing timer task. +func (hw *HashWheel) Update(seq uint64, oldExpires int64, newExpires int64) error { + // Remove from old position. + if err := hw.Remove(seq, oldExpires); err != nil { + return err + } + // Add to new position. + return hw.Add(seq, newExpires) +} + +// ExpireTasks processes all expired tasks using a callback. +func (hw *HashWheel) ExpireTasks(callback func(seq uint64, expires int64)) { + now := time.Now().UnixNano() + + // Quick return if nothing is expired. + if hw.lowest > now { + return + } + + // Start from the slot containing the lowest expiration. + startPos, exitPos := hw.getPosition(hw.lowest), hw.getPosition(now+tickDuration) + var updateLowest bool + + for offset := int64(0); ; offset++ { + pos := (startPos + offset) & wheelMask + if pos == exitPos { + if updateLowest { + hw.updateLowestExpires() + } + return + } + // Grab our slot. + slot := hw.wheel[pos] + if slot == nil || slot.lowest > now { + continue + } + + // Track new lowest while processing expirations + newLowest := int64(math.MaxInt64) + for seq, expires := range slot.entries { + if expires <= now { + callback(seq, expires) + delete(slot.entries, seq) + hw.count-- + updateLowest = true + } else if expires < newLowest { + newLowest = expires + } + } + + // Nil out if we are empty. + if len(slot.entries) == 0 { + hw.wheel[pos] = nil + } else { + slot.lowest = newLowest + } + } +} + +// GetNextExpiration returns the earliest expiration time before the given time. +// Returns math.MaxInt64 if no expirations exist before the specified time. +func (hw *HashWheel) GetNextExpiration(before int64) int64 { + if hw.lowest < before { + return hw.lowest + } + return math.MaxInt64 +} + +// AppendEncode writes out the contents of the THW into a binary snapshot +// and returns it. The high seq number is included in the snapshot and will +// be returned on decode. +func (hw *HashWheel) Encode(highSeq uint64) []byte { + b := make([]byte, 0, headerLen+(hw.count*(2*binary.MaxVarintLen64))) + b = append(b, 1) // Magic version + b = binary.LittleEndian.AppendUint64(b, hw.count) // Entry count + b = binary.LittleEndian.AppendUint64(b, highSeq) // Stamp + for _, slot := range hw.wheel { + if slot == nil || slot.entries == nil { + continue + } + for v, ts := range slot.entries { + b = binary.AppendVarint(b, ts) + b = binary.AppendUvarint(b, v) + } + } + return b +} + +// Decode snapshots a binary-encoded THW and replaces the contents of this +// THW with them. Returns the high seq number from the snapshot. +func (hw *HashWheel) Decode(b []byte) (uint64, error) { + if len(b) < headerLen { + return 0, io.ErrShortBuffer + } + if b[0] != 1 { + return 0, ErrInvalidVersion + } + hw.wheel = make([]*slot, wheelSize) + hw.lowest = math.MaxInt64 + count := binary.LittleEndian.Uint64(b[1:]) + stamp := binary.LittleEndian.Uint64(b[9:]) + b = b[headerLen:] + for i := uint64(0); i < count; i++ { + ts, tn := binary.Varint(b) + if tn < 0 { + return 0, io.ErrUnexpectedEOF + } + v, vn := binary.Uvarint(b[tn:]) + if vn < 0 { + return 0, io.ErrUnexpectedEOF + } + hw.Add(v, ts) + b = b[tn+vn:] + } + return stamp, nil +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_other.go b/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_other.go new file mode 100644 index 0000000000..a1ed593ac0 --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_other.go @@ -0,0 +1,23 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +package tpm + +import "fmt" + +// LoadJetStreamEncryptionKeyFromTPM here is a stub for unsupported platforms. +func LoadJetStreamEncryptionKeyFromTPM(srkPassword, jsKeyFile, jsKeyPassword string, pcr int) (string, error) { + return "", fmt.Errorf("TPM functionality is not supported on this platform") +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_windows.go b/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_windows.go new file mode 100644 index 0000000000..5e401680a0 --- /dev/null +++ b/vendor/github.com/nats-io/nats-server/v2/server/tpm/js_ek_tpm_windows.go @@ -0,0 +1,281 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows + +package tpm + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/google/go-tpm/legacy/tpm2" + "github.com/google/go-tpm/tpmutil" + "github.com/nats-io/nkeys" +) + +var ( + // Version of the NATS TPM JS implmentation + JsKeyTPMVersion = 1 +) + +// How this works: +// Create a Storage Root Key (SRK) in the TPM. +// If existing JS Encryption keys do not exist on disk. +// - Create a JetStream encryption key (js key) and seal it to the SRK +// using a provided js encryption key password. +// - Save the public and private blobs to a file on disk. +// - Return the new js encryption key (the private portion of the nkey) +// Otherwise (keys exist on disk) +// - Read the public and private blobs from disk +// - Load them into the TPM +// - Unseal the js key using the TPM, and the provided js encryption keys password. +// +// Note: a SRK password for the SRK is supported but not tested here. + +// Gets/Regenerates the Storage Root Key (SRK) from the TPM. Caller MUST flush this handle when done. +func regenerateSRK(rwc io.ReadWriteCloser, srkPassword string) (tpmutil.Handle, error) { + // Default EK template defined in: + // https://trustedcomputinggroup.org/wp-content/uploads/Credential_Profile_EK_V2.0_R14_published.pdf + // Shared SRK template based off of EK template and specified in: + // https://trustedcomputinggroup.org/wp-content/uploads/TCG-TPM-v2.0-Provisioning-Guidance-Published-v1r1.pdf + srkTemplate := tpm2.Public{ + Type: tpm2.AlgRSA, + NameAlg: tpm2.AlgSHA256, + Attributes: tpm2.FlagFixedTPM | tpm2.FlagFixedParent | tpm2.FlagSensitiveDataOrigin | tpm2.FlagUserWithAuth | tpm2.FlagRestricted | tpm2.FlagDecrypt | tpm2.FlagNoDA, + AuthPolicy: nil, + // We must use RSA 2048 for the intel TSS2 stack + RSAParameters: &tpm2.RSAParams{ + Symmetric: &tpm2.SymScheme{ + Alg: tpm2.AlgAES, + KeyBits: 128, + Mode: tpm2.AlgCFB, + }, + KeyBits: 2048, + ModulusRaw: make([]byte, 256), + }, + } + // Create the parent key against which to seal the data + srkHandle, _, err := tpm2.CreatePrimary(rwc, tpm2.HandleOwner, tpm2.PCRSelection{}, "", srkPassword, srkTemplate) + return srkHandle, err +} + +type natsTPMPersistedKeys struct { + Version int `json:"version"` + PrivateKey []byte `json:"private_key"` + PublicKey []byte `json:"public_key"` +} + +// Writes the private and public blobs to disk in a single file. If the directory does +// not exist, it will be created. If the file already exists it will be overwritten. +func writeTPMKeysToFile(filename string, privateBlob []byte, publicBlob []byte) error { + keyDir := filepath.Dir(filename) + if err := os.MkdirAll(keyDir, 0750); err != nil { + return fmt.Errorf("unable to create/access directory %q: %v", keyDir, err) + } + + // Create a new set of persisted keys. Note that the private key doesn't necessarily + // need to be protected as the TPM password is required to use unseal, although it's + // a good idea to put this in a secure location accessible to the server. + tpmKeys := natsTPMPersistedKeys{ + Version: JsKeyTPMVersion, + PrivateKey: make([]byte, base64.StdEncoding.EncodedLen(len(privateBlob))), + PublicKey: make([]byte, base64.StdEncoding.EncodedLen(len(publicBlob))), + } + base64.StdEncoding.Encode(tpmKeys.PrivateKey, privateBlob) + base64.StdEncoding.Encode(tpmKeys.PublicKey, publicBlob) + // Convert to JSON + keysJSON, err := json.Marshal(tpmKeys) + if err != nil { + return fmt.Errorf("unable to marshal keys to JSON: %v", err) + } + // Write the JSON to a file + if err := os.WriteFile(filename, keysJSON, 0640); err != nil { + return fmt.Errorf("unable to write keys file to %q: %v", filename, err) + } + return nil +} + +// Reads the private and public blobs from a single file. If the file does not exist, +// or the file cannot be read and the keys decoded, an error is returned. +func readTPMKeysFromFile(filename string) ([]byte, []byte, error) { + keysJSON, err := os.ReadFile(filename) + if err != nil { + return nil, nil, err + } + + var tpmKeys natsTPMPersistedKeys + if err := json.Unmarshal(keysJSON, &tpmKeys); err != nil { + return nil, nil, fmt.Errorf("unable to unmarshal TPM file keys JSON from %s: %v", filename, err) + } + + // Placeholder for future-proofing. Here is where we would + // check the current version against tpmKeys.Version and + // handle any changes. + + // Base64 decode the private and public blobs. + privateBlob := make([]byte, base64.StdEncoding.DecodedLen(len(tpmKeys.PrivateKey))) + publicBlob := make([]byte, base64.StdEncoding.DecodedLen(len(tpmKeys.PublicKey))) + prn, err := base64.StdEncoding.Decode(privateBlob, tpmKeys.PrivateKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to decode privateBlob from base64: %v", err) + } + pun, err := base64.StdEncoding.Decode(publicBlob, tpmKeys.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to decode publicBlob from base64: %v", err) + } + return publicBlob[:pun], privateBlob[:prn], nil +} + +// Creates a new JetStream encryption key, seals it to the TPM, and saves the public and +// private blobs to disk in a JSON encoded file. The key is returned as a string. +func createAndSealJsEncryptionKey(rwc io.ReadWriteCloser, srkHandle tpmutil.Handle, srkPassword, jsKeyFile, jsKeyPassword string, pcr int) (string, error) { + // Get the authorization policy that will protect the data to be sealed + sessHandle, policy, err := policyPCRPasswordSession(rwc, pcr) + if err != nil { + return "", fmt.Errorf("unable to get policy: %v", err) + } + if err := tpm2.FlushContext(rwc, sessHandle); err != nil { + return "", fmt.Errorf("unable to flush session: %v", err) + } + // Seal the data to the parent key and the policy + user, err := nkeys.CreateUser() + if err != nil { + return "", fmt.Errorf("unable to create seed: %v", err) + } + // We'll use the seed to represent the encryption key. + jsStoreKey, err := user.Seed() + if err != nil { + return "", fmt.Errorf("unable to get seed: %v", err) + } + privateArea, publicArea, err := tpm2.Seal(rwc, srkHandle, srkPassword, jsKeyPassword, policy, jsStoreKey) + if err != nil { + return "", fmt.Errorf("unable to seal data: %v", err) + } + err = writeTPMKeysToFile(jsKeyFile, privateArea, publicArea) + if err != nil { + return "", fmt.Errorf("unable to write key file: %v", err) + } + return string(jsStoreKey), nil +} + +// Unseals the JetStream encryption key from the TPM with the provided keys. +// The key is returned as a string. +func unsealJsEncrpytionKey(rwc io.ReadWriteCloser, pcr int, srkHandle tpmutil.Handle, srkPassword, objectPassword string, publicBlob, privateBlob []byte) (string, error) { + // Load the public/private blobs into the TPM for decryption. + objectHandle, _, err := tpm2.Load(rwc, srkHandle, srkPassword, publicBlob, privateBlob) + if err != nil { + return "", fmt.Errorf("unable to load data: %v", err) + } + defer tpm2.FlushContext(rwc, objectHandle) + + // Create the authorization session with TPM. + sessHandle, _, err := policyPCRPasswordSession(rwc, pcr) + if err != nil { + return "", fmt.Errorf("unable to get auth session: %v", err) + } + defer func() { + tpm2.FlushContext(rwc, sessHandle) + }() + // Unseal the data we've loaded into the TPM with the object (js key) password. + unsealedData, err := tpm2.UnsealWithSession(rwc, sessHandle, objectHandle, objectPassword) + if err != nil { + return "", fmt.Errorf("unable to unseal data: %v", err) + } + return string(unsealedData), nil +} + +// Returns session handle and policy digest. +func policyPCRPasswordSession(rwc io.ReadWriteCloser, pcr int) (sessHandle tpmutil.Handle, policy []byte, retErr error) { + sessHandle, _, err := tpm2.StartAuthSession( + rwc, + tpm2.HandleNull, /*tpmKey*/ + tpm2.HandleNull, /*bindKey*/ + make([]byte, 16), /*nonceCaller*/ + nil, /*secret*/ + tpm2.SessionPolicy, + tpm2.AlgNull, + tpm2.AlgSHA256) + if err != nil { + return tpm2.HandleNull, nil, fmt.Errorf("unable to start session: %v", err) + } + defer func() { + if sessHandle != tpm2.HandleNull && err != nil { + if err := tpm2.FlushContext(rwc, sessHandle); err != nil { + retErr = fmt.Errorf("%v\nunable to flush session: %v", retErr, err) + } + } + }() + + pcrSelection := tpm2.PCRSelection{ + Hash: tpm2.AlgSHA256, + PCRs: []int{pcr}, + } + if err := tpm2.PolicyPCR(rwc, sessHandle, nil, pcrSelection); err != nil { + return sessHandle, nil, fmt.Errorf("unable to bind PCRs to auth policy: %v", err) + } + if err := tpm2.PolicyPassword(rwc, sessHandle); err != nil { + return sessHandle, nil, fmt.Errorf("unable to require password for auth policy: %v", err) + } + policy, err = tpm2.PolicyGetDigest(rwc, sessHandle) + if err != nil { + return sessHandle, nil, fmt.Errorf("unable to get policy digest: %v", err) + } + return sessHandle, policy, nil +} + +// LoadJetStreamEncryptionKeyFromTPM loads the JetStream encryption key from the TPM. +// If the keyfile does not exist, a key will be created and sealed. Public and private blobs +// used to decrypt the key in future sessions will be saved to disk in the file provided. +// The key will be unsealed and returned only with the correct password and PCR value. +func LoadJetStreamEncryptionKeyFromTPM(srkPassword, jsKeyFile, jsKeyPassword string, pcr int) (string, error) { + rwc, err := tpm2.OpenTPM() + if err != nil { + return "", fmt.Errorf("could not open the TPM: %v", err) + } + defer rwc.Close() + + // Load the key from the TPM + srkHandle, err := regenerateSRK(rwc, srkPassword) + defer func() { + tpm2.FlushContext(rwc, srkHandle) + }() + if err != nil { + return "", fmt.Errorf("unable to regenerate SRK from the TPM: %v", err) + } + // Read the keys from the key file. If the filed doesn't exist it means we need to create + // a new js encrytpion key. + publicBlob, privateBlob, err := readTPMKeysFromFile(jsKeyFile) + if err != nil { + if os.IsNotExist(err) { + jsek, err := createAndSealJsEncryptionKey(rwc, srkHandle, srkPassword, jsKeyFile, jsKeyPassword, pcr) + if err != nil { + return "", fmt.Errorf("unable to generate new key from the TPM: %v", err) + } + // we've created and sealed the JS Encryption key, now we just return it. + return jsek, nil + } + return "", fmt.Errorf("unable to load key from TPM: %v", err) + } + + // Unseal the JetStream encryption key using the TPM. + jsek, err := unsealJsEncrpytionKey(rwc, pcr, srkHandle, srkPassword, jsKeyPassword, publicBlob, privateBlob) + if err != nil { + return "", fmt.Errorf("unable to unseal key from the TPM: %v", err) + } + return jsek, nil +} diff --git a/vendor/github.com/nats-io/nats-server/v2/server/util.go b/vendor/github.com/nats-io/nats-server/v2/server/util.go index 10a8d8d67b..f9fd695c32 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/util.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/util.go @@ -23,7 +23,6 @@ import ( "net" "net/url" "reflect" - "regexp" "strconv" "strings" "time" @@ -39,11 +38,9 @@ const ( asciiNine = 57 ) -var semVerRe = regexp.MustCompile(`\Av?([0-9]+)\.?([0-9]+)?\.?([0-9]+)?`) - func versionComponents(version string) (major, minor, patch int, err error) { m := semVerRe.FindStringSubmatch(version) - if m == nil { + if len(m) == 0 { return 0, 0, 0, errors.New("invalid semver") } major, err = strconv.Atoi(m[1]) diff --git a/vendor/github.com/nats-io/nats-server/v2/server/websocket.go b/vendor/github.com/nats-io/nats-server/v2/server/websocket.go index 90aa3d82f4..67239c3752 100644 --- a/vendor/github.com/nats-io/nats-server/v2/server/websocket.go +++ b/vendor/github.com/nats-io/nats-server/v2/server/websocket.go @@ -105,18 +105,21 @@ var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var wsTestRejectNoMasking = false type websocket struct { - frames net.Buffers - fs int64 - closeMsg []byte - compress bool - closeSent bool - browser bool - nocompfrag bool // No fragment for compressed frames - maskread bool - maskwrite bool - compressor *flate.Writer - cookieJwt string - clientIP string + frames net.Buffers + fs int64 + closeMsg []byte + compress bool + closeSent bool + browser bool + nocompfrag bool // No fragment for compressed frames + maskread bool + maskwrite bool + compressor *flate.Writer + cookieJwt string + cookieUsername string + cookiePassword string + cookieToken string + clientIP string } type srvWebsocket struct { @@ -128,7 +131,8 @@ type srvWebsocket struct { sameOrigin bool connectURLs []string connectURLsMap refCountedUrlSet - authOverride bool // indicate if there is auth override in websocket config + authOverride bool // indicate if there is auth override in websocket config + rawHeaders string // raw headers to be used in the upgrade response. // These are immutable and can be accessed without lock. // This is the case when generating the client INFO. @@ -791,6 +795,9 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe if kind == MQTT { p = append(p, wsMQTTSecProto...) } + if s.websocket.rawHeaders != _EMPTY_ { + p = append(p, s.websocket.rawHeaders...) + } p = append(p, _CRLF_...) if _, err = conn.Write(p); err != nil { @@ -824,9 +831,19 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe // So make the combination of the two. ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/") } - if opts.Websocket.JWTCookie != _EMPTY_ { - if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil { - ws.cookieJwt = c.Value + + if cookies := r.Cookies(); len(cookies) > 0 { + ows := &opts.Websocket + for _, c := range cookies { + if ows.JWTCookie == c.Name { + ws.cookieJwt = c.Value + } else if ows.UsernameCookie == c.Name { + ws.cookieUsername = c.Value + } else if ows.PasswordCookie == c.Name { + ws.cookiePassword = c.Value + } else if ows.TokenCookie == c.Name { + ws.cookieToken = c.Value + } } } } @@ -1023,6 +1040,24 @@ func validateWebsocketOptions(o *Options) error { if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil { return fmt.Errorf("websocket: %v", err) } + + // Check for invalid headers here. + for key := range wo.Headers { + k := strings.ToLower(key) + switch k { + case "host", + "content-length", + "connection", + "upgrade", + "nats-no-masking": + return fmt.Errorf("websocket: invalid header %q not allowed", key) + } + + if strings.HasPrefix(k, "sec-websocket-") { + return fmt.Errorf("websocket: invalid header %q, \"Sec-WebSocket-\" prefix not allowed", key) + } + } + return nil } @@ -1054,6 +1089,21 @@ func (s *Server) wsSetOriginOptions(o *WebsocketOpts) { } } +// Calculate the raw headers for websocket upgrade response. +func (s *Server) wsSetHeadersOptions(o *WebsocketOpts) { + var sb strings.Builder + for k, v := range o.Headers { + sb.WriteString(k) + sb.WriteString(": ") + sb.WriteString(v) + sb.WriteString(_CRLF_) + } + ws := &s.websocket + ws.mu.Lock() + defer ws.mu.Unlock() + ws.rawHeaders = sb.String() +} + // Given the websocket options, we check if any auth configuration // has been provided. If so, possibly create users/nkey users and // store them in s.websocket.users/nkeys. @@ -1075,6 +1125,7 @@ func (s *Server) startWebsocketServer() { o := &sopts.Websocket s.wsSetOriginOptions(o) + s.wsSetHeadersOptions(o) var hl net.Listener var proto string @@ -1149,7 +1200,7 @@ func (s *Server) startWebsocketServer() { if !hasLeaf { s.Errorf("Not configured to accept leaf node connections") // Silently close for now. If we want to send an error back, we would - // need to create the leafnode client anyway, so that is is handling websocket + // need to create the leafnode client anyway, so that is handling websocket // frames, then send the error to the remote. res.conn.Close() return diff --git a/vendor/github.com/onsi/ginkgo/v2/CHANGELOG.md b/vendor/github.com/onsi/ginkgo/v2/CHANGELOG.md index 50af9d0baa..056dba6ef4 100644 --- a/vendor/github.com/onsi/ginkgo/v2/CHANGELOG.md +++ b/vendor/github.com/onsi/ginkgo/v2/CHANGELOG.md @@ -1,3 +1,15 @@ +## 2.23.2 + +🎉🎉🎉 + +At long last, some long-standing performance gaps between `ginkgo` and `go test` have been resolved! + +Ginkgo operates by running `go test -c` to generate test binaries, and then running those binaries. It turns out that the compilation step of `go test -c` is slower than `go test`'s compilation step because `go test` strips out debug symbols (`ldflags=-w`) whereas `go test -c` does not. + +Ginkgo now passes the appropriate `ldflags` to `go test -c` when running specs to strip out symbols. This is only done when it is safe to do so and symbols are preferred when profiling is enabled and when `ginkgo build` is called explicitly. + +This, coupled, with the [instructions for disabling XProtect on MacOS](https://onsi.github.io/ginkgo/#if-you-are-running-on-macos) yields a much better performance experience with Ginkgo. + ## 2.23.1 ## 🚨 For users on MacOS 🚨 diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go index a071b8d091..2b36b2feb9 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go @@ -44,7 +44,7 @@ func buildSpecs(args []string, cliConfig types.CLIConfig, goFlagsConfig types.Go internal.VerifyCLIAndFrameworkVersion(suites) opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers()) - opc.StartCompiling(suites, goFlagsConfig) + opc.StartCompiling(suites, goFlagsConfig, true) for { suiteIdx, suite := opc.Next() diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go index a9cf9fcb2c..12d4a32c87 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go @@ -26,7 +26,7 @@ func (c Command) Run(args []string, additionalArgs []string) { } for _, arg := range args { if strings.HasPrefix(arg, "-") { - AbortWith("Malformed arguments - make sure all flags appear {{bold}}after{{/}} the Ginkgo subcommand and {{bold}}before{{/}} your list of packages.\n{{gray}}e.g. 'ginkgo run -p my_package' is valid `ginkgo -p run my_package` is not.{{/}}") + AbortWith(types.GinkgoErrors.FlagAfterPositionalParameter().Error()) } } c.Command(args, additionalArgs) diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go index 48827cc5ef..7bbe6be0fc 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go @@ -11,7 +11,7 @@ import ( "github.com/onsi/ginkgo/v2/types" ) -func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite { +func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig, preserveSymbols bool) TestSuite { if suite.PathToCompiledTest != "" { return suite } @@ -46,7 +46,7 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error()) return suite } - args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, "./", pathToInvocationPath) + args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, "./", pathToInvocationPath, preserveSymbols) if err != nil { suite.State = TestSuiteStateFailedToCompile suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error()) @@ -120,7 +120,7 @@ func NewOrderedParallelCompiler(numCompilers int) *OrderedParallelCompiler { } } -func (opc *OrderedParallelCompiler) StartCompiling(suites TestSuites, goFlagsConfig types.GoFlagsConfig) { +func (opc *OrderedParallelCompiler) StartCompiling(suites TestSuites, goFlagsConfig types.GoFlagsConfig, preserveSymbols bool) { opc.stopped = false opc.idx = 0 opc.numSuites = len(suites) @@ -135,7 +135,7 @@ func (opc *OrderedParallelCompiler) StartCompiling(suites TestSuites, goFlagsCon stopped := opc.stopped opc.mutex.Unlock() if !stopped { - suite = CompileSuite(suite, goFlagsConfig) + suite = CompileSuite(suite, goFlagsConfig, preserveSymbols) } c <- suite } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go index b7d77390bb..03875b9796 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go @@ -107,7 +107,7 @@ OUTER_LOOP: } opc := internal.NewOrderedParallelCompiler(r.cliConfig.ComputedNumCompilers()) - opc.StartCompiling(suites, r.goFlagsConfig) + opc.StartCompiling(suites, r.goFlagsConfig, false) SUITE_LOOP: for { diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go index bde4193ce7..fe1ca30519 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go @@ -153,7 +153,7 @@ func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { } func (w *SpecWatcher) compileAndRun(suite internal.TestSuite, additionalArgs []string) internal.TestSuite { - suite = internal.CompileSuite(suite, w.goFlagsConfig) + suite = internal.CompileSuite(suite, w.goFlagsConfig, false) if suite.State.Is(internal.TestSuiteStateFailedToCompile) { fmt.Println(suite.CompilationError.Error()) return suite diff --git a/vendor/github.com/onsi/ginkgo/v2/types/config.go b/vendor/github.com/onsi/ginkgo/v2/types/config.go index 3d543c1219..ca837b0557 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/config.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/config.go @@ -231,6 +231,10 @@ func (g GoFlagsConfig) BinaryMustBePreserved() bool { return g.BlockProfile != "" || g.CPUProfile != "" || g.MemProfile != "" || g.MutexProfile != "" } +func (g GoFlagsConfig) NeedsSymbols() bool { + return g.BinaryMustBePreserved() +} + // Configuration that were deprecated in 2.0 type deprecatedConfig struct { DebugParallel bool @@ -640,7 +644,7 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo } // GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test -func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, packageToBuild string, pathToInvocationPath string) ([]string, error) { +func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, packageToBuild string, pathToInvocationPath string, preserveSymbols bool) ([]string, error) { // if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure // the built test binary can generate a coverprofile if goFlagsConfig.CoverProfile != "" { @@ -663,6 +667,10 @@ func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, packageToBuild strin goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",") } + if !goFlagsConfig.NeedsSymbols() && goFlagsConfig.LDFlags == "" && !preserveSymbols { + goFlagsConfig.LDFlags = "-w -s" + } + args := []string{"test", "-c", packageToBuild} goArgs, err := GenerateFlagArgs( GoBuildFlags, diff --git a/vendor/github.com/onsi/ginkgo/v2/types/errors.go b/vendor/github.com/onsi/ginkgo/v2/types/errors.go index 854252ac2a..c3f562f776 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/errors.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/errors.go @@ -636,6 +636,13 @@ func (g ginkgoErrors) ExpectFilenameNotPath(flag string, path string) error { } } +func (g ginkgoErrors) FlagAfterPositionalParameter() error { + return GinkgoError{ + Heading: "Malformed arguments - detected a flag after the package liste", + Message: "Make sure all flags appear {{bold}}after{{/}} the Ginkgo subcommand and {{bold}}before{{/}} your list of packages (or './...').\n{{gray}}e.g. 'ginkgo run -p my_package' is valid but `ginkgo -p run my_package` is not.\n{{gray}}e.g. 'ginkgo -p -vet ./...' is valid but 'ginkgo -p ./... -vet' is not{{/}}", + } +} + /* Stack-Trace parsing errors */ func (g ginkgoErrors) FailedToParseStackTrace(message string) error { diff --git a/vendor/github.com/onsi/ginkgo/v2/types/version.go b/vendor/github.com/onsi/ginkgo/v2/types/version.go index 8d38790bf2..48c1e88be3 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/version.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/version.go @@ -1,3 +1,3 @@ package types -const VERSION = "2.23.1" +const VERSION = "2.23.2" diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/lookup/lookup.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/lookup/lookup.go index 01d0632ae7..11f75887c5 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/lookup/lookup.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/lookup/lookup.go @@ -45,7 +45,7 @@ import ( var tracer trace.Tracer -const RevisionsDir = ".oc-nodes" +const MetadataDir = ".oc-nodes" var _spaceTypePersonal = "personal" var _spaceTypeProject = "project" @@ -288,7 +288,7 @@ func (lu *Lookup) InternalPath(spaceID, nodeID string) string { if len(spaceRoot) == 0 { return "" } - return filepath.Join(spaceRoot, RevisionsDir, Pathify(nodeID, 4, 2)) + return filepath.Join(spaceRoot, MetadataDir, Pathify(nodeID, 4, 2)) } path, _ := lu.IDCache.Get(context.Background(), spaceID, nodeID) @@ -303,7 +303,7 @@ func (lu *Lookup) VersionPath(spaceID, nodeID, version string) string { return "" } - return filepath.Join(spaceRoot, RevisionsDir, Pathify(nodeID, 4, 2)+node.RevisionIDDelimiter+version) + return filepath.Join(spaceRoot, MetadataDir, Pathify(nodeID, 4, 2)+node.RevisionIDDelimiter+version) } // VersionPath returns the "current" path of the node @@ -313,7 +313,7 @@ func (lu *Lookup) CurrentPath(spaceID, nodeID string) string { return "" } - return filepath.Join(spaceRoot, RevisionsDir, Pathify(nodeID, 4, 2)+node.CurrentIDDelimiter) + return filepath.Join(spaceRoot, MetadataDir, Pathify(nodeID, 4, 2)+node.CurrentIDDelimiter) } // refFromCS3 creates a CS3 reference from a set of bytes. This method should remain private diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/posix.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/posix.go index 69977f4778..7a763c9aea 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/posix.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/posix.go @@ -87,7 +87,7 @@ func New(m map[string]interface{}, stream events.Stream, log *zerolog.Logger) (s return "" } - return filepath.Join(spaceRoot, lookup.RevisionsDir, lookup.Pathify(n.GetID(), 4, 2)+".mpk") + return filepath.Join(spaceRoot, lookup.MetadataDir) }, o.FileMetadataCache), um, o, &timemanager.Manager{}) default: diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/revisions.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/revisions.go index 1ca30aded7..c0d14e897d 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/revisions.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/revisions.go @@ -24,6 +24,7 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" "time" @@ -63,7 +64,44 @@ func (tp *Tree) CreateRevision(ctx context.Context, n *node.Node, version string vf, err := os.OpenFile(versionPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) if err != nil { if os.IsExist(err) { - err := os.Remove(versionPath) + dir := filepath.Dir(versionPath) + base := filepath.Base(versionPath) + files, err := os.ReadDir(dir) + if err != nil { + return "", err + } + + // find revision with highest number + highest := 0 + for _, file := range files { + if file.IsDir() { + continue + } + name := file.Name() + if !strings.HasPrefix(name, base) { + continue + } + ext := strings.TrimPrefix(name, base+".") + if ext == "" || ext == base { + continue + } + num, err := strconv.Atoi(ext) + if err != nil { + continue + } + if num > highest { + highest = num + } + } + + // rename existing revision + oldNode := node.NewBaseNode(n.SpaceID, n.ID+node.RevisionIDDelimiter+version+"."+strconv.Itoa(highest+1), tp.lookup) + err = tp.lookup.MetadataBackend().Rename(revNode, oldNode) + if err != nil { + return "", err + } + newPath := versionPath + "." + strconv.Itoa(highest+1) + err = os.Rename(versionPath, newPath) if err != nil { return "", err } diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/tree.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/tree.go index 91f3d2e507..9014063872 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/tree.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/tree/tree.go @@ -662,7 +662,7 @@ func (t *Tree) isIndex(path string) bool { func (t *Tree) isInternal(path string) bool { return path == t.options.Root || path == filepath.Join(t.options.Root, "users") || - t.isIndex(path) || strings.Contains(path, lookup.RevisionsDir) + t.isIndex(path) || strings.Contains(path, lookup.MetadataDir) } func isLockFile(path string) bool { diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/hybrid_backend.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/hybrid_backend.go index 3835ca41be..7f3244b819 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/hybrid_backend.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/hybrid_backend.go @@ -199,15 +199,11 @@ func (b HybridBackend) Set(ctx context.Context, n MetadataNode, key string, val func (b HybridBackend) SetMultiple(ctx context.Context, n MetadataNode, attribs map[string][]byte, acquireLock bool) (err error) { path := n.InternalPath() if acquireLock { - err := os.MkdirAll(filepath.Dir(path), 0600) + unlock, err := b.Lock(n) if err != nil { return err } - lockedFile, err := lockedfile.OpenFile(b.LockfilePath(n), os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return err - } - defer cleanupLockfile(ctx, lockedFile) + defer func() { _ = unlock() }() } offloadAttr, err := xattr.Get(path, _metadataOffloadedAttr) @@ -476,17 +472,37 @@ func (b HybridBackend) Rename(oldNode, newNode MetadataNode) error { } // MetadataPath returns the path of the file holding the metadata for the given path -func (b HybridBackend) MetadataPath(n MetadataNode) string { return b.metadataPathFunc(n) } +func (b HybridBackend) MetadataPath(n MetadataNode) string { + base := b.metadataPathFunc(n) + + return filepath.Join(base, pathify(n.GetID(), 4, 2)+".mpk") +} // LockfilePath returns the path of the lock file -func (HybridBackend) LockfilePath(n MetadataNode) string { return n.InternalPath() + ".mlock" } +func (b HybridBackend) LockfilePath(n MetadataNode) string { + base := b.metadataPathFunc(n) + + return filepath.Join(base, "locks", n.GetID()+".mlock") +} // Lock locks the metadata for the given path func (b HybridBackend) Lock(n MetadataNode) (UnlockFunc, error) { metaLockPath := b.LockfilePath(n) mlock, err := lockedfile.OpenFile(metaLockPath, os.O_RDWR|os.O_CREATE, 0600) if err != nil { - return nil, err + if errors.Is(err, os.ErrNotExist) { + // create the parent directory + err = os.MkdirAll(filepath.Dir(metaLockPath), 0700) + if err != nil { + return nil, err + } + mlock, err = lockedfile.OpenFile(metaLockPath, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return nil, err + } + } else { + return nil, err + } } return func() error { err := mlock.Close() @@ -513,3 +529,17 @@ func (b HybridBackend) cacheKey(n MetadataNode) string { func isOffloadingAttribute(key string) bool { return strings.HasPrefix(key, prefixes.GrantPrefix) || strings.HasPrefix(key, prefixes.MetadataPrefix) } + +func pathify(id string, depth, width int) string { + b := strings.Builder{} + i := 0 + for ; i < depth; i++ { + if len(id) <= i*width+width { + break + } + b.WriteString(id[i*width : i*width+width]) + b.WriteRune(filepath.Separator) + } + b.WriteString(id[i*width:]) + return b.String() +} diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/messagepack_backend.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/messagepack_backend.go index fc1c873bd9..40de764fe7 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/messagepack_backend.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/messagepack_backend.go @@ -43,12 +43,6 @@ type MessagePackBackend struct { metaCache cache.FileMetadataCache } -type readWriteCloseSeekTruncater interface { - io.ReadWriteCloser - io.Seeker - Truncate(int64) error -} - // NewMessagePackBackend returns a new MessagePackBackend instance func NewMessagePackBackend(o cache.Config) MessagePackBackend { return MessagePackBackend{ @@ -148,7 +142,6 @@ func (b MessagePackBackend) AllWithLockedSource(ctx context.Context, n MetadataN func (b MessagePackBackend) saveAttributes(ctx context.Context, n MetadataNode, setAttribs map[string][]byte, deleteAttribs []string, acquireLock bool) error { var ( err error - f readWriteCloseSeekTruncater ) ctx, span := tracer.Start(ctx, "saveAttributes") defer func() { @@ -160,16 +153,13 @@ func (b MessagePackBackend) saveAttributes(ctx context.Context, n MetadataNode, span.End() }() - lockPath := b.LockfilePath(n) metaPath := b.MetadataPath(n) if acquireLock { - _, subspan := tracer.Start(ctx, "lockedfile.OpenFile") - f, err = lockedfile.OpenFile(lockPath, os.O_RDWR|os.O_CREATE, 0600) - subspan.End() + unlock, err := b.Lock(n) if err != nil { return err } - defer f.Close() + defer func() { _ = unlock() }() } // Read current state _, subspan := tracer.Start(ctx, "os.ReadFile") diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/xattrs_backend.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/xattrs_backend.go index 4460872098..8fbd4940e4 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/xattrs_backend.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/metadata/xattrs_backend.go @@ -168,7 +168,7 @@ func (b XattrsBackend) Set(ctx context.Context, n MetadataNode, key string, val func (b XattrsBackend) SetMultiple(ctx context.Context, n MetadataNode, attribs map[string][]byte, acquireLock bool) (err error) { path := n.InternalPath() if acquireLock { - err := os.MkdirAll(filepath.Dir(path), 0600) + err := os.MkdirAll(filepath.Dir(path), 0700) if err != nil { return err } diff --git a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/tree/revisions.go b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/tree/revisions.go index 899cb95efa..2a6e180c9e 100644 --- a/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/tree/revisions.go +++ b/vendor/github.com/opencloud-eu/reva/v2/pkg/storage/pkg/decomposedfs/tree/revisions.go @@ -24,13 +24,13 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" "time" provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1" "github.com/pkg/errors" "github.com/rogpeppe/go-internal/lockedfile" - "github.com/shamaton/msgpack/v2" "github.com/opencloud-eu/reva/v2/pkg/appctx" "github.com/opencloud-eu/reva/v2/pkg/errtypes" @@ -63,32 +63,48 @@ func (tp *Tree) CreateRevision(ctx context.Context, n *node.Node, version string vf, err := os.OpenFile(versionPath, os.O_CREATE|os.O_EXCL, 0600) if err != nil { if os.IsExist(err) { - revisionNode := node.NewBaseNode(n.SpaceID, n.ID+node.RevisionIDDelimiter+version, tp.lookup) - revisionPath := tp.lookup.MetadataBackend().MetadataPath(revisionNode) - b, err := os.ReadFile(revisionPath) + dir := filepath.Dir(versionPath) + base := filepath.Base(versionPath) + files, err := os.ReadDir(dir) if err != nil { return "", err } - m := map[string][]byte{} - if err := msgpack.Unmarshal(b, &m); err != nil { - return "", err - } - - bid := m["user.oc.blobid"] - if string(bid) != "" { - if err := tp.DeleteBlob(&node.Node{ - BaseNode: *revisionNode, - BlobID: string(bid), - }); err != nil { - return "", err + // find revision with highest number + highest := 0 + for _, file := range files { + if file.IsDir() { + continue + } + name := file.Name() + if !strings.HasPrefix(name, base) || strings.HasSuffix(name, ".mpk") { + continue + } + ext := strings.TrimPrefix(name, base+".") + if ext == "" || ext == base { + continue + } + num, err := strconv.Atoi(ext) + if err != nil { + continue + } + if num > highest { + highest = num } } - err = os.Remove(versionPath) + // rename existing revision + oldNode := node.NewBaseNode(n.SpaceID, n.ID+node.RevisionIDDelimiter+version+"."+strconv.Itoa(highest+1), tp.lookup) + err = tp.lookup.MetadataBackend().Rename(versionNode, oldNode) if err != nil { return "", err } + newPath := versionPath + "." + strconv.Itoa(highest+1) + err = os.Rename(versionPath, newPath) + if err != nil { + return "", err + } + vf, err = os.OpenFile(versionPath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) if err != nil { return "", err diff --git a/vendor/golang.org/x/time/rate/rate.go b/vendor/golang.org/x/time/rate/rate.go index ec5f0cdd0c..794b2e32bf 100644 --- a/vendor/golang.org/x/time/rate/rate.go +++ b/vendor/golang.org/x/time/rate/rate.go @@ -85,7 +85,7 @@ func (lim *Limiter) Burst() int { // TokensAt returns the number of tokens available at time t. func (lim *Limiter) TokensAt(t time.Time) float64 { lim.mu.Lock() - _, tokens := lim.advance(t) // does not mutate lim + tokens := lim.advance(t) // does not mutate lim lim.mu.Unlock() return tokens } @@ -186,7 +186,7 @@ func (r *Reservation) CancelAt(t time.Time) { return } // advance time to now - t, tokens := r.lim.advance(t) + tokens := r.lim.advance(t) // calculate new number of tokens tokens += restoreTokens if burst := float64(r.lim.burst); tokens > burst { @@ -307,7 +307,7 @@ func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) { lim.mu.Lock() defer lim.mu.Unlock() - t, tokens := lim.advance(t) + tokens := lim.advance(t) lim.last = t lim.tokens = tokens @@ -324,7 +324,7 @@ func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) { lim.mu.Lock() defer lim.mu.Unlock() - t, tokens := lim.advance(t) + tokens := lim.advance(t) lim.last = t lim.tokens = tokens @@ -347,7 +347,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) } } - t, tokens := lim.advance(t) + tokens := lim.advance(t) // Calculate the remaining number of tokens resulting from the request. tokens -= float64(n) @@ -380,10 +380,11 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) return r } -// advance calculates and returns an updated state for lim resulting from the passage of time. +// advance calculates and returns an updated number of tokens for lim +// resulting from the passage of time. // lim is not changed. // advance requires that lim.mu is held. -func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) { +func (lim *Limiter) advance(t time.Time) (newTokens float64) { last := lim.last if t.Before(last) { last = t @@ -396,7 +397,7 @@ func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) { if burst := float64(lim.burst); tokens > burst { tokens = burst } - return t, tokens + return tokens } // durationFromTokens is a unit conversion function from the number of tokens to the duration diff --git a/vendor/modules.txt b/vendor/modules.txt index bb8c6133d4..96f399040a 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -703,6 +703,11 @@ github.com/google/go-querystring/query # github.com/google/go-tika v0.3.1 ## explicit; go 1.11 github.com/google/go-tika/tika +# github.com/google/go-tpm v0.9.3 +## explicit; go 1.22 +github.com/google/go-tpm/legacy/tpm2 +github.com/google/go-tpm/tpmutil +github.com/google/go-tpm/tpmutil/tbs # github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad ## explicit; go 1.22 github.com/google/pprof/profile @@ -988,7 +993,7 @@ github.com/munnerz/goautoneg # github.com/nats-io/jwt/v2 v2.7.3 ## explicit; go 1.22 github.com/nats-io/jwt/v2 -# github.com/nats-io/nats-server/v2 v2.10.26 +# github.com/nats-io/nats-server/v2 v2.11.0 ## explicit; go 1.23.0 github.com/nats-io/nats-server/v2/conf github.com/nats-io/nats-server/v2/internal/fastrand @@ -1002,6 +1007,8 @@ github.com/nats-io/nats-server/v2/server/gsl github.com/nats-io/nats-server/v2/server/pse github.com/nats-io/nats-server/v2/server/stree github.com/nats-io/nats-server/v2/server/sysmem +github.com/nats-io/nats-server/v2/server/thw +github.com/nats-io/nats-server/v2/server/tpm # github.com/nats-io/nats.go v1.39.1 ## explicit; go 1.22.0 github.com/nats-io/nats.go @@ -1057,7 +1064,7 @@ github.com/onsi/ginkgo/reporters/stenographer github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty github.com/onsi/ginkgo/types -# github.com/onsi/ginkgo/v2 v2.23.1 +# github.com/onsi/ginkgo/v2 v2.23.2 ## explicit; go 1.23.0 github.com/onsi/ginkgo/v2 github.com/onsi/ginkgo/v2/config @@ -1191,7 +1198,7 @@ github.com/open-policy-agent/opa/v1/types github.com/open-policy-agent/opa/v1/util github.com/open-policy-agent/opa/v1/util/decoding github.com/open-policy-agent/opa/v1/version -# github.com/opencloud-eu/reva/v2 v2.28.1-0.20250320135948-a946c0d6d289 +# github.com/opencloud-eu/reva/v2 v2.28.1-0.20250321112659-61a430bfb4c5 ## explicit; go 1.24.1 github.com/opencloud-eu/reva/v2/cmd/revad/internal/grace github.com/opencloud-eu/reva/v2/cmd/revad/runtime @@ -2225,8 +2232,8 @@ golang.org/x/text/transform golang.org/x/text/unicode/bidi golang.org/x/text/unicode/norm golang.org/x/text/width -# golang.org/x/time v0.10.0 -## explicit; go 1.18 +# golang.org/x/time v0.11.0 +## explicit; go 1.23.0 golang.org/x/time/rate # golang.org/x/tools v0.31.0 ## explicit; go 1.23.0