build(deps): bump github.com/open-policy-agent/opa from 1.10.1 to 1.11.0

Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 1.10.1 to 1.11.0.
- [Release notes](https://github.com/open-policy-agent/opa/releases)
- [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-policy-agent/opa/compare/v1.10.1...v1.11.0)

---
updated-dependencies:
- dependency-name: github.com/open-policy-agent/opa
  dependency-version: 1.11.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot]
2025-12-17 15:26:39 +00:00
committed by Ralf Haferkamp
parent 5ed944dfc3
commit 82c82f8ae2
94 changed files with 6802 additions and 1946 deletions

16
go.mod
View File

@@ -61,7 +61,7 @@ require (
github.com/onsi/ginkgo v1.16.5
github.com/onsi/ginkgo/v2 v2.27.2
github.com/onsi/gomega v1.38.2
github.com/open-policy-agent/opa v1.10.1
github.com/open-policy-agent/opa v1.11.1
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20250724122329-41ba6b191e76
github.com/opencloud-eu/reva/v2 v2.41.0
@@ -172,7 +172,7 @@ require (
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v1.0.0-rc.1 // indirect
github.com/containerd/platforms v1.0.0-rc.2 // indirect
github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/cornelk/hashmap v1.0.8 // indirect
@@ -237,7 +237,7 @@ require (
github.com/gofrs/uuid v4.4.0+incompatible // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gomodule/redigo v1.9.3 // indirect
github.com/google/go-querystring v1.1.0 // indirect
@@ -271,7 +271,7 @@ require (
github.com/lestrrat-go/dsig-secp256k1 v1.0.0 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc/v3 v3.0.1 // indirect
github.com/lestrrat-go/jwx/v3 v3.0.11 // indirect
github.com/lestrrat-go/jwx/v3 v3.0.12 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lestrrat-go/option/v2 v2.0.0 // indirect
github.com/libregraph/oidc-go v1.1.0 // indirect
@@ -302,7 +302,7 @@ require (
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/mschoch/smat v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
@@ -341,7 +341,7 @@ require (
github.com/samber/lo v1.51.0 // indirect
github.com/samber/slog-common v0.19.0 // indirect
github.com/samber/slog-zerolog/v2 v2.9.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/segmentio/kafka-go v0.4.49 // indirect
github.com/segmentio/ksuid v1.0.4 // indirect
github.com/sercand/kuberesolver/v5 v5.1.1 // indirect
@@ -367,9 +367,9 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect
github.com/trustelem/zxcvbn v1.0.1 // indirect
github.com/urfave/cli/v2 v2.27.5 // indirect
github.com/urfave/cli/v2 v2.27.7 // indirect
github.com/valyala/fastjson v1.6.4 // indirect
github.com/vektah/gqlparser/v2 v2.5.30 // indirect
github.com/vektah/gqlparser/v2 v2.5.31 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wk8/go-ordered-map v1.0.0 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect

34
go.sum
View File

@@ -198,8 +198,8 @@ github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/
github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c=
github.com/butonic/go-micro/v4 v4.11.1-0.20241115112658-b5d4de5ed9b3 h1:h8Z0hBv5tg/uZMKu8V47+DKWYVQg0lYP8lXDQq7uRpE=
github.com/butonic/go-micro/v4 v4.11.1-0.20241115112658-b5d4de5ed9b3/go.mod h1:eE/tD53n3KbVrzrWxKLxdkGw45Fg1qaNLWjpJMvIUF4=
github.com/bytecodealliance/wasmtime-go/v37 v37.0.0 h1:DPjdn2V3JhXHMoZ2ymRqGK+y1bDyr9wgpyYCvhjMky8=
github.com/bytecodealliance/wasmtime-go/v37 v37.0.0/go.mod h1:Pf1l2JCTUFMnOqDIwkjzx1qfVJ09xbaXETKgRVE4jZ0=
github.com/bytecodealliance/wasmtime-go/v39 v39.0.1 h1:RibaT47yiyCRxMOj/l2cvL8cWiWBSqDXHyqsa9sGcCE=
github.com/bytecodealliance/wasmtime-go/v39 v39.0.1/go.mod h1:miR4NYIEBXeDNamZIzpskhJ0z/p8al+lwMWylQ/ZJb4=
github.com/c-bata/go-prompt v0.2.5/go.mod h1:vFnjEGDIIA/Lib7giyE4E9c50Lvl8j0S+7FVlAwDAVw=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
@@ -239,8 +239,8 @@ github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151X
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v1.0.0-rc.1 h1:83KIq4yy1erSRgOVHNk1HYdPvzdJ5CnsWaRoJX4C41E=
github.com/containerd/platforms v1.0.0-rc.1/go.mod h1:J71L7B+aiM5SdIEqmd9wp6THLVRzJGXfNuWCZCllLA4=
github.com/containerd/platforms v1.0.0-rc.2 h1:0SPgaNZPVWGEi4grZdV8VRYQn78y+nm6acgLGv/QzE4=
github.com/containerd/platforms v1.0.0-rc.2/go.mod h1:J71L7B+aiM5SdIEqmd9wp6THLVRzJGXfNuWCZCllLA4=
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
@@ -511,8 +511,9 @@ github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4er
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
@@ -776,8 +777,8 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.1 h1:3n7Es68YYGZb2Jf+k//llA4FTZMl3yCwIjFIk4ubevI=
github.com/lestrrat-go/httprc/v3 v3.0.1/go.mod h1:2uAvmbXE4Xq8kAUjVrZOq1tZVYYYs5iP62Cmtru00xk=
github.com/lestrrat-go/jwx/v3 v3.0.11 h1:yEeUGNUuNjcez/Voxvr7XPTYNraSQTENJgtVTfwvG/w=
github.com/lestrrat-go/jwx/v3 v3.0.11/go.mod h1:XSOAh2SiXm0QgRe3DulLZLyt+wUuEdFo81zuKTLcvgQ=
github.com/lestrrat-go/jwx/v3 v3.0.12 h1:p25r68Y4KrbBdYjIsQweYxq794CtGCzcrc5dGzJIRjg=
github.com/lestrrat-go/jwx/v3 v3.0.12/go.mod h1:HiUSaNmMLXgZ08OmGBaPVvoZQgJVOQphSrGr5zMamS8=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/option/v2 v2.0.0 h1:XxrcaJESE1fokHy3FpaQ/cXW8ZsIdWcdFzzLOcID3Ss=
@@ -897,8 +898,9 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
@@ -955,8 +957,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
github.com/open-policy-agent/opa v1.10.1 h1:haIvxZSPky8HLjRrvQwWAjCPLg8JDFSZMbbG4yyUHgY=
github.com/open-policy-agent/opa v1.10.1/go.mod h1:7uPI3iRpOalJ0BhK6s1JALWPU9HvaV1XeBSSMZnr/PM=
github.com/open-policy-agent/opa v1.11.1 h1:4bMlG6DjRZTRAswRyF+KUCgxHu1Gsk0h9EbZ4W9REvM=
github.com/open-policy-agent/opa v1.11.1/go.mod h1:QimuJO4T3KYxWzrmAymqlFvsIanCjKrGjmmC8GgAdgE=
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a h1:Sakl76blJAaM6NxylVkgSzktjo2dS504iDotEFJsh3M=
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a/go.mod h1:pjcozWijkNPbEtX5SIQaxEW/h8VAVZYTLx+70bmB3LY=
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89 h1:W1ms+lP5lUUIzjRGDg93WrQfZJZCaV1ZP3KeyXi8bzY=
@@ -1107,8 +1109,8 @@ github.com/samber/slog-zerolog/v2 v2.9.0 h1:6LkOabJmZdNLaUWkTC3IVVA+dq7b/V0FM6lz
github.com/samber/slog-zerolog/v2 v2.9.0/go.mod h1:gnQW9VnCfM34v2pRMUIGMsZOVbYLqY/v0Wxu6atSVGc=
github.com/scaleway/scaleway-sdk-go v1.0.0-beta.7.0.20210127161313-bd30bebeac4f/go.mod h1:CJJ5VAbozOl0yEw7nHB9+7BXTJbIn6h7W+f6Gau5IP8=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/segmentio/kafka-go v0.4.49 h1:GJiNX1d/g+kG6ljyJEoi9++PUMdXGAxb7JGPiDCuNmk=
github.com/segmentio/kafka-go v0.4.49/go.mod h1:Y1gn60kzLEEaW28YshXyk2+VCUKbJ3Qr6DrnT3i4+9E=
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
@@ -1239,15 +1241,15 @@ github.com/tus/tusd/v2 v2.8.0/go.mod h1:3/zEOVQQIwmJhvNam8phV4x/UQt68ZmZiTzeuJUN
github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w=
github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU=
github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ=
github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE=
github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
github.com/vektah/gqlparser/v2 v2.5.31 h1:YhWGA1mfTjID7qJhd1+Vxhpk5HTgydrGU9IgkWBTJ7k=
github.com/vektah/gqlparser/v2 v2.5.31/go.mod h1:c1I28gSOVNzlfc4WuDlqU7voQnsqI6OG2amkBAFmgts=
github.com/vinyldns/go-vinyldns v0.0.0-20200917153823-148a5f6b8f14/go.mod h1:RWc47jtnVuQv6+lY3c768WtXCas/Xi+U5UFc5xULmYg=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=

View File

@@ -38,5 +38,5 @@ func DefaultSpec() specs.Platform {
// Default returns the current platform's default platform specification.
func Default() MatchComparer {
return Only(DefaultSpec())
return &windowsMatchComparer{Matcher: NewMatcher(DefaultSpec())}
}

View File

@@ -42,18 +42,30 @@ const (
// rs5 (version 1809, codename "Redstone 5") corresponds to Windows Server
// 2019 (ltsc2019), and Windows 10 (October 2018 Update).
rs5 = 17763
// ltsc2019 (Windows Server 2019) is an alias for [RS5].
ltsc2019 = rs5
// v21H2Server corresponds to Windows Server 2022 (ltsc2022).
v21H2Server = 20348
// ltsc2022 (Windows Server 2022) is an alias for [v21H2Server]
ltsc2022 = v21H2Server
// v22H2Win11 corresponds to Windows 11 (2022 Update).
v22H2Win11 = 22621
// v23H2 is the 23H2 release in the Windows Server annual channel.
v23H2 = 25398
// Windows Server 2025 build 26100
v25H1Server = 26100
ltsc2025 = v25H1Server
)
// List of stable ABI compliant ltsc releases
// Note: List must be sorted in ascending order
var compatLTSCReleases = []uint16{
v21H2Server,
ltsc2022,
ltsc2025,
}
// CheckHostAndContainerCompat checks if given host and container
@@ -70,18 +82,27 @@ func checkWindowsHostAndContainerCompat(host, ctr windowsOSVersion) bool {
}
// If host is < WS 2022, exact version match is required
if host.Build < v21H2Server {
if host.Build < ltsc2022 {
return host.Build == ctr.Build
}
var supportedLtscRelease uint16
// Find the latest LTSC version that is earlier than the host version.
// This is the earliest version of container that the host can run.
//
// If the host version is an LTSC, then it supports compatibility with
// everything from the previous LTSC up to itself, so we want supportedLTSCRelease
// to be the previous entry.
//
// If no match is found, then we know that the host is LTSC2022 exactly,
// since we already checked that it's not less than LTSC2022.
var supportedLTSCRelease uint16 = ltsc2022
for i := len(compatLTSCReleases) - 1; i >= 0; i-- {
if host.Build >= compatLTSCReleases[i] {
supportedLtscRelease = compatLTSCReleases[i]
if host.Build > compatLTSCReleases[i] {
supportedLTSCRelease = compatLTSCReleases[i]
break
}
}
return ctr.Build >= supportedLtscRelease && ctr.Build <= host.Build
return supportedLTSCRelease <= ctr.Build && ctr.Build <= host.Build
}
func getWindowsOSVersion(osVersionPrefix string) windowsOSVersion {
@@ -114,18 +135,6 @@ func getWindowsOSVersion(osVersionPrefix string) windowsOSVersion {
}
}
func winRevision(v string) int {
parts := strings.Split(v, ".")
if len(parts) < 4 {
return 0
}
r, err := strconv.Atoi(parts[3])
if err != nil {
return 0
}
return r
}
type windowsVersionMatcher struct {
windowsOSVersion
}
@@ -149,8 +158,7 @@ type windowsMatchComparer struct {
func (c *windowsMatchComparer) Less(p1, p2 specs.Platform) bool {
m1, m2 := c.Match(p1), c.Match(p2)
if m1 && m2 {
r1, r2 := winRevision(p1.OSVersion), winRevision(p2.OSVersion)
return r1 > r2
return p1.OSVersion > p2.OSVersion
}
return m1 && !m2
}

View File

@@ -106,6 +106,9 @@ linters:
- revive
path: jwt/internal/types/
text: "var-naming: avoid meaningless package names"
- linters:
- godoclint
path: (^|/)internal/
paths:
- third_party$
- builtin$

View File

@@ -4,6 +4,49 @@ Changes
v3 has many incompatibilities with v2. To see the full list of differences between
v2 and v3, please read the Changes-v3.md file (https://github.com/lestrrat-go/jwx/blob/develop/v3/Changes-v3.md)
v3.0.12 20 Oct 2025
* [jwe] As part of the next change, now per-recipient headers that are empty
are no longer serialized in flattened JSON serialization.
* [jwe] Introduce `jwe.WithLegacyHeaderMerging(bool)` option to control header
merging behavior in during JWE encryption. This only applies to flattened
JSON serialization.
Previously, when using flattened JSON serialization (i.e. you specified
JSON serialization via `jwe.WithJSON()` and only supplied one key), per-recipient
headers were merged into the protected headers during encryption, and then
were left to be included in the final serialization as-is. This caused duplicate
headers to be present in both the protected headers and the per-recipient headers.
Since there maybe users who rely on this behavior already, instead of changing the
default behavior to fix this duplication, a new option to `jwe.Encrypt()` was added
to allow clearing the per-recipient headers after merging to leave the `"headers"`
field empty. This in effect makes the flattened JSON serialization more similar to
the compact serialization, where there are no per-recipient headers present, and
leaves the headers disjoint.
Note that in compact mode, there are no per-recipient headers and thus the
headers need to be merged regardless. In full JSON serialization, we never
merge the headers, so it is left up to the user to keep the headers disjoint.
* [jws] Calling the deprecated `jws.NewSigner()` function for the time will cause
legacy signers to be loaded automatically. Previously, you had to explicitly
call `jws.Settings(jws.WithLegacySigners(true))` to enable legacy signers.
We incorrectly assumed that users would not be using `jws.NewSigner()`, and thus
disabled legacy signers by default. However, it turned out that some users
were using `jws.NewSigner()` in their code, which lead to breakages in
existing code. In hindsight we should have known that any API made public before will
be used by _somebody_.
As a side effect, jws.Settings(jws.WithLegacySigners(...)) is now a no-op.
However, please do note that jws.Signer (and similar) objects were always intended to be
used for _registering_ new signing/verifying algorithms, and not for end users to actually
use them directly. If you are using them for other purposes, please consider changing
your code, as it is more than likely that we will somehow deprecate/remove/discouraged
their use in the future.
v3.0.11 14 Sep 2025
* [jwk] Add `(jwk.Cache).Shutdown()` method that delegates to the httprc controller
object, to shutdown the cache.

View File

@@ -9,9 +9,9 @@ bazel_dep(name = "rules_go", version = "0.55.1")
bazel_dep(name = "gazelle", version = "0.44.0")
bazel_dep(name = "aspect_bazel_lib", version = "2.11.0")
# Go SDK setup - using Go 1.24.4 to match the toolchain in go.mod
# Go SDK setup from go.mod
go_sdk = use_extension("@rules_go//go:extensions.bzl", "go_sdk")
go_sdk.download(version = "1.24.4")
go_sdk.from_file(go_mod = "//:go.mod")
# Go dependencies from go.mod
go_deps = use_extension("@gazelle//:extensions.bzl", "go_deps")

View File

@@ -22,8 +22,9 @@ const _FormatKind_name = "InvalidFormatUnknownFormatJWEJWSJWKJWKSJWT"
var _FormatKind_index = [...]uint8{0, 13, 26, 29, 32, 35, 39, 42}
func (i FormatKind) String() string {
if i < 0 || i >= FormatKind(len(_FormatKind_index)-1) {
idx := int(i) - 0
if i < 0 || idx >= len(_FormatKind_index)-1 {
return "FormatKind(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _FormatKind_name[_FormatKind_index[i]:_FormatKind_index[i+1]]
return _FormatKind_name[_FormatKind_index[idx]:_FormatKind_index[idx+1]]
}

View File

@@ -99,15 +99,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
rawKey = raw
}
// Extract ECDH-ES specific parameters if needed
// Extract ECDH-ES specific parameters if needed.
var apu, apv []byte
if b.headers != nil {
if val, ok := b.headers.AgreementPartyUInfo(); ok {
apu = val
}
if val, ok := b.headers.AgreementPartyVInfo(); ok {
apv = val
}
hdr := b.headers
if hdr == nil {
hdr = NewHeaders()
}
if val, ok := hdr.AgreementPartyUInfo(); ok {
apu = val
}
if val, ok := hdr.AgreementPartyVInfo(); ok {
apv = val
}
// Create the encrypter using the new jwebb pattern
@@ -116,20 +121,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
return nil, fmt.Errorf(`jwe.Encrypt: recipientBuilder: failed to create encrypter: %w`, err)
}
if hdrs := b.headers; hdrs != nil {
_ = r.SetHeaders(hdrs)
}
_ = r.SetHeaders(hdr)
if err := r.Headers().Set(AlgorithmKey, b.alg); err != nil {
// Populate headers with stuff that we automatically set
if err := hdr.Set(AlgorithmKey, b.alg); err != nil {
return nil, fmt.Errorf(`failed to set header: %w`, err)
}
if keyID != "" {
if err := r.Headers().Set(KeyIDKey, keyID); err != nil {
if err := hdr.Set(KeyIDKey, keyID); err != nil {
return nil, fmt.Errorf(`failed to set header: %w`, err)
}
}
// Handle the encrypted key
var rawCEK []byte
enckey, err := enc.EncryptKey(cek)
if err != nil {
@@ -143,8 +148,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
}
}
// finally, anything specific should go here
if hp, ok := enckey.(populater); ok {
if err := hp.Populate(r.Headers()); err != nil {
if err := hp.Populate(hdr); err != nil {
return nil, fmt.Errorf(`failed to populate: %w`, err)
}
}
@@ -154,7 +160,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
// Encrypt generates a JWE message for the given payload and returns
// it in serialized form, which can be in either compact or
// JSON format. Default is compact.
// JSON format. Default is compact. When JSON format is specified and
// there is only one recipient, the resulting serialization is
// automatically converted to flattened JSON serialization format.
//
// You must pass at least one key to `jwe.Encrypt()` by using `jwe.WithKey()`
// option.
@@ -172,6 +180,10 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
//
// Look for options that return `jwe.EncryptOption` or `jws.EncryptDecryptOption`
// for a complete list of options that can be passed to this function.
//
// As of v3.0.12, users can specify `jwe.WithLegacyHeaderMerging()` to
// disable header merging behavior that was the default prior to v3.0.12.
// Read the documentation for `jwe.WithLegacyHeaderMerging()` for more information.
func Encrypt(payload []byte, options ...EncryptOption) ([]byte, error) {
ec := encryptContextPool.Get()
defer encryptContextPool.Put(ec)
@@ -410,10 +422,26 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
Tag(msg.tag).
CEK(dc.cek)
if v, ok := recipient.Headers().Algorithm(); !ok || v != alg {
// algorithms don't match
// The "alg" header can be in either protected/unprotected headers.
// prefer per-recipient headers (as it might be the case that the algorithm differs
// by each recipient), then look at protected headers.
var algMatched bool
for _, hdr := range []Headers{recipient.Headers(), protectedHeaders} {
v, ok := hdr.Algorithm()
if !ok {
continue
}
if v == alg {
algMatched = true
break
}
// if we found something but didn't match, it's a failure
return nil, fmt.Errorf(`jwe.Decrypt: key (%q) and recipient (%q) algorithms do not match`, alg, v)
}
if !algMatched {
return nil, fmt.Errorf(`jwe.Decrypt: failed to find "alg" header in either protected or per-recipient headers`)
}
h2, err := protectedHeaders.Clone()
if err != nil {
@@ -534,11 +562,12 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
// encryptContext holds the state during JWE encryption, similar to JWS signContext
type encryptContext struct {
calg jwa.ContentEncryptionAlgorithm
compression jwa.CompressionAlgorithm
format int
builders []*recipientBuilder
protected Headers
calg jwa.ContentEncryptionAlgorithm
compression jwa.CompressionAlgorithm
format int
builders []*recipientBuilder
protected Headers
legacyHeaderMerging bool
}
var encryptContextPool = pool.New(allocEncryptContext, freeEncryptContext)
@@ -561,6 +590,7 @@ func freeEncryptContext(ec *encryptContext) *encryptContext {
}
func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
ec.legacyHeaderMerging = true
var mergeProtected bool
var useRawCEK bool
for _, option := range options {
@@ -577,7 +607,11 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
if v == jwa.DIRECT() || v == jwa.ECDH_ES() {
useRawCEK = true
}
ec.builders = append(ec.builders, &recipientBuilder{alg: v, key: wk.key, headers: wk.headers})
ec.builders = append(ec.builders, &recipientBuilder{
alg: v,
key: wk.key,
headers: wk.headers,
})
case identContentEncryptionAlgorithm{}:
var c jwa.ContentEncryptionAlgorithm
if err := option.Value(&c); err != nil {
@@ -616,6 +650,12 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
return err
}
ec.format = fmtOpt
case identLegacyHeaderMerging{}:
var v bool
if err := option.Value(&v); err != nil {
return err
}
ec.legacyHeaderMerging = v
}
}
@@ -732,7 +772,8 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
}
}
recipients := recipientSlicePool.GetCapacity(len(ec.builders))
lbuilders := len(ec.builders)
recipients := recipientSlicePool.GetCapacity(lbuilders)
defer recipientSlicePool.Put(recipients)
for i, builder := range ec.builders {
@@ -767,14 +808,55 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
}
}
// If there's only one recipient, you want to include that in the
// protected header
if len(recipients) == 1 {
// fmtCompact does not have per-recipient headers, nor a "header" field.
// In this mode, we're going to have to merge everything to the protected
// header.
if ec.format == fmtCompact {
// We have already established that the number of builders is 1 in
// ec.ProcessOptions(). But we're going to be pedantic
if lbuilders != 1 {
return nil, fmt.Errorf(`internal error: expected exactly one recipient builder (got %d)`, lbuilders)
}
// when we're using compact format, we can safely merge per-recipient
// headers into the protected header, if any
h, err := protected.Merge(recipients[0].Headers())
if err != nil {
return nil, fmt.Errorf(`failed to merge protected headers: %w`, err)
return nil, fmt.Errorf(`failed to merge protected headers for compact serialization: %w`, err)
}
protected = h
// per-recipient headers, if any, will be ignored in compact format
} else {
// If it got here, it's JSON (could be pretty mode, too).
if lbuilders == 1 {
// If it got here, then we're doing flattened JSON serialization.
// In this mode, we should merge per-recipient headers into the protected header,
// but we also need to make sure that the "header" field is reset so that
// it does not contain the same fields as the protected header.
//
// However, old behavior was to merge per-recipient headers into the
// protected header when there was only one recipient, AND leave the
// original "header" field as is, so we need to support that for backwards compatibility.
//
// The legacy merging only takes effect when there is exactly one recipient.
//
// This behavior can be disabled by passing jwe.WithLegacyHeaderMerging(false)
// If the user has explicitly asked for merging, do it
h, err := protected.Merge(recipients[0].Headers())
if err != nil {
return nil, fmt.Errorf(`failed to merge protected headers for flattenend JSON format: %w`, err)
}
protected = h
if !ec.legacyHeaderMerging {
// Clear per-recipient headers, since they have been merged.
// But we only do it when legacy merging is disabled.
// Note: we should probably introduce a Reset() method in v4
if err := recipients[0].SetHeaders(NewHeaders()); err != nil {
return nil, fmt.Errorf(`failed to clear per-recipient headers after merging: %w`, err)
}
}
}
}
aad, err := protected.Encode()

View File

@@ -265,14 +265,23 @@ func (m *Message) MarshalJSON() ([]byte, error) {
if recipients := m.Recipients(); len(recipients) > 0 {
if len(recipients) == 1 { // Use flattened format
if hdrs := recipients[0].Headers(); hdrs != nil {
buf.Reset()
if err := enc.Encode(hdrs); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err)
var skipHeaders bool
if zeroer, ok := hdrs.(isZeroer); ok {
if zeroer.isZero() {
skipHeaders = true
}
}
if !skipHeaders {
buf.Reset()
if err := enc.Encode(hdrs); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err)
}
fields = append(fields, jsonKV{
Key: HeadersKey,
Value: strings.TrimSpace(buf.String()),
})
}
fields = append(fields, jsonKV{
Key: HeadersKey,
Value: strings.TrimSpace(buf.String()),
})
}
if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
@@ -369,13 +378,18 @@ func (m *Message) UnmarshalJSON(buf []byte) error {
// field. TODO: do both of these conditions need to meet, or just one?
if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
recipient := NewRecipient()
hdrs := NewHeaders()
if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
return fmt.Errorf(`failed to decode headers field: %w`, err)
}
if err := recipient.SetHeaders(hdrs); err != nil {
return fmt.Errorf(`failed to set new headers: %w`, err)
// `"heders"` could be empty. If that's the case, just skip the
// following unmarshaling step
if proxy.Headers != nil {
hdrs := NewHeaders()
if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
return fmt.Errorf(`failed to decode headers field: %w`, err)
}
if err := recipient.SetHeaders(hdrs); err != nil {
return fmt.Errorf(`failed to set new headers: %w`, err)
}
}
if v := proxy.EncryptedKey; len(v) > 0 {

View File

@@ -6,8 +6,9 @@ import (
"github.com/lestrrat-go/option/v2"
)
// Specify contents of the protected header. Some fields such as
// "enc" and "zip" will be overwritten when encryption is performed.
// WithProtectedHeaders is used to specify contents of the protected header.
// Some fields such as "enc" and "zip" will be overwritten when encryption is
// performed.
//
// There is no equivalent for unprotected headers in this implementation
func WithProtectedHeaders(h Headers) EncryptOption {

View File

@@ -169,4 +169,42 @@ options:
If set to an invalid value, the default value is used.
In v2, this option was called MaxBufferSize.
This option has a global effect.
This option has a global effect.
- ident: LegacyHeaderMerging
interface: EncryptOption
argument_type: bool
option_name: WithLegacyHeaderMerging
comment: |
WithLegacyHeaderMerging specifies whether to perform legacy header merging
when encrypting a JWE message in JSON serialization, when there is a single recipient.
This behavior is enabled by default for backwards compatibility.
When a JWE message is encrypted in JSON serialization, and there is only
one recipient, this library automatically serializes the message in
flattened JSON serialization format. In older versions of this library,
the protected headers and the per-recipient headers were merged together
before computing the AAD (Additional Authenticated Data), but the per-recipient
headers were kept as-is in the `header` field of the recipient object.
This behavior is not compliant with the JWE specification, which states that
the headers must be disjoint.
Passing this option with a value of `false` disables this legacy behavior,
and while the per-recipient headers and protected headers are still merged
for the purpose of computing AAD, the per-recipient headers are cleared
after merging, so that the resulting JWE message is compliant with the
specification.
This option has no effect when there are multiple recipients, or when
the serialization format is compact serialization. For multiple recipients
(i.e. full JSON serialization), the protected headers and per-recipient
headers are never merged, and it is the caller's responsibility to ensure
that the headers are disjoint. In compact serialization, there are no per-recipient
headers; in fact, the protected headers are the only headers that exist,
and therefore there is no possibility of header collision after merging
(note: while per-recipient headers do not make sense in compact serialization,
this library does not prevent you from setting them -- they are all just
merged into the protected headers).
In future versions, the new behavior will be the default. New users are
encouraged to set this option to `false` now to avoid future issues.

View File

@@ -147,6 +147,7 @@ type identFS struct{}
type identKey struct{}
type identKeyProvider struct{}
type identKeyUsed struct{}
type identLegacyHeaderMerging struct{}
type identMaxDecompressBufferSize struct{}
type identMaxPBES2Count struct{}
type identMergeProtectedHeaders struct{}
@@ -193,6 +194,10 @@ func (identKeyUsed) String() string {
return "WithKeyUsed"
}
func (identLegacyHeaderMerging) String() string {
return "WithLegacyHeaderMerging"
}
func (identMaxDecompressBufferSize) String() string {
return "WithMaxDecompressBufferSize"
}
@@ -292,6 +297,43 @@ func WithKeyUsed(v any) DecryptOption {
return &decryptOption{option.New(identKeyUsed{}, v)}
}
// WithLegacyHeaderMerging specifies whether to perform legacy header merging
// when encrypting a JWE message in JSON serialization, when there is a single recipient.
// This behavior is enabled by default for backwards compatibility.
//
// When a JWE message is encrypted in JSON serialization, and there is only
// one recipient, this library automatically serializes the message in
// flattened JSON serialization format. In older versions of this library,
// the protected headers and the per-recipient headers were merged together
// before computing the AAD (Additional Authenticated Data), but the per-recipient
// headers were kept as-is in the `header` field of the recipient object.
//
// This behavior is not compliant with the JWE specification, which states that
// the headers must be disjoint.
//
// Passing this option with a value of `false` disables this legacy behavior,
// and while the per-recipient headers and protected headers are still merged
// for the purpose of computing AAD, the per-recipient headers are cleared
// after merging, so that the resulting JWE message is compliant with the
// specification.
//
// This option has no effect when there are multiple recipients, or when
// the serialization format is compact serialization. For multiple recipients
// (i.e. full JSON serialization), the protected headers and per-recipient
// headers are never merged, and it is the caller's responsibility to ensure
// that the headers are disjoint. In compact serialization, there are no per-recipient
// headers; in fact, the protected headers are the only headers that exist,
// and therefore there is no possibility of header collision after merging
// (note: while per-recipient headers do not make sense in compact serialization,
// this library does not prevent you from setting them -- they are all just
// merged into the protected headers).
//
// In future versions, the new behavior will be the default. New users are
// encouraged to set this option to `false` now to avoid future issues.
func WithLegacyHeaderMerging(v bool) EncryptOption {
return &encryptOption{option.New(identLegacyHeaderMerging{}, v)}
}
// WithMaxDecompressBufferSize specifies the maximum buffer size for used when
// decompressing the payload of a JWE message. If a compressed JWE payload
// exceeds this amount when decompressed, jwe.Decrypt will return an error.

View File

@@ -270,7 +270,7 @@ func (cs *cachedSet) cached() (Set, error) {
return cs.r.Resource(), nil
}
// Add is a no-op for `jwk.CachedSet`, as the `jwk.Set` should be treated read-only
// AddKey is a no-op for `jwk.CachedSet`, as the `jwk.Set` should be treated read-only
func (*cachedSet) AddKey(_ Key) error {
return fmt.Errorf(`(jwk.Cachedset).AddKey: jwk.CachedSet is immutable`)
}

View File

@@ -40,7 +40,7 @@ type CachedFetcher struct {
cache *Cache
}
// Creates a new `jwk.CachedFetcher` object.
// NewCachedFetcher creates a new `jwk.CachedFetcher` object.
func NewCachedFetcher(cache *Cache) *CachedFetcher {
return &CachedFetcher{cache}
}

View File

@@ -118,7 +118,7 @@ func NewPEMDecoder() PEMDecoder {
type pemDecoder struct{}
// DecodePEM decodes a key in PEM encoded ASN.1 DER format.
// Decode decodes a key in PEM encoded ASN.1 DER format.
// and returns a raw key.
func (pemDecoder) Decode(src []byte) (any, []byte, error) {
block, rest := pem.Decode(src)

View File

@@ -586,11 +586,14 @@ func AlgorithmsForKey(key any) ([]jwa.SignatureAlgorithm, error) {
return algs, nil
}
// Settings allows you to set global settings for this JWS operations.
//
// Currently, the only setting available is `jws.WithLegacySigners()`,
// which for various reason is now a no-op.
func Settings(options ...GlobalOption) {
for _, option := range options {
switch option.Ident() {
case identLegacySigners{}:
enableLegacySigners()
}
}
}

View File

@@ -26,7 +26,7 @@ func (e headerNotFoundError) Is(target error) bool {
}
}
// ErrHeaderdNotFound returns an error that can be passed to `errors.Is` to check if the error is
// ErrHeaderNotFound returns an error that can be passed to `errors.Is` to check if the error is
// the result of the field not being found
func ErrHeaderNotFound() error {
return headerNotFoundError{}

View File

@@ -2,11 +2,14 @@ package jws
import (
"fmt"
"sync"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jws/legacy"
)
var enableLegacySignersOnce = &sync.Once{}
func enableLegacySigners() {
for _, alg := range []jwa.SignatureAlgorithm{jwa.HS256(), jwa.HS384(), jwa.HS512()} {
if err := RegisterSigner(alg, func(alg jwa.SignatureAlgorithm) SignerFactory {
@@ -74,7 +77,7 @@ func legacySignerFor(alg jwa.SignatureAlgorithm) (Signer, error) {
muSigner.Lock()
s, ok := signers[alg]
if !ok {
v, err := NewSigner(alg)
v, err := newLegacySigner(alg)
if err != nil {
muSigner.Unlock()
return nil, fmt.Errorf(`failed to create payload signer: %w`, err)

View File

@@ -23,7 +23,7 @@ type Signer interface {
Algorithm() jwa.SignatureAlgorithm
}
// This is for legacy support only.
// Verifier is for legacy support only.
type Verifier interface {
// Verify checks whether the payload and signature are valid for
// the given key.

View File

@@ -38,7 +38,7 @@ type withKey struct {
public Headers
}
// This exists as an escape hatch to modify the header values after the fact
// Protected exists as an escape hatch to modify the header values after the fact
func (w *withKey) Protected(v Headers) Headers {
if w.protected == nil && v != nil {
w.protected = v
@@ -221,7 +221,7 @@ type withInsecureNoSignature struct {
protected Headers
}
// This exists as an escape hatch to modify the header values after the fact
// Protected exists as an escape hatch to modify the header values after the fact
func (w *withInsecureNoSignature) Protected(v Headers) Headers {
if w.protected == nil && v != nil {
w.protected = v

View File

@@ -227,8 +227,4 @@ options:
interface: GlobalOption
constant_value: true
comment: |
WithLegacySigners specifies whether the JWS package should use legacy
signers for signing JWS messages.
Usually there's no need to use this option, as the new signers and
verifiers are loaded by default.
WithLegacySigners is a no-op option that exists only for backwards compatibility.

View File

@@ -356,11 +356,7 @@ func WithKeyUsed(v any) VerifyOption {
return &verifyOption{option.New(identKeyUsed{}, v)}
}
// WithLegacySigners specifies whether the JWS package should use legacy
// signers for signing JWS messages.
//
// Usually there's no need to use this option, as the new signers and
// verifiers are loaded by default.
// WithLegacySigners is a no-op option that exists only for backwards compatibility.
func WithLegacySigners() GlobalOption {
return &globalOption{option.New(identLegacySigners{}, true)}
}

View File

@@ -2,6 +2,7 @@ package jws
import (
"fmt"
"strings"
"sync"
"github.com/lestrrat-go/jwx/v3/jwa"
@@ -33,6 +34,19 @@ func (fn SignerFactoryFn) Create() (Signer, error) {
return fn()
}
func init() {
// register the signers using jwsbb. These will be used by default.
for _, alg := range jwa.SignatureAlgorithms() {
if alg == jwa.NoSignature() {
continue
}
if err := RegisterSigner(alg, defaultSigner{alg: alg}); err != nil {
panic(fmt.Sprintf("RegisterSigner failed: %v", err))
}
}
}
// SignerFor returns a Signer2 for the given signature algorithm.
//
// Currently, this function will never fail. It will always return a
@@ -43,6 +57,9 @@ func (fn SignerFactoryFn) Create() (Signer, error) {
// 3. If no Signer2 or legacy Signer(Factory) is registered, it will return a
// default signer that uses jwsbb.Sign.
//
// 1 and 2 will take care of 99% of the cases. The only time 3 will happen is
// when you are using a custom algorithm that is not supported out of the box.
//
// jwsbb.Sign knows how to handle a static set of algorithms, so if the
// algorithm is not supported, it will return an error when you call
// `Sign` on the default signer.
@@ -80,6 +97,14 @@ var signerDB = make(map[jwa.SignatureAlgorithm]SignerFactory)
// Unlike the `UnregisterSigner` function, this function automatically
// calls `jwa.RegisterSignatureAlgorithm` to register the algorithm
// in this module's algorithm database.
//
// For backwards compatibility, this function also accepts
// `SignerFactory` implementations, but this usage is deprecated.
// You should use `Signer2` implementations instead.
//
// If you want to completely remove an algorithm, you must call
// `jwa.UnregisterSignatureAlgorithm` yourself after calling
// `UnregisterSigner`.
func RegisterSigner(alg jwa.SignatureAlgorithm, f any) error {
jwa.RegisterSignatureAlgorithm(alg)
switch s := f.(type) {
@@ -87,22 +112,10 @@ func RegisterSigner(alg jwa.SignatureAlgorithm, f any) error {
muSigner2DB.Lock()
signer2DB[alg] = s
muSigner2DB.Unlock()
// delete the other signer, if there was one
muSignerDB.Lock()
delete(signerDB, alg)
muSignerDB.Unlock()
case SignerFactory:
muSignerDB.Lock()
signerDB[alg] = s
muSignerDB.Unlock()
// Remove previous signer, if there was one
removeSigner(alg)
muSigner2DB.Lock()
delete(signer2DB, alg)
muSigner2DB.Unlock()
default:
return fmt.Errorf(`jws.RegisterSigner: unsupported type %T for algorithm %q`, f, alg)
}
@@ -132,11 +145,25 @@ func UnregisterSigner(alg jwa.SignatureAlgorithm) {
}
// NewSigner creates a signer that signs payloads using the given signature algorithm.
// This function is deprecated. You should use `SignerFor()` instead.
// This function is deprecated, and will either be removed to re-purposed using
// a different signature.
//
// This function only exists for backwards compatibility, but will not work
// unless you enable the legacy support mode by calling jws.Settings(jws.WithLegacySigners(true)).
// When you want to load a Signer object, you should use `SignerFor()` instead.
func NewSigner(alg jwa.SignatureAlgorithm) (Signer, error) {
s, err := newLegacySigner(alg)
if err == nil {
return s, nil
}
if strings.HasPrefix(err.Error(), `jws.NewSigner: unsupported signature algorithm`) {
// When newLegacySigner fails, automatically trigger to enable signers
enableLegacySignersOnce.Do(enableLegacySigners)
return newLegacySigner(alg)
}
return nil, err
}
func newLegacySigner(alg jwa.SignatureAlgorithm) (Signer, error) {
muSignerDB.RLock()
f, ok := signerDB[alg]
muSignerDB.RUnlock()

View File

@@ -66,7 +66,7 @@ func (o *TokenOptionSet) Enable(flag TokenOption) {
*o = TokenOptionSet(o.Value() | uint64(flag))
}
// Enable sets the appropriate value to disable the option in the
// Disable sets the appropriate value to disable the option in the
// option set
func (o *TokenOptionSet) Disable(flag TokenOption) {
*o = TokenOptionSet(o.Value() & ^uint64(flag))

View File

@@ -17,9 +17,9 @@ const _TokenOption_name = "FlattenAudienceMaxPerTokenOption"
var _TokenOption_index = [...]uint8{0, 15, 32}
func (i TokenOption) String() string {
i -= 1
if i >= TokenOption(len(_TokenOption_index)-1) {
return "TokenOption(" + strconv.FormatInt(int64(i+1), 10) + ")"
idx := int(i) - 1
if i < 1 || idx >= len(_TokenOption_index)-1 {
return "TokenOption(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _TokenOption_name[_TokenOption_index[i]:_TokenOption_index[i+1]]
return _TokenOption_name[_TokenOption_index[idx]:_TokenOption_index[idx+1]]
}

View File

@@ -6,10 +6,12 @@ import (
)
type safeType struct {
reflect.Type
cfg *frozenConfig
Type reflect.Type
cfg *frozenConfig
}
var _ Type = &safeType{}
func (type2 *safeType) New() interface{} {
return reflect.New(type2.Type).Interface()
}
@@ -18,6 +20,22 @@ func (type2 *safeType) UnsafeNew() unsafe.Pointer {
panic("does not support unsafe operation")
}
func (type2 *safeType) Kind() reflect.Kind {
return type2.Type.Kind()
}
func (type2 *safeType) Len() int {
return type2.Type.Len()
}
func (type2 *safeType) NumField() int {
return type2.Type.NumField()
}
func (type2 *safeType) String() string {
return type2.Type.String()
}
func (type2 *safeType) Elem() Type {
return type2.cfg.Type2(type2.Type.Elem())
}

View File

@@ -10,16 +10,19 @@ import v1 "github.com/open-policy-agent/opa/v1/ast"
// can return a Visitor w which will be used to visit the children of the AST
// element v. If the Visit function returns nil, the children will not be
// visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type Visitor = v1.Visitor
// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
// and after the AST has been visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type BeforeAndAfterVisitor = v1.BeforeAndAfterVisitor
// Walk iterates the AST by calling the Visit function on the Visitor
// v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
func Walk(v Visitor, x any) {
v1.Walk(v, x)
@@ -27,6 +30,7 @@ func Walk(v Visitor, x any) {
// WalkBeforeAndAfter iterates the AST by calling the Visit function on the
// Visitor v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x any) {
v1.WalkBeforeAndAfter(v, x)

View File

@@ -100,24 +100,28 @@ func Deactivate(opts *DeactivateOpts) error {
}
// LegacyWriteManifestToStore will write the bundle manifest to the older single (unnamed) bundle manifest location.
//
// Deprecated: Use WriteManifestToStore and named bundles instead.
func LegacyWriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, manifest Manifest) error {
return v1.LegacyWriteManifestToStore(ctx, store, txn, manifest)
}
// LegacyEraseManifestFromStore will erase the bundle manifest from the older single (unnamed) bundle manifest location.
//
// Deprecated: Use WriteManifestToStore and named bundles instead.
func LegacyEraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) error {
return v1.LegacyEraseManifestFromStore(ctx, store, txn)
}
// LegacyReadRevisionFromStore will read the bundle manifest revision from the older single (unnamed) bundle manifest location.
//
// Deprecated: Use ReadBundleRevisionFromStore and named bundles instead.
func LegacyReadRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) (string, error) {
return v1.LegacyReadRevisionFromStore(ctx, store, txn)
}
// ActivateLegacy calls Activate for the bundles but will also write their manifest to the older unnamed store location.
//
// Deprecated: Use Activate with named bundles instead.
func ActivateLegacy(opts *ActivateOpts) error {
return v1.ActivateLegacy(opts)

View File

@@ -40,7 +40,8 @@
"type": "boolean"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "and",
@@ -95,7 +96,8 @@
"type": "boolean"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "array.concat",
@@ -385,7 +387,8 @@
"type": "array"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "cast_boolean",
@@ -399,7 +402,8 @@
"type": "boolean"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "cast_null",
@@ -413,7 +417,8 @@
"type": "null"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "cast_object",
@@ -435,7 +440,8 @@
"type": "object"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "cast_set",
@@ -452,7 +458,8 @@
"type": "set"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "cast_string",
@@ -466,7 +473,8 @@
"type": "string"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "ceil",
@@ -2975,7 +2983,8 @@
"type": "boolean"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "net.lookup_ip_addr",
@@ -3493,7 +3502,8 @@
"type": "boolean"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "regex.find_all_string_submatch_n",
@@ -3808,7 +3818,8 @@
"type": "set"
},
"type": "function"
}
},
"deprecated": true
},
{
"name": "sort",

View File

File diff suppressed because it is too large Load Diff

View File

@@ -32,7 +32,7 @@ const (
opaWasmABIMinorVersionVar = "opa_wasm_abi_minor_version"
)
// nolint: deadcode,varcheck
// nolint: varcheck
const (
opaTypeNull int32 = iota + 1
opaTypeBoolean
@@ -414,7 +414,7 @@ func (c *Compiler) initModule() error {
},
},
},
Init: bytes.Repeat([]byte{0}, int(heapBase-offset)),
Init: make([]byte, int(heapBase-offset)),
})
return nil
@@ -1058,9 +1058,11 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
},
})
case *ir.AssignIntStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.I64Const{Value: stmt.Value})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueNumberSetInt)})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Target)},
instruction.I64Const{Value: stmt.Value},
instruction.Call{Index: c.function(opaValueNumberSetInt)},
)
case *ir.ScanStmt:
if err := c.compileScan(stmt, &instrs); err != nil {
return nil, err
@@ -1073,12 +1075,14 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
}
case *ir.DotStmt:
if loc, ok := stmt.Source.Value.(ir.Local); ok {
instrs = append(instrs, instruction.GetLocal{Index: c.local(loc)})
instrs = append(instrs, c.instrRead(stmt.Key))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueGet)})
instrs = append(instrs, instruction.TeeLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(loc)},
c.instrRead(stmt.Key),
instruction.Call{Index: c.function(opaValueGet)},
instruction.TeeLocal{Index: c.local(stmt.Target)},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
)
} else {
// Booleans and string sources would lead to the BrIf (since opa_value_get
// on them returns 0), so let's skip trying that.
@@ -1086,97 +1090,131 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
break
}
case *ir.LenStmt:
instrs = append(instrs, c.instrRead(stmt.Source))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueLength)})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberSize)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
c.instrRead(stmt.Source),
instruction.Call{Index: c.function(opaValueLength)},
instruction.Call{Index: c.function(opaNumberSize)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.EqualStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
c.instrRead(stmt.A),
c.instrRead(stmt.B),
instruction.Call{Index: c.function(opaValueCompare)},
instruction.BrIf{Index: 0},
)
case *ir.NotEqualStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
c.instrRead(stmt.A),
c.instrRead(stmt.B),
instruction.Call{Index: c.function(opaValueCompare)},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
)
case *ir.MakeNullStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaNull)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.Call{Index: c.function(opaNull)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.MakeNumberIntStmt:
instrs = append(instrs, instruction.I64Const{Value: stmt.Value})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberInt)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.I64Const{Value: stmt.Value},
instruction.Call{Index: c.function(opaNumberInt)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.MakeNumberRefStmt:
instrs = append(instrs, instruction.I32Const{Value: c.stringAddr(stmt.Index)})
instrs = append(instrs, instruction.I32Const{Value: int32(len(c.policy.Static.Strings[stmt.Index].Value))})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberRef)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.I32Const{Value: c.stringAddr(stmt.Index)},
instruction.I32Const{Value: int32(len(c.policy.Static.Strings[stmt.Index].Value))},
instruction.Call{Index: c.function(opaNumberRef)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.MakeArrayStmt:
instrs = append(instrs, instruction.I32Const{Value: stmt.Capacity})
instrs = append(instrs, instruction.Call{Index: c.function(opaArrayWithCap)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.I32Const{Value: stmt.Capacity},
instruction.Call{Index: c.function(opaArrayWithCap)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.MakeObjectStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaObject)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.Call{Index: c.function(opaObject)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.MakeSetStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaSet)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.Call{Index: c.function(opaSet)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.IsArrayStmt:
if loc, ok := stmt.Source.Value.(ir.Local); ok {
instrs = append(instrs, instruction.GetLocal{Index: c.local(loc)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)})
instrs = append(instrs, instruction.I32Const{Value: opaTypeArray})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(loc)},
instruction.Call{Index: c.function(opaValueType)},
instruction.I32Const{Value: opaTypeArray},
instruction.I32Ne{},
instruction.BrIf{Index: 0},
)
} else {
instrs = append(instrs, instruction.Br{Index: 0})
break
}
case *ir.IsObjectStmt:
if loc, ok := stmt.Source.Value.(ir.Local); ok {
instrs = append(instrs, instruction.GetLocal{Index: c.local(loc)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)})
instrs = append(instrs, instruction.I32Const{Value: opaTypeObject})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(loc)},
instruction.Call{Index: c.function(opaValueType)},
instruction.I32Const{Value: opaTypeObject},
instruction.I32Ne{},
instruction.BrIf{Index: 0},
)
} else {
instrs = append(instrs, instruction.Br{Index: 0})
break
}
case *ir.IsSetStmt:
if loc, ok := stmt.Source.Value.(ir.Local); ok {
instrs = append(instrs, instruction.GetLocal{Index: c.local(loc)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)})
instrs = append(instrs, instruction.I32Const{Value: opaTypeSet})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(loc)},
instruction.Call{Index: c.function(opaValueType)},
instruction.I32Const{Value: opaTypeSet},
instruction.I32Ne{},
instruction.BrIf{Index: 0},
)
} else {
instrs = append(instrs, instruction.Br{Index: 0})
break
}
case *ir.IsUndefinedStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Source)},
instruction.I32Const{Value: 0},
instruction.I32Ne{},
instruction.BrIf{Index: 0},
)
case *ir.ResetLocalStmt:
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.I32Const{Value: 0},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.IsDefinedStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Source)},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
)
case *ir.ArrayAppendStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Array)})
instrs = append(instrs, c.instrRead(stmt.Value))
instrs = append(instrs, instruction.Call{Index: c.function(opaArrayAppend)})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Array)},
c.instrRead(stmt.Value),
instruction.Call{Index: c.function(opaArrayAppend)},
)
case *ir.ObjectInsertStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Object)})
instrs = append(instrs, c.instrRead(stmt.Key))
instrs = append(instrs, c.instrRead(stmt.Value))
instrs = append(instrs, instruction.Call{Index: c.function(opaObjectInsert)})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Object)},
c.instrRead(stmt.Key),
c.instrRead(stmt.Value),
instruction.Call{Index: c.function(opaObjectInsert)},
)
case *ir.ObjectInsertOnceStmt:
tmp := c.genLocal()
instrs = append(instrs, instruction.Block{
@@ -1203,14 +1241,18 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
},
})
case *ir.ObjectMergeStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueMerge)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.A)},
instruction.GetLocal{Index: c.local(stmt.B)},
instruction.Call{Index: c.function(opaValueMerge)},
instruction.SetLocal{Index: c.local(stmt.Target)},
)
case *ir.SetAddStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Set)})
instrs = append(instrs, c.instrRead(stmt.Value))
instrs = append(instrs, instruction.Call{Index: c.function(opaSetAdd)})
instrs = append(instrs,
instruction.GetLocal{Index: c.local(stmt.Set)},
c.instrRead(stmt.Value),
instruction.Call{Index: c.function(opaSetAdd)},
)
default:
var buf bytes.Buffer
err := ir.Pretty(&buf, stmt)
@@ -1226,8 +1268,7 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
func (c *Compiler) compileScan(scan *ir.ScanStmt, result *[]instruction.Instruction) error {
var instrs = *result
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.SetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.I32Const{Value: 0}, instruction.SetLocal{Index: c.local(scan.Key)})
body, err := c.compileScanBlock(scan)
if err != nil {
return err
@@ -1242,23 +1283,21 @@ func (c *Compiler) compileScan(scan *ir.ScanStmt, result *[]instruction.Instruct
}
func (c *Compiler) compileScanBlock(scan *ir.ScanStmt) ([]instruction.Instruction, error) {
var instrs []instruction.Instruction
// Execute iterator.
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Source)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueIter)})
// Check for emptiness.
instrs = append(instrs, instruction.TeeLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 1})
// Load value.
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Source)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueGet)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(scan.Value)})
instrs := []instruction.Instruction{
// Execute iterator.
instruction.GetLocal{Index: c.local(scan.Source)},
instruction.GetLocal{Index: c.local(scan.Key)},
instruction.Call{Index: c.function(opaValueIter)},
// Check for emptiness.
instruction.TeeLocal{Index: c.local(scan.Key)},
instruction.I32Eqz{},
instruction.BrIf{Index: 1},
// Load value.
instruction.GetLocal{Index: c.local(scan.Source)},
instruction.GetLocal{Index: c.local(scan.Key)},
instruction.Call{Index: c.function(opaValueGet)},
instruction.SetLocal{Index: c.local(scan.Value)},
}
// Loop body.
nested, err := c.compileBlock(scan.Block)
@@ -1278,8 +1317,7 @@ func (c *Compiler) compileNot(not *ir.NotStmt, result *[]instruction.Instruction
// generate and initialize condition variable
cond := c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: 1})
instrs = append(instrs, instruction.SetLocal{Index: cond})
instrs = append(instrs, instruction.I32Const{Value: 1}, instruction.SetLocal{Index: cond})
nested, err := c.compileBlock(not.Block)
if err != nil {
@@ -1287,14 +1325,15 @@ func (c *Compiler) compileNot(not *ir.NotStmt, result *[]instruction.Instruction
}
// unset condition variable if end of block is reached
nested = append(nested, instruction.I32Const{Value: 0})
nested = append(nested, instruction.SetLocal{Index: cond})
instrs = append(instrs, instruction.Block{Instrs: nested})
// break out of block if condition variable was unset
instrs = append(instrs, instruction.GetLocal{Index: cond})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
instrs = append(instrs, instruction.Block{Instrs: append(nested,
instruction.I32Const{Value: 0},
instruction.SetLocal{Index: cond},
)},
// break out of block if condition variable was unset
instruction.GetLocal{Index: cond},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
)
*result = instrs
return nil
@@ -1304,34 +1343,36 @@ func (c *Compiler) compileWithStmt(with *ir.WithStmt, result *[]instruction.Inst
var instrs = *result
save := c.genLocal()
instrs = append(instrs, instruction.Call{Index: c.function(opaMemoizePush)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(with.Local)})
instrs = append(instrs, instruction.SetLocal{Index: save})
instrs = append(instrs,
instruction.Call{Index: c.function(opaMemoizePush)},
instruction.GetLocal{Index: c.local(with.Local)},
instruction.SetLocal{Index: save},
)
if len(with.Path) == 0 {
instrs = append(instrs, c.instrRead(with.Value))
instrs = append(instrs, instruction.SetLocal{Index: c.local(with.Local)})
instrs = append(instrs, c.instrRead(with.Value), instruction.SetLocal{Index: c.local(with.Local)})
} else {
instrs = c.compileUpsert(with.Local, with.Path, with.Value, with.Location, instrs)
}
undefined := c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: 1})
instrs = append(instrs, instruction.SetLocal{Index: undefined})
instrs = append(instrs, instruction.I32Const{Value: 1}, instruction.SetLocal{Index: undefined})
nested, err := c.compileBlock(with.Block)
if err != nil {
return err
}
nested = append(nested, instruction.I32Const{Value: 0})
nested = append(nested, instruction.SetLocal{Index: undefined})
instrs = append(instrs, instruction.Block{Instrs: nested})
instrs = append(instrs, instruction.GetLocal{Index: save})
instrs = append(instrs, instruction.SetLocal{Index: c.local(with.Local)})
instrs = append(instrs, instruction.Call{Index: c.function(opaMemoizePop)})
instrs = append(instrs, instruction.GetLocal{Index: undefined})
instrs = append(instrs, instruction.BrIf{Index: 0})
nested = append(nested, instruction.I32Const{Value: 0}, instruction.SetLocal{Index: undefined})
instrs = append(instrs,
instruction.Block{Instrs: nested},
instruction.GetLocal{Index: save},
instruction.SetLocal{Index: c.local(with.Local)},
instruction.Call{Index: c.function(opaMemoizePop)},
instruction.GetLocal{Index: undefined},
instruction.BrIf{Index: 0},
)
*result = instrs
@@ -1339,37 +1380,38 @@ func (c *Compiler) compileWithStmt(with *ir.WithStmt, result *[]instruction.Inst
}
func (c *Compiler) compileUpsert(local ir.Local, path []int, value ir.Operand, _ ir.Location, instrs []instruction.Instruction) []instruction.Instruction {
lcopy := c.genLocal() // holds copy of local
instrs = append(instrs, instruction.GetLocal{Index: c.local(local)})
instrs = append(instrs, instruction.SetLocal{Index: lcopy})
// Shallow copy the local if defined otherwise initialize to an empty object.
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.Block{Instrs: []instruction.Instruction{
instruction.GetLocal{Index: lcopy},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
instruction.GetLocal{Index: lcopy},
instruction.Call{Index: c.function(opaValueShallowCopy)},
instrs = append(instrs,
instruction.GetLocal{Index: c.local(local)},
instruction.SetLocal{Index: lcopy},
// Shallow copy the local if defined otherwise initialize to an empty object.
instruction.Block{
Instrs: []instruction.Instruction{
instruction.Block{Instrs: []instruction.Instruction{
instruction.GetLocal{Index: lcopy},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
instruction.GetLocal{Index: lcopy},
instruction.Call{Index: c.function(opaValueShallowCopy)},
instruction.TeeLocal{Index: lcopy},
instruction.SetLocal{Index: c.local(local)},
instruction.Br{Index: 1},
}},
instruction.Call{Index: c.function(opaObject)},
instruction.TeeLocal{Index: lcopy},
instruction.SetLocal{Index: c.local(local)},
instruction.Br{Index: 1},
}},
instruction.Call{Index: c.function(opaObject)},
instruction.TeeLocal{Index: lcopy},
instruction.SetLocal{Index: c.local(local)},
},
})
},
})
// Initialize the locals that specify the path of the upsert operation.
lpath := make(map[int]uint32, len(path))
for i := range path {
lpath[i] = c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: c.opaStringAddr(path[i])})
instrs = append(instrs, instruction.SetLocal{Index: lpath[i]})
instrs = append(instrs,
instruction.I32Const{Value: c.opaStringAddr(path[i])},
instruction.SetLocal{Index: lpath[i]},
)
}
// Generate a block that traverses the path of the upsert operation,
@@ -1379,36 +1421,34 @@ func (c *Compiler) compileUpsert(local ir.Local, path []int, value ir.Operand, _
ltemp := c.genLocal()
for i := range len(path) - 1 {
// Lookup the next part of the path.
inner = append(inner, instruction.GetLocal{Index: lcopy})
inner = append(inner, instruction.GetLocal{Index: lpath[i]})
inner = append(inner, instruction.Call{Index: c.function(opaValueGet)})
inner = append(inner, instruction.SetLocal{Index: ltemp})
// If the next node is missing, break.
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.I32Eqz{})
inner = append(inner, instruction.BrIf{Index: uint32(i)})
// If the next node is not an object, break.
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.Call{Index: c.function(opaValueType)})
inner = append(inner, instruction.I32Const{Value: opaTypeObject})
inner = append(inner, instruction.I32Ne{})
inner = append(inner, instruction.BrIf{Index: uint32(i)})
// Otherwise, shallow copy the next node node and insert into the copy
// before continuing.
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.Call{Index: c.function(opaValueShallowCopy)})
inner = append(inner, instruction.SetLocal{Index: ltemp})
inner = append(inner, instruction.GetLocal{Index: lcopy})
inner = append(inner, instruction.GetLocal{Index: lpath[i]})
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.Call{Index: c.function(opaObjectInsert)})
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.SetLocal{Index: lcopy})
inner = append(inner,
// Lookup the next part of the path.
instruction.GetLocal{Index: lcopy},
instruction.GetLocal{Index: lpath[i]},
instruction.Call{Index: c.function(opaValueGet)},
instruction.SetLocal{Index: ltemp},
// If the next node is missing, break.
instruction.GetLocal{Index: ltemp},
instruction.I32Eqz{},
instruction.BrIf{Index: uint32(i)},
// If the next node is not an object, break.
instruction.GetLocal{Index: ltemp},
instruction.Call{Index: c.function(opaValueType)},
instruction.I32Const{Value: opaTypeObject},
instruction.I32Ne{},
instruction.BrIf{Index: uint32(i)},
// Otherwise, shallow copy the next node node and insert into the copy
// before continuing.
instruction.GetLocal{Index: ltemp},
instruction.Call{Index: c.function(opaValueShallowCopy)},
instruction.SetLocal{Index: ltemp},
instruction.GetLocal{Index: lcopy},
instruction.GetLocal{Index: lpath[i]},
instruction.GetLocal{Index: ltemp},
instruction.Call{Index: c.function(opaObjectInsert)},
instruction.GetLocal{Index: ltemp},
instruction.SetLocal{Index: lcopy},
)
}
inner = append(inner, instruction.Br{Index: uint32(len(path) - 1)})
@@ -1418,27 +1458,29 @@ func (c *Compiler) compileUpsert(local ir.Local, path []int, value ir.Operand, _
lval := c.genLocal()
for i := range len(path) - 1 {
block = append(block, instruction.Block{Instrs: inner})
block = append(block, instruction.Call{Index: c.function(opaObject)})
block = append(block, instruction.SetLocal{Index: lval})
block = append(block, instruction.GetLocal{Index: lcopy})
block = append(block, instruction.GetLocal{Index: lpath[i]})
block = append(block, instruction.GetLocal{Index: lval})
block = append(block, instruction.Call{Index: c.function(opaObjectInsert)})
block = append(block, instruction.GetLocal{Index: lval})
block = append(block, instruction.SetLocal{Index: lcopy})
block = append(block,
instruction.Block{Instrs: inner},
instruction.Call{Index: c.function(opaObject)},
instruction.SetLocal{Index: lval},
instruction.GetLocal{Index: lcopy},
instruction.GetLocal{Index: lpath[i]},
instruction.GetLocal{Index: lval},
instruction.Call{Index: c.function(opaObjectInsert)},
instruction.GetLocal{Index: lval},
instruction.SetLocal{Index: lcopy},
)
inner = block
block = nil
}
// Finish by inserting the statement's value into the shallow copied node.
instrs = append(instrs, instruction.Block{Instrs: inner})
instrs = append(instrs, instruction.GetLocal{Index: lcopy})
instrs = append(instrs, instruction.GetLocal{Index: lpath[len(path)-1]})
instrs = append(instrs, c.instrRead(value))
instrs = append(instrs, instruction.Call{Index: c.function(opaObjectInsert)})
return instrs
return append(instrs,
instruction.Block{Instrs: inner},
instruction.GetLocal{Index: lcopy},
instruction.GetLocal{Index: lpath[len(path)-1]},
c.instrRead(value),
instruction.Call{Index: c.function(opaObjectInsert)},
)
}
func (c *Compiler) compileCallDynamicStmt(stmt *ir.CallDynamicStmt, result *[]instruction.Instruction) error {

View File

@@ -4,39 +4,72 @@ import (
"archive/tar"
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"io"
"strings"
)
// MustWriteTarGz write the list of file names and content
// into a tarball.
func MustWriteTarGz(files [][2]string) *bytes.Buffer {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
defer gw.Close()
tw := tar.NewWriter(gw)
defer tw.Close()
for _, file := range files {
if err := WriteFile(tw, file[0], []byte(file[1])); err != nil {
panic(err)
}
}
return &buf
type TarGzWriter struct {
*tar.Writer
gw *gzip.Writer
}
// WriteFile adds a file header with content to the given tar writer
func WriteFile(tw *tar.Writer, path string, bs []byte) error {
func NewTarGzWriter(w io.Writer) *TarGzWriter {
gw := gzip.NewWriter(w)
tw := tar.NewWriter(gw)
return &TarGzWriter{
Writer: tw,
gw: gw,
}
}
func (tgw *TarGzWriter) WriteFile(path string, bs []byte) (err error) {
hdr := &tar.Header{
Name: "/" + strings.TrimLeft(path, "/"),
Name: path,
Mode: 0600,
Typeflag: tar.TypeReg,
Size: int64(len(bs)),
}
if err := tw.WriteHeader(hdr); err != nil {
if err = tgw.WriteHeader(hdr); err == nil {
_, err = tgw.Write(bs)
}
return err
}
func (tgw *TarGzWriter) WriteJSONFile(path string, v any) error {
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(v); err != nil {
return err
}
_, err := tw.Write(bs)
return err
return tgw.WriteFile(path, buf.Bytes())
}
func (tgw *TarGzWriter) Close() error {
return errors.Join(tgw.Writer.Close(), tgw.gw.Close())
}
// MustWriteTarGz writes the list of file names and content into a tarball.
// Paths are prefixed with "/".
func MustWriteTarGz(files [][2]string) *bytes.Buffer {
buf := &bytes.Buffer{}
tgw := NewTarGzWriter(buf)
defer tgw.Close()
for _, file := range files {
if !strings.HasPrefix(file[0], "/") {
file[0] = "/" + file[0]
}
if err := tgw.WriteFile(file[0], []byte(file[1])); err != nil {
panic(err)
}
}
return buf
}

View File

@@ -23,7 +23,7 @@
//
// created 26-02-2013
// nolint: deadcode,unused,varcheck // Package in development (2021).
// nolint:unused,varcheck // Package in development (2021).
package gojsonschema
import (

View File

@@ -158,6 +158,8 @@ func SignV4(headers map[string][]string, method string, theURL *url.URL, body []
// include the values for the signed headers
orderedKeys := util.KeysSorted(headersToSign)
for _, k := range orderedKeys {
// TODO: fix later
//nolint:perfsprint
canonicalReq += k + ":" + strings.Join(headersToSign[k], ",") + "\n"
}
canonicalReq += "\n" // linefeed to terminate headers

View File

@@ -7,16 +7,16 @@ package ref
import (
"errors"
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/storage"
"github.com/open-policy-agent/opa/v1/util"
)
// ParseDataPath returns a ref from the slash separated path s rooted at data.
// All path segments are treated as identifier strings.
func ParseDataPath(s string) (ast.Ref, error) {
path, ok := storage.ParsePath("/" + strings.TrimPrefix(s, "/"))
path, ok := storage.ParsePath(util.WithPrefix(s, "/"))
if !ok {
return nil, errors.New("invalid path")
}

View File

@@ -81,8 +81,6 @@ type GHResponse struct {
// New returns an instance of the Reporter
func New(opts Options) (Reporter, error) {
r := GHVersionCollector{}
url := cmp.Or(os.Getenv("OPA_TELEMETRY_SERVICE_URL"), ExternalServiceURL)
restConfig := fmt.Appendf(nil, `{
@@ -93,7 +91,7 @@ func New(opts Options) (Reporter, error) {
if err != nil {
return nil, err
}
r.client = client
r := GHVersionCollector{client: client}
// heap_usage_bytes is always present, so register it unconditionally
r.RegisterGatherer("heap_usage_bytes", readRuntimeMemStats)
@@ -135,19 +133,17 @@ func createDataResponse(ghResp GHResponse) (*DataResponse, error) {
return nil, errors.New("server response does not contain tag_name")
}
v := strings.TrimPrefix(version.Version, "v")
sv, err := semver.NewVersion(v)
sv, err := semver.Parse(version.Version)
if err != nil {
return nil, fmt.Errorf("failed to parse current version %q: %w", v, err)
return nil, fmt.Errorf("failed to parse current version %q: %w", version.Version, err)
}
latestV := strings.TrimPrefix(ghResp.TagName, "v")
latestSV, err := semver.NewVersion(latestV)
latestSV, err := semver.Parse(ghResp.TagName)
if err != nil {
return nil, fmt.Errorf("failed to parse latest version %q: %w", latestV, err)
return nil, fmt.Errorf("failed to parse latest version %q: %w", ghResp.TagName, err)
}
isLatest := sv.Compare(*latestSV) >= 0
isLatest := sv.Compare(latestSV) >= 0
// Note: alternatively, we could look through the assets in the GH API response to find a matching asset,
// and use its URL. However, this is not guaranteed to be more robust, and wouldn't use the 'openpolicyagent.org' domain.

View File

@@ -18,6 +18,7 @@ import (
"github.com/open-policy-agent/opa/v1/loader"
"github.com/open-policy-agent/opa/v1/metrics"
"github.com/open-policy-agent/opa/v1/storage"
"github.com/open-policy-agent/opa/v1/util"
)
// InsertAndCompileOptions contains the input for the operation.
@@ -246,13 +247,9 @@ func WalkPaths(paths []string, filter loader.Filter, asBundle bool) (*WalkPathsR
cleanedPath = fp
}
if !strings.HasPrefix(cleanedPath, "/") {
cleanedPath = "/" + cleanedPath
}
result.FileDescriptors = append(result.FileDescriptors, &Descriptor{
Root: path,
Path: cleanedPath,
Path: util.WithPrefix(cleanedPath, "/"),
})
}
}

View File

@@ -14,237 +14,234 @@
// Semantic Versions http://semver.org
// Package semver has been vendored from:
// This file was originally vendored from:
// https://github.com/coreos/go-semver/tree/e214231b295a8ea9479f11b70b35d5acf3556d9b/semver
// A number of the original functions of the package have been removed since
// they are not required for our built-ins.
// There isn't a single line left from the original source today, but being generous about
// attribution won't hurt.
package semver
import (
"bytes"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/open-policy-agent/opa/v1/util"
)
// reMetaIdentifier matches pre-release and metadata identifiers against the spec requirements
var reMetaIdentifier = regexp.MustCompile(`^[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*$`)
// Version represents a parsed SemVer
type Version struct {
Major int64
Minor int64
Patch int64
PreRelease PreRelease
Metadata string
PreRelease string `json:"PreRelease,omitempty"`
Metadata string `json:"Metadata,omitempty"`
}
// PreRelease represents a pre-release suffix string
type PreRelease string
// Parse constructs new semver Version from version string.
func Parse(version string) (v Version, err error) {
version = strings.TrimPrefix(version, "v")
func splitOff(input *string, delim string) (val string) {
parts := strings.SplitN(*input, delim, 2)
if len(parts) == 2 {
*input = parts[0]
val = parts[1]
version, v.Metadata = cut(version, '+')
if v.Metadata != "" && !reMetaIdentifier.MatchString(v.Metadata) {
return v, fmt.Errorf("invalid metadata identifier: %s", v.Metadata)
}
return val
version, v.PreRelease = cut(version, '-')
if v.PreRelease != "" && !reMetaIdentifier.MatchString(v.PreRelease) {
return v, fmt.Errorf("invalid pre-release identifier: %s", v.PreRelease)
}
if strings.Count(version, ".") != 2 {
return v, fmt.Errorf("%s should contain major, minor, and patch versions", version)
}
major, after := cut(version, '.')
if v.Major, err = strconv.ParseInt(major, 10, 64); err != nil {
return v, err
}
minor, after := cut(after, '.')
if v.Minor, err = strconv.ParseInt(minor, 10, 64); err != nil {
return v, err
}
if v.Patch, err = strconv.ParseInt(after, 10, 64); err != nil {
return v, err
}
return v, nil
}
// NewVersion constructs new SemVers from strings
func NewVersion(version string) (*Version, error) {
v := Version{}
if err := v.Set(version); err != nil {
return nil, err
// MustParse is like Parse but panics if the version string is invalid instead of returning an error.
func MustParse(version string) Version {
v, err := Parse(version)
if err != nil {
panic(err)
}
return &v, nil
return v
}
// Set parses and updates v from the given version string. Implements flag.Value
func (v *Version) Set(version string) error {
metadata := splitOff(&version, "+")
preRelease := PreRelease(splitOff(&version, "-"))
dotParts := strings.SplitN(version, ".", 3)
if len(dotParts) != 3 {
return fmt.Errorf("%s is not in dotted-tri format", version)
}
if err := validateIdentifier(string(preRelease)); err != nil {
return fmt.Errorf("failed to validate pre-release: %v", err)
}
if err := validateIdentifier(metadata); err != nil {
return fmt.Errorf("failed to validate metadata: %v", err)
}
parsed := make([]int64, 3)
for i, v := range dotParts[:3] {
val, err := strconv.ParseInt(v, 10, 64)
parsed[i] = val
if err != nil {
return err
}
}
v.Metadata = metadata
v.PreRelease = preRelease
v.Major = parsed[0]
v.Minor = parsed[1]
v.Patch = parsed[2]
return nil
}
func (v Version) String() string {
var buffer bytes.Buffer
fmt.Fprintf(&buffer, "%d.%d.%d", v.Major, v.Minor, v.Patch)
if v.PreRelease != "" {
fmt.Fprintf(&buffer, "-%s", v.PreRelease)
}
if v.Metadata != "" {
fmt.Fprintf(&buffer, "+%s", v.Metadata)
}
return buffer.String()
}
// Compare tests if v is less than, equal to, or greater than versionB,
// returning -1, 0, or +1 respectively.
func (v Version) Compare(versionB Version) int {
if cmp := recursiveCompare(v.Slice(), versionB.Slice()); cmp != 0 {
return cmp
}
return preReleaseCompare(v, versionB)
}
// Slice converts the comparable parts of the semver into a slice of integers.
func (v Version) Slice() []int64 {
return []int64{v.Major, v.Minor, v.Patch}
}
// Slice splits the pre-release suffix string
func (p PreRelease) Slice() []string {
preRelease := string(p)
return strings.Split(preRelease, ".")
}
func preReleaseCompare(versionA Version, versionB Version) int {
a := versionA.PreRelease
b := versionB.PreRelease
/* Handle the case where if two versions are otherwise equal it is the
* one without a PreRelease that is greater */
if len(a) == 0 && (len(b) > 0) {
return 1
} else if len(b) == 0 && (len(a) > 0) {
return -1
}
// If there is a prerelease, check and compare each part.
return recursivePreReleaseCompare(a.Slice(), b.Slice())
}
func recursiveCompare(versionA []int64, versionB []int64) int {
if len(versionA) == 0 {
return 0
}
a := versionA[0]
b := versionB[0]
if a > b {
return 1
} else if a < b {
return -1
}
return recursiveCompare(versionA[1:], versionB[1:])
}
func recursivePreReleaseCompare(versionA []string, versionB []string) int {
// A larger set of pre-release fields has a higher precedence than a smaller set,
// if all of the preceding identifiers are equal.
if len(versionA) == 0 {
if len(versionB) > 0 {
return -1
}
return 0
} else if len(versionB) == 0 {
// We're longer than versionB so return 1.
return 1
}
a := versionA[0]
b := versionB[0]
aInt := false
bInt := false
aI, err := strconv.Atoi(versionA[0])
if err == nil {
aInt = true
}
bI, err := strconv.Atoi(versionB[0])
if err == nil {
bInt = true
}
// Numeric identifiers always have lower precedence than non-numeric identifiers.
if aInt && !bInt {
return -1
} else if !aInt && bInt {
return 1
}
// Handle Integer Comparison
if aInt && bInt {
if aI > bI {
return 1
} else if aI < bI {
return -1
}
}
// Handle String Comparison
if a > b {
return 1
} else if a < b {
return -1
}
return recursivePreReleaseCompare(versionA[1:], versionB[1:])
}
// validateIdentifier makes sure the provided identifier satisfies semver spec
func validateIdentifier(id string) error {
if id != "" && !reIdentifier.MatchString(id) {
return fmt.Errorf("%s is not a valid semver identifier", id)
}
return nil
}
// reIdentifier is a regular expression used to check that pre-release and metadata
// identifiers satisfy the spec requirements
var reIdentifier = regexp.MustCompile(`^[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*$`)
// Compare compares two semver strings.
func Compare(a, b string) int {
aV, err := NewVersion(strings.TrimPrefix(a, "v"))
aV, err := Parse(a)
if err != nil {
return -1
}
bV, err := NewVersion(strings.TrimPrefix(b, "v"))
bV, err := Parse(b)
if err != nil {
return 1
}
return aV.Compare(*bV)
return aV.Compare(bV)
}
// AppendText appends the textual representation of the version to b and returns the extended buffer.
// This method conforms to the encoding.TextAppender interface, and is useful for serializing the Version
// without allocating, provided the caller has pre-allocated sufficient space in b.
func (v Version) AppendText(b []byte) ([]byte, error) {
if b == nil {
b = make([]byte, 0, length(v))
}
b = append(strconv.AppendInt(b, v.Major, 10), '.')
b = append(strconv.AppendInt(b, v.Minor, 10), '.')
b = strconv.AppendInt(b, v.Patch, 10)
if v.PreRelease != "" {
b = append(append(b, '-'), v.PreRelease...)
}
if v.Metadata != "" {
b = append(append(b, '+'), v.Metadata...)
}
return b, nil
}
// String returns the string representation of the version.
func (v Version) String() string {
bs := make([]byte, 0, length(v))
bs, _ = v.AppendText(bs)
return string(bs)
}
// Compare tests if v is less than, equal to, or greater than other, returning -1, 0, or +1 respectively.
// Comparison is based on the SemVer specification (https://semver.org/#spec-item-11).
func (v Version) Compare(other Version) int {
if v.Major > other.Major {
return 1
} else if v.Major < other.Major {
return -1
}
if v.Minor > other.Minor {
return 1
} else if v.Minor < other.Minor {
return -1
}
if v.Patch > other.Patch {
return 1
} else if v.Patch < other.Patch {
return -1
}
if v.PreRelease == other.PreRelease {
return 0
}
// if two versions are otherwise equal it is the one without a pre-release that is greater
if v.PreRelease == "" && other.PreRelease != "" {
return 1
}
if other.PreRelease == "" && v.PreRelease != "" {
return -1
}
a, afterA := cut(v.PreRelease, '.')
b, afterB := cut(other.PreRelease, '.')
for {
if a == "" && b != "" {
return -1
}
if a != "" && b == "" {
return 1
}
aIsInt := isAllDecimals(a)
bIsInt := isAllDecimals(b)
// numeric identifiers have lower precedence than non-numeric
if aIsInt && !bIsInt {
return -1
} else if !aIsInt && bIsInt {
return 1
}
if aIsInt && bIsInt {
aInt, _ := strconv.Atoi(a)
bInt, _ := strconv.Atoi(b)
if aInt > bInt {
return 1
} else if aInt < bInt {
return -1
}
} else {
// string comparison
if a > b {
return 1
} else if a < b {
return -1
}
}
// a larger set of pre-release fields has a higher precedence than a
// smaller set, if all of the preceding identifiers are equal.
if afterA != "" && afterB == "" {
return 1
} else if afterA == "" && afterB != "" {
return -1
}
a, afterA = cut(afterA, '.')
b, afterB = cut(afterB, '.')
}
}
func isAllDecimals(s string) bool {
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return s != ""
}
// length allows calculating the length of the version for pre-allocation.
func length(v Version) int {
n := util.NumDigitsInt64(v.Major) + util.NumDigitsInt64(v.Minor) + util.NumDigitsInt64(v.Patch) + 2
if v.PreRelease != "" {
n += len(v.PreRelease) + 1
}
if v.Metadata != "" {
n += len(v.Metadata) + 1
}
return n
}
// cut is a *slightly* faster version of strings.Cut only accepting
// single byte separators, and skipping the boolean return value.
func cut(s string, sep byte) (before, after string) {
if i := strings.IndexByte(s, sep); i >= 0 {
return s[:i], s[i+1:]
}
return s, ""
}

View File

@@ -77,6 +77,7 @@ func Schemas(schemaPath string) (*ast.SchemaSet, error) {
}
// All returns a Result object loaded (recursively) from the specified paths.
//
// Deprecated: Use FileLoader.Filtered() instead.
func All(paths []string) (*Result, error) {
return NewFileLoader().Filtered(paths, nil)
@@ -85,6 +86,7 @@ func All(paths []string) (*Result, error) {
// Filtered returns a Result object loaded (recursively) from the specified
// paths while applying the given filters. If any filter returns true, the
// file/directory is excluded.
//
// Deprecated: Use FileLoader.Filtered() instead.
func Filtered(paths []string, filter Filter) (*Result, error) {
return NewFileLoader().Filtered(paths, filter)
@@ -93,6 +95,7 @@ func Filtered(paths []string, filter Filter) (*Result, error) {
// AsBundle loads a path as a bundle. If it is a single file
// it will be treated as a normal tarball bundle. If a directory
// is supplied it will be loaded as an unzipped bundle tree.
//
// Deprecated: Use FileLoader.AsBundle() instead.
func AsBundle(path string) (*bundle.Bundle, error) {
return NewFileLoader().AsBundle(path)

View File

@@ -68,6 +68,7 @@ func EvalInstrument(instrument bool) EvalOption {
}
// EvalTracer configures a tracer for a Prepared Query's evaluation
//
// Deprecated: Use EvalQueryTracer instead.
func EvalTracer(tracer topdown.Tracer) EvalOption {
return v1.EvalTracer(tracer)
@@ -441,6 +442,7 @@ func Trace(yes bool) func(r *Rego) {
}
// Tracer returns an argument that adds a query tracer to r.
//
// Deprecated: Use QueryTracer instead.
func Tracer(t topdown.Tracer) func(r *Rego) {
return v1.Tracer(t)

View File

@@ -752,10 +752,7 @@ func (c *CompileAnnotation) Compare(other *CompileAnnotation) int {
return -1
}
if cmp := slices.CompareFunc(c.Unknowns, other.Unknowns,
func(x, y Ref) int {
return x.Compare(y)
}); cmp != 0 {
if cmp := slices.CompareFunc(c.Unknowns, other.Unknowns, RefCompare); cmp != 0 {
return cmp
}
return c.MaskRule.Compare(other.MaskRule)

View File

@@ -26,11 +26,16 @@ func RegisterBuiltin(b *Builtin) {
BuiltinMap[b.Infix] = b
InternStringTerm(b.Infix)
InternVarValue(b.Infix)
}
InternStringTerm(b.Name)
if strings.Contains(b.Name, ".") {
InternStringTerm(strings.Split(b.Name, ".")...)
parts := strings.Split(b.Name, ".")
InternStringTerm(parts...)
InternVarValue(parts[0])
} else {
InternStringTerm(b.Name)
InternVarValue(b.Name)
}
}
@@ -3397,7 +3402,7 @@ var SetDiff = &Builtin{
),
types.SetOfAny,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3411,7 +3416,7 @@ var NetCIDROverlap = &Builtin{
),
types.B,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3423,7 +3428,7 @@ var CastArray = &Builtin{
types.Args(types.A),
types.NewArray(nil, types.A),
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3437,7 +3442,7 @@ var CastSet = &Builtin{
types.Args(types.A),
types.SetOfAny,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3449,7 +3454,7 @@ var CastString = &Builtin{
types.Args(types.A),
types.S,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3460,7 +3465,7 @@ var CastBoolean = &Builtin{
types.Args(types.A),
types.B,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3471,7 +3476,7 @@ var CastNull = &Builtin{
types.Args(types.A),
types.Nl,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3482,11 +3487,11 @@ var CastObject = &Builtin{
types.Args(types.A),
types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)),
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
// RegexMatchDeprecated declares `re_match` which has been deprecated. Use `regex.match` instead.
// RegexMatchDeprecated declares `re_match` which has been Deprecated. Use `regex.match` instead.
var RegexMatchDeprecated = &Builtin{
Name: "re_match",
Decl: types.NewFunction(
@@ -3496,7 +3501,7 @@ var RegexMatchDeprecated = &Builtin{
),
types.B,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: false,
}
@@ -3513,7 +3518,7 @@ var All = &Builtin{
),
types.B,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3530,7 +3535,7 @@ var Any = &Builtin{
),
types.B,
),
deprecated: true,
Deprecated: true,
CanSkipBctx: true,
}
@@ -3548,7 +3553,7 @@ type Builtin struct {
Decl *types.Function `json:"decl"` // Built-in function type declaration.
Infix string `json:"infix,omitempty"` // Unique name of infix operator. Default should be unset.
Relation bool `json:"relation,omitempty"` // Indicates if the built-in acts as a relation.
deprecated bool `json:"-"` // Indicates if the built-in has been deprecated.
Deprecated bool `json:"deprecated,omitempty"` // Indicates if the built-in has been deprecated.
CanSkipBctx bool `json:"-"` // Built-in needs no data from the built-in context.
Nondeterministic bool `json:"nondeterministic,omitempty"` // Indicates if the built-in returns non-deterministic results.
}
@@ -3573,12 +3578,12 @@ func (b *Builtin) Minimal() *Builtin {
return &cpy
}
// IsDeprecated returns true if the Builtin function is deprecated and will be removed in a future release.
// IsDeprecated returns true if the Builtin function is Deprecated and will be removed in a future release.
func (b *Builtin) IsDeprecated() bool {
return b.deprecated
return b.Deprecated
}
// IsDeterministic returns true if the Builtin function returns non-deterministic results.
// IsNondeterministic returns true if the Builtin function returns non-deterministic results.
func (b *Builtin) IsNondeterministic() bool {
return b.Nondeterministic
}

View File

@@ -228,13 +228,8 @@ func LoadCapabilitiesVersions() ([]string, error) {
// MinimumCompatibleVersion returns the minimum compatible OPA version based on
// the built-ins, features, and keywords in c.
func (c *Capabilities) MinimumCompatibleVersion() (string, bool) {
var maxVersion semver.Version
// this is the oldest OPA release that includes capabilities
if err := maxVersion.Set("0.17.0"); err != nil {
panic("unreachable")
}
maxVersion := semver.MustParse("0.17.0")
minVersionIndex := minVersionIndexOnce()
for _, bi := range c.Builtins {

View File

@@ -383,10 +383,6 @@ func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error {
}
func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error {
args := expr.Operands()
pre := getArgTypes(env, args)
// NOTE(tsandall): undefined functions will have been caught earlier in the
// compiler. We check for undefined functions before the safety check so
// that references to non-existent functions result in undefined function
@@ -424,12 +420,14 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error {
namedFargs.Args = append(namedFargs.Args, ftpe.NamedResult())
}
args := expr.Operands()
if len(args) > len(fargs.Args) && fargs.Variadic == nil {
return newArgError(expr.Location, name, "too many arguments", pre, namedFargs)
return newArgError(expr.Location, name, "too many arguments", getArgTypes(env, args), namedFargs)
}
if len(args) < len(ftpe.FuncArgs().Args) {
return newArgError(expr.Location, name, "too few arguments", pre, namedFargs)
return newArgError(expr.Location, name, "too few arguments", getArgTypes(env, args), namedFargs)
}
for i := range args {

View File

@@ -440,6 +440,7 @@ func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
}
// WithBuiltins is deprecated.
//
// Deprecated: Use WithCapabilities instead.
func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
c.customBuiltins = maps.Clone(builtins)
@@ -447,6 +448,7 @@ func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
}
// WithUnsafeBuiltins is deprecated.
//
// Deprecated: Use WithCapabilities instead.
func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
maps.Copy(c.unsafeBuiltinsMap, unsafeBuiltins)

View File

@@ -33,7 +33,8 @@ func CompileModulesWithOpt(modules map[string]string, opts CompileOpts) (*Compil
compiler := NewCompiler().
WithDefaultRegoVersion(opts.ParserOptions.RegoVersion).
WithEnablePrintStatements(opts.EnablePrintStatements)
WithEnablePrintStatements(opts.EnablePrintStatements).
WithCapabilities(opts.ParserOptions.Capabilities)
compiler.Compile(parsed)
if compiler.Failed() {

View File

@@ -29,6 +29,7 @@ func newTypeEnv(f func() *typeChecker) *TypeEnv {
}
// Get returns the type of x.
//
// Deprecated: Use GetByValue or GetByRef instead, as they are more efficient.
func (env *TypeEnv) Get(x any) types.Type {
if term, ok := x.(*Term); ok {

View File

@@ -99,19 +99,24 @@ func (e *Error) Error() string {
}
}
msg := fmt.Sprintf("%v: %v", e.Code, e.Message)
sb := strings.Builder{}
if len(prefix) > 0 {
msg = prefix + ": " + msg
sb.WriteString(prefix)
sb.WriteString(": ")
}
sb.WriteString(e.Code)
sb.WriteString(": ")
sb.WriteString(e.Message)
if e.Details != nil {
for _, line := range e.Details.Lines() {
msg += "\n\t" + line
sb.WriteString("\n\t")
sb.WriteString(line)
}
}
return msg
return sb.String()
}
// NewError returns a new Error object.

View File

@@ -884,7 +884,6 @@ func indexValue(b Value) (Value, bool) {
}
func globDelimiterToString(delim *Term) (string, bool) {
arr, ok := delim.Value.(*Array)
if !ok {
return "", false
@@ -895,14 +894,16 @@ func globDelimiterToString(delim *Term) (string, bool) {
if arr.Len() == 0 {
result = "."
} else {
sb := strings.Builder{}
for i := range arr.Len() {
term := arr.Elem(i)
s, ok := term.Value.(String)
if !ok {
return "", false
}
result += string(s)
sb.WriteString(string(s))
}
result = sb.String()
}
return result, true

View File

@@ -28,7 +28,7 @@ var (
InternedEmptyString = StringTerm("")
InternedEmptyObject = ObjectTerm()
InternedEmptyArray = ArrayTerm()
InternedEmptyArray = NewTerm(InternedEmptyArrayValue)
InternedEmptySet = SetTerm()
InternedEmptyArrayValue = NewArray()
@@ -40,6 +40,15 @@ var (
internedStringTerms = map[string]*Term{
"": InternedEmptyString,
}
internedVarValues = map[string]Value{
"input": Var("input"),
"data": Var("data"),
"key": Var("key"),
"value": Var("value"),
"i": Var("i"), "j": Var("j"), "k": Var("k"), "v": Var("v"), "x": Var("x"), "y": Var("y"), "z": Var("z"),
}
)
// InternStringTerm interns the given strings as terms. Note that Interning is
@@ -56,6 +65,20 @@ func InternStringTerm(str ...string) {
}
}
// InternVarValue interns the given variable names as Var Values. Note that Interning is
// considered experimental and should not be relied upon by external code.
// WARNING: This must **only** be called at initialization time, as the
// interned terms are shared globally, and the underlying map is not thread-safe.
func InternVarValue(names ...string) {
for _, name := range names {
if _, ok := internedVarValues[name]; ok {
continue
}
internedVarValues[name] = Var(name)
}
}
// HasInternedValue returns true if the given value is interned, otherwise false.
func HasInternedValue[T internable](v T) bool {
switch value := any(v).(type) {
@@ -94,6 +117,16 @@ func InternedValue[T internable](v T) Value {
return InternedValueOr(v, internedTermValue)
}
// InternedVarValue returns an interned Var Value for the given name. If the
// name is not interned, a new Var Value is returned.
func InternedVarValue(name string) Value {
if v, ok := internedVarValues[name]; ok {
return v
}
return Var(name)
}
// InternedValueOr returns an interned Value for scalar v. Calls supplier
// to produce a Value if the value is not interned.
func InternedValueOr[T internable](v T, supplier func(T) Value) Value {

View File

@@ -26,6 +26,7 @@ import (
"github.com/open-policy-agent/opa/v1/ast/internal/tokens"
astJSON "github.com/open-policy-agent/opa/v1/ast/json"
"github.com/open-policy-agent/opa/v1/ast/location"
"github.com/open-policy-agent/opa/v1/util"
)
// DefaultMaxParsingRecursionDepth is the default maximum recursion
@@ -57,6 +58,21 @@ const (
RegoV1
)
var (
// this is the name to use for instantiating an empty set, e.g., `set()`.
setConstructor = RefTerm(VarTerm("set"))
preAllocWildcards = [...]Value{
Var("$0"), Var("$1"), Var("$2"), Var("$3"), Var("$4"), Var("$5"),
Var("$6"), Var("$7"), Var("$8"), Var("$9"), Var("$10"),
}
// use static references to avoid allocations, and
// copy them to the call term only when needed
memberWithKeyRef = MemberWithKey.Ref()
memberRef = Member.Ref()
)
func (v RegoVersion) Int() int {
if v == RegoV1 {
return 1
@@ -88,17 +104,17 @@ func RegoVersionFromInt(i int) RegoVersion {
// can do efficient shallow copies of these values when doing a
// save() and restore().
type state struct {
s *scanner.Scanner
lastEnd int
skippedNL bool
tok tokens.Token
tokEnd int
lit string
loc Location
errors Errors
hints []string
comments []*Comment
hints []string
s *scanner.Scanner
loc Location
lit string
lastEnd int
tokEnd int
wildcard int
tok tokens.Token
skippedNL bool
}
func (s *state) String() string {
@@ -451,7 +467,6 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) {
// next type of statement. If a statement can be parsed, continue from that
// point trying to parse packages, imports, etc. in the same order.
for p.s.tok != tokens.EOF {
s := p.save()
if pkg := p.parsePackage(); pkg != nil {
@@ -512,12 +527,12 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) {
}
func (p *Parser) parseAnnotations(stmts []Statement) []Statement {
annotStmts, errs := parseAnnotations(p.s.comments)
for _, err := range errs {
p.error(err.Location, err.Message)
}
stmts = slices.Grow(stmts, len(annotStmts))
for _, annotStmt := range annotStmts {
stmts = append(stmts, annotStmt)
}
@@ -545,11 +560,11 @@ func parseAnnotations(comments []*Comment) ([]*Annotations, Errors) {
}
}
var stmts []*Annotations
stmts := make([]*Annotations, 0, len(blocks))
var errs Errors
for _, b := range blocks {
a, err := b.Parse()
if err != nil {
if a, err := b.Parse(); err != nil {
errs = append(errs, &Error{
Code: ParseErr,
Message: err.Error(),
@@ -564,14 +579,13 @@ func parseAnnotations(comments []*Comment) ([]*Annotations, Errors) {
}
func (p *Parser) parsePackage() *Package {
var pkg Package
pkg.SetLoc(p.s.Loc())
if p.s.tok != tokens.Package {
return nil
}
var pkg Package
pkg.SetLoc(p.s.Loc())
p.scanWS()
// Make sure we allow the first term of refs to be the 'package' keyword.
@@ -633,14 +647,13 @@ func (p *Parser) parsePackage() *Package {
}
func (p *Parser) parseImport() *Import {
var imp Import
imp.SetLoc(p.s.Loc())
if p.s.tok != tokens.Import {
return nil
}
var imp Import
imp.SetLoc(p.s.Loc())
p.scanWS()
// Make sure we allow the first term of refs to be the 'import' keyword.
@@ -952,7 +965,7 @@ func (p *Parser) parseRules() []*Rule {
next.Head.keywords = rule.Head.keywords
for i := range next.Head.Args {
if v, ok := next.Head.Args[i].Value.(Var); ok && v.IsWildcard() {
next.Head.Args[i].Value = Var(p.genwildcard())
next.Head.Args[i].Value = p.genwildcard()
}
}
setLocRecursive(next.Head, loc)
@@ -972,7 +985,7 @@ func (p *Parser) parseElse(head *Head) *Rule {
rule.Head.generatedValue = false
for i := range rule.Head.Args {
if v, ok := rule.Head.Args[i].Value.(Var); ok && v.IsWildcard() {
rule.Head.Args[i].Value = Var(p.genwildcard())
rule.Head.Args[i].Value = p.genwildcard()
}
}
rule.Head.SetLoc(p.s.Loc())
@@ -1281,14 +1294,11 @@ func (p *Parser) parseLiteralExpr(negated bool) *Expr {
}
func (p *Parser) parseWith() []*With {
withs := []*With{}
for {
with := With{Location: p.s.Loc()}
with := With{
Location: p.s.Loc(),
}
p.scan()
if p.s.tok != tokens.Ident {
@@ -1525,11 +1535,6 @@ func (p *Parser) parseTermInfixCallInList() *Term {
return p.parseTermIn(nil, false, p.s.loc.Offset)
}
// use static references to avoid allocations, and
// copy them to the call term only when needed
var memberWithKeyRef = MemberWithKey.Ref()
var memberRef = Member.Ref()
func (p *Parser) parseTermIn(lhs *Term, keyVal bool, offset int) *Term {
if !p.enter() {
return nil
@@ -1898,9 +1903,6 @@ func (p *Parser) parseRawString() *Term {
return StringTerm(p.s.lit[1 : len(p.s.lit)-1]).SetLocation(p.s.Loc())
}
// this is the name to use for instantiating an empty set, e.g., `set()`.
var setConstructor = RefTerm(VarTerm("set"))
func (p *Parser) parseCall(operator *Term, offset int) (term *Term) {
if !p.enter() {
return nil
@@ -1978,7 +1980,7 @@ func (p *Parser) parseRef(head *Term, offset int) (term *Term) {
term = p.parseRef(term, offset)
}
}
end = p.s.tokEnd
end = p.s.lastEnd
return term
case tokens.LBrack:
p.scan()
@@ -2042,7 +2044,6 @@ func (p *Parser) parseArray() (term *Term) {
// Does this represent a set comprehension or a set containing binary OR
// call? We resolve the ambiguity by prioritizing comprehensions.
head := p.parseTerm()
if head == nil {
return nil
}
@@ -2286,7 +2287,7 @@ func (p *Parser) parseTermList(end tokens.Token, r []*Term) []*Term {
}
continue
default:
p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end))
p.illegal("expected %q or %q", tokens.Comma, end)
return nil
}
}
@@ -2316,12 +2317,12 @@ func (p *Parser) parseTermPairList(end tokens.Token, r [][2]*Term) [][2]*Term {
}
continue
default:
p.illegal(fmt.Sprintf("expected %q or %q", tokens.Comma, end))
p.illegal("expected %q or %q", tokens.Comma, end)
return nil
}
}
default:
p.illegal(fmt.Sprintf("expected %q", tokens.Colon))
p.illegal("expected %q", tokens.Colon)
return nil
}
}
@@ -2353,48 +2354,69 @@ func (p *Parser) parseTermOpName(ref Ref, values ...tokens.Token) *Term {
}
func (p *Parser) parseVar() *Term {
s := p.s.lit
term := VarTerm(s).SetLocation(p.s.Loc())
// Update wildcard values with unique identifiers
if term.Equal(Wildcard) {
term.Value = Var(p.genwildcard())
if p.s.lit == WildcardString {
// Update wildcard values with unique identifiers
return NewTerm(p.genwildcard()).SetLocation(p.s.Loc())
}
return term
return VarTerm(p.s.lit).SetLocation(p.s.Loc())
}
func (p *Parser) genwildcard() string {
c := p.s.wildcard
func (p *Parser) genwildcard() Value {
var v Value
if p.s.wildcard < len(preAllocWildcards) {
v = preAllocWildcards[p.s.wildcard]
} else {
v = Var(WildcardPrefix + strconv.Itoa(p.s.wildcard))
}
p.s.wildcard++
return fmt.Sprintf("%v%d", WildcardPrefix, c)
return v
}
func (p *Parser) error(loc *location.Location, reason string) {
p.errorf(loc, "%s", reason)
}
func (p *Parser) errorf(loc *location.Location, f string, a ...any) {
msg := strings.Builder{}
msg.WriteString(fmt.Sprintf(f, a...))
switch len(p.s.hints) {
func writeHints(msg *strings.Builder, hints []string) {
switch len(hints) {
case 0: // nothing to do
case 1:
msg.WriteString(" (hint: ")
msg.WriteString(p.s.hints[0])
msg.WriteRune(')')
msg.WriteString(hints[0])
msg.WriteByte(')')
default:
msg.WriteString(" (hints: ")
for i, h := range p.s.hints {
for i, h := range hints {
if i > 0 {
msg.WriteString(", ")
}
msg.WriteString(h)
}
msg.WriteRune(')')
msg.WriteByte(')')
}
}
func (p *Parser) error(loc *location.Location, reason string) {
msg := reason
if len(p.s.hints) > 0 {
sb := &strings.Builder{}
sb.WriteString(reason)
writeHints(sb, p.s.hints)
msg = sb.String()
}
p.s.errors = append(p.s.errors, &Error{
Code: ParseErr,
Message: msg,
Location: loc,
Details: newParserErrorDetail(p.s.s.Bytes(), loc.Offset),
})
p.s.hints = nil
}
func (p *Parser) errorf(loc *location.Location, f string, a ...any) {
msg := &strings.Builder{}
fmt.Fprintf(msg, f, a...)
if len(p.s.hints) > 0 {
writeHints(msg, p.s.hints)
}
p.s.errors = append(p.s.errors, &Error{
@@ -2406,28 +2428,25 @@ func (p *Parser) errorf(loc *location.Location, f string, a ...any) {
p.s.hints = nil
}
func (p *Parser) hint(f string, a ...any) {
p.s.hints = append(p.s.hints, fmt.Sprintf(f, a...))
func (p *Parser) hint(s string) {
p.s.hints = append(p.s.hints, s)
}
func (p *Parser) illegal(note string, a ...any) {
tok := p.s.tok.String()
if p.s.tok == tokens.Illegal {
p.errorf(p.s.Loc(), "illegal token")
return
}
tok := p.s.tok.String()
tokType := "token"
if tokens.IsKeyword(p.s.tok) {
tokType = "keyword"
} else if _, ok := allFutureKeywords[p.s.tok.String()]; ok {
if _, ok := allFutureKeywords[tok]; ok || tokens.IsKeyword(p.s.tok) {
tokType = "keyword"
}
note = fmt.Sprintf(note, a...)
if len(note) > 0 {
p.errorf(p.s.Loc(), "unexpected %s %s: %s", tok, tokType, note)
p.errorf(p.s.Loc(), "unexpected %s %s: %s", tok, tokType, fmt.Sprintf(note, a...))
} else {
p.errorf(p.s.Loc(), "unexpected %s %s", tok, tokType)
}
@@ -2999,10 +3018,7 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke
return
}
kwds := make([]string, 0, len(allowedFutureKeywords))
for k := range allowedFutureKeywords {
kwds = append(kwds, k)
}
kwds := util.Keys(allowedFutureKeywords)
switch len(path) {
case 2: // all keywords imported, nothing to do
@@ -3052,10 +3068,7 @@ func (p *Parser) regoV1Import(imp *Import) {
}
// import all future keywords with the rego.v1 import
kwds := make([]string, 0, len(futureKeywordsV0))
for k := range futureKeywordsV0 {
kwds = append(kwds, k)
}
kwds := util.Keys(futureKeywordsV0)
p.s.s.SetRegoV1Compatible()
for _, kw := range kwds {

View File

@@ -86,7 +86,11 @@ var ReservedVars = NewVarSet(
)
// Wildcard represents the wildcard variable as defined in the language.
var Wildcard = &Term{Value: Var("_")}
var (
WildcardString = "_"
WildcardValue Value = Var(WildcardString)
Wildcard = &Term{Value: WildcardValue}
)
// WildcardPrefix is the special character that all wildcard variables are
// prefixed with when the statement they are contained in is parsed.
@@ -375,8 +379,10 @@ func (mod *Module) String() string {
appendAnnotationStrings := func(buf []string, node Node) []string {
if as, ok := byNode[node]; ok {
for i := range as {
buf = append(buf, "# METADATA")
buf = append(buf, "# "+as[i].String())
buf = append(buf,
"# METADATA",
"# "+as[i].String(),
)
}
}
return buf
@@ -726,6 +732,7 @@ func (rule *Rule) SetLoc(loc *Location) {
// Path returns a ref referring to the document produced by this rule. If rule
// is not contained in a module, this function panics.
//
// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead.
func (rule *Rule) Path() Ref {
if rule.Module == nil {

View File

@@ -1,33 +1,40 @@
package ast
import (
"bytes"
"strings"
"sync"
"github.com/open-policy-agent/opa/v1/util"
)
type termPtrPool struct {
pool sync.Pool
}
var (
TermPtrPool = util.NewSyncPool[Term]()
BytesReaderPool = util.NewSyncPool[bytes.Reader]()
IndexResultPool = util.NewSyncPool[IndexResult]()
bbPool = util.NewSyncPool[bytes.Buffer]()
// Needs custom pool because of custom Put logic.
sbPool = &stringBuilderPool{
pool: sync.Pool{
New: func() any {
return &strings.Builder{}
},
},
}
// Needs custom pool because of custom Put logic.
varVisitorPool = &vvPool{
pool: sync.Pool{
New: func() any {
return NewVarVisitor()
},
},
}
)
type stringBuilderPool struct {
pool sync.Pool
}
type indexResultPool struct {
pool sync.Pool
}
type vvPool struct {
pool sync.Pool
}
func (p *termPtrPool) Get() *Term {
return p.pool.Get().(*Term)
}
func (p *termPtrPool) Put(t *Term) {
p.pool.Put(t)
}
type (
stringBuilderPool struct{ pool sync.Pool }
vvPool struct{ pool sync.Pool }
)
func (p *stringBuilderPool) Get() *strings.Builder {
return p.pool.Get().(*strings.Builder)
@@ -38,16 +45,6 @@ func (p *stringBuilderPool) Put(sb *strings.Builder) {
p.pool.Put(sb)
}
func (p *indexResultPool) Get() *IndexResult {
return p.pool.Get().(*IndexResult)
}
func (p *indexResultPool) Put(x *IndexResult) {
if x != nil {
p.pool.Put(x)
}
}
func (p *vvPool) Get() *VarVisitor {
return p.pool.Get().(*VarVisitor)
}
@@ -58,35 +55,3 @@ func (p *vvPool) Put(vv *VarVisitor) {
p.pool.Put(vv)
}
}
var TermPtrPool = &termPtrPool{
pool: sync.Pool{
New: func() any {
return &Term{}
},
},
}
var sbPool = &stringBuilderPool{
pool: sync.Pool{
New: func() any {
return &strings.Builder{}
},
},
}
var varVisitorPool = &vvPool{
pool: sync.Pool{
New: func() any {
return NewVarVisitor()
},
},
}
var IndexResultPool = &indexResultPool{
pool: sync.Pool{
New: func() any {
return &IndexResult{}
},
},
}

View File

@@ -2,7 +2,6 @@
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// nolint: deadcode // Public API.
package ast
import (
@@ -824,7 +823,7 @@ type Var string
// VarTerm creates a new Term with a Variable value.
func VarTerm(v string) *Term {
return &Term{Value: Var(v)}
return &Term{Value: InternedVarValue(v)}
}
// Equal returns true if the other Value is a Variable and has the same value
@@ -881,7 +880,7 @@ func (v Var) String() string {
// illegal variable name character (WildcardPrefix) to avoid conflicts. When
// we serialize the variable here, we need to make sure it's parseable.
if v.IsWildcard() {
return Wildcard.String()
return WildcardString
}
return string(v)
}
@@ -1154,12 +1153,6 @@ func IsVarCompatibleString(s string) bool {
return varRegexp.MatchString(s)
}
var bbPool = &sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
func (ref Ref) String() string {
// Note(anderseknert):
// Options tried in the order of cheapness, where after some effort,
@@ -1181,7 +1174,7 @@ func (ref Ref) String() string {
_var := ref[0].Value.String()
bb := bbPool.Get().(*bytes.Buffer)
bb := bbPool.Get()
bb.Reset()
defer bbPool.Put(bb)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,7 @@ package ast
// can return a Visitor w which will be used to visit the children of the AST
// element v. If the Visit function returns nil, the children will not be
// visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type Visitor interface {
Visit(v any) (w Visitor)
@@ -15,6 +16,7 @@ type Visitor interface {
// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
// and after the AST has been visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type BeforeAndAfterVisitor interface {
Visitor
@@ -24,6 +26,7 @@ type BeforeAndAfterVisitor interface {
// Walk iterates the AST by calling the Visit function on the Visitor
// v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
func Walk(v Visitor, x any) {
if bav, ok := v.(BeforeAndAfterVisitor); !ok {
@@ -37,6 +40,7 @@ func Walk(v Visitor, x any) {
// WalkBeforeAndAfter iterates the AST by calling the Visit function on the
// Visitor v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x any) {
Walk(v, x)

View File

@@ -6,9 +6,7 @@
package bundle
import (
"archive/tar"
"bytes"
"compress/gzip"
"encoding/hex"
"encoding/json"
"errors"
@@ -24,6 +22,8 @@ import (
"sync"
"github.com/gobwas/glob"
"golang.org/x/sync/errgroup"
"github.com/open-policy-agent/opa/internal/file/archive"
"github.com/open-policy-agent/opa/internal/merge"
"github.com/open-policy-agent/opa/v1/ast"
@@ -51,6 +51,10 @@ const (
SnapshotBundleType = "snapshot"
)
var (
empty Bundle
)
// Bundle represents a loaded bundle. The bundle can contain data and policies.
type Bundle struct {
Signatures SignaturesConfig
@@ -96,7 +100,7 @@ type SignaturesConfig struct {
// isEmpty returns if the SignaturesConfig is empty.
func (s SignaturesConfig) isEmpty() bool {
return reflect.DeepEqual(s, SignaturesConfig{})
return s.Signatures == nil && s.Plugin == ""
}
// DecodedSignature represents the decoded JWT payload.
@@ -186,7 +190,6 @@ func (m *Manifest) SetRegoVersion(v ast.RegoVersion) {
// Equal returns true if m is semantically equivalent to other.
func (m Manifest) Equal(other Manifest) bool {
// This is safe since both are passed by value.
m.Init()
other.Init()
@@ -323,7 +326,6 @@ func (ss stringSet) Equal(other stringSet) bool {
}
func (m *Manifest) validateAndInjectDefaults(b Bundle) error {
m.Init()
// Validate roots in bundle.
@@ -337,7 +339,7 @@ func (m *Manifest) validateAndInjectDefaults(b Bundle) error {
for i := range len(roots) - 1 {
for j := i + 1; j < len(roots); j++ {
if RootPathsOverlap(roots[i], roots[j]) {
return fmt.Errorf("manifest has overlapped roots: '%v' and '%v'", roots[i], roots[j])
return fmt.Errorf("manifest has overlapped roots: '%s' and '%s'", roots[i], roots[j])
}
}
}
@@ -349,7 +351,7 @@ func (m *Manifest) validateAndInjectDefaults(b Bundle) error {
found = RootPathsContain(roots, path)
}
if !found {
return fmt.Errorf("manifest roots %v do not permit '%v' in module '%v'", roots, module.Parsed.Package, module.Path)
return fmt.Errorf("manifest roots %v do not permit '%v' in module '%s'", roots, module.Parsed.Package, module.Path)
}
}
@@ -368,7 +370,7 @@ func (m *Manifest) validateAndInjectDefaults(b Bundle) error {
// Ensure wasm module entrypoint in within bundle roots
if !RootPathsContain(roots, wmConfig.Entrypoint) {
return fmt.Errorf("manifest roots %v do not permit '%v' entrypoint for wasm module '%v'", roots, wmConfig.Entrypoint, wmConfig.Module)
return fmt.Errorf("manifest roots %v do not permit '%s' entrypoint for wasm module '%s'", roots, wmConfig.Entrypoint, wmConfig.Module)
}
if _, ok := seenEps[wmConfig.Entrypoint]; ok {
@@ -504,14 +506,13 @@ func NewReader(r io.Reader) *Reader {
// NewCustomReader returns a new Reader configured to use the
// specified DirectoryLoader.
func NewCustomReader(loader DirectoryLoader) *Reader {
nr := Reader{
return &Reader{
loader: loader,
metrics: metrics.New(),
metrics: metrics.NoOp(),
files: make(map[string]FileInfo),
sizeLimitBytes: DefaultSizeLimitBytes + 1,
lazyLoadingMode: HasExtension(),
}
return &nr
}
// IncludeManifestInData sets whether the manifest metadata should be
@@ -620,24 +621,17 @@ func (r *Reader) ParserOptions() ast.ParserOptions {
// Read returns a new Bundle loaded from the reader.
func (r *Reader) Read() (Bundle, error) {
var bundle Bundle
var descriptors []*Descriptor
var err error
var raw []Raw
bundle.Signatures, bundle.Patch, descriptors, err = preProcessBundle(r.loader, r.skipVerify, r.sizeLimitBytes)
bundle, descriptors, err := preProcessBundle(r.loader, r.skipVerify, r.sizeLimitBytes)
if err != nil {
return bundle, err
return empty, err
}
bundle.lazyLoadingMode = r.lazyLoadingMode
bundle.sizeLimitBytes = r.sizeLimitBytes
if bundle.Type() == SnapshotBundleType {
err = r.checkSignaturesAndDescriptors(bundle.Signatures)
if err != nil {
return bundle, err
if err := r.checkSignaturesAndDescriptors(bundle.Signatures); err != nil {
return empty, err
}
bundle.Data = map[string]any{}
@@ -647,7 +641,7 @@ func (r *Reader) Read() (Bundle, error) {
for _, f := range descriptors {
buf, err := readFile(f, r.sizeLimitBytes)
if err != nil {
return bundle, err
return empty, err
}
// verify the file content
@@ -663,7 +657,7 @@ func (r *Reader) Read() (Bundle, error) {
delete(r.files, path)
} else {
if err = r.verifyBundleFile(path, buf); err != nil {
return bundle, err
return empty, err
}
}
}
@@ -690,7 +684,7 @@ func (r *Reader) Read() (Bundle, error) {
p = modulePathWithPrefix(r.name, fullPath)
}
raw = append(raw, Raw{Path: p, Value: bs, module: &mf})
bundle.Raw = append(bundle.Raw, Raw{Path: p, Value: bs, module: &mf})
}
} else if filepath.Base(path) == WasmFile {
bundle.WasmModules = append(bundle.WasmModules, WasmModuleFile{
@@ -706,7 +700,7 @@ func (r *Reader) Read() (Bundle, error) {
})
} else if filepath.Base(path) == dataFile {
if r.lazyLoadingMode {
raw = append(raw, Raw{Path: path, Value: buf.Bytes()})
bundle.Raw = append(bundle.Raw, Raw{Path: path, Value: buf.Bytes()})
continue
}
@@ -717,16 +711,16 @@ func (r *Reader) Read() (Bundle, error) {
r.metrics.Timer(metrics.RegoDataParse).Stop()
if err != nil {
return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err)
return empty, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err)
}
if err := insertValue(&bundle, path, value); err != nil {
return bundle, err
if err := insertValue(bundle, path, value); err != nil {
return empty, err
}
} else if filepath.Base(path) == yamlDataFile || filepath.Base(path) == ymlDataFile {
if r.lazyLoadingMode {
raw = append(raw, Raw{Path: path, Value: buf.Bytes()})
bundle.Raw = append(bundle.Raw, Raw{Path: path, Value: buf.Bytes()})
continue
}
@@ -737,16 +731,16 @@ func (r *Reader) Read() (Bundle, error) {
r.metrics.Timer(metrics.RegoDataParse).Stop()
if err != nil {
return bundle, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err)
return empty, fmt.Errorf("bundle load failed on %v: %w", r.fullPath(path), err)
}
if err := insertValue(&bundle, path, value); err != nil {
return bundle, err
if err := insertValue(bundle, path, value); err != nil {
return empty, err
}
} else if strings.HasSuffix(path, ManifestExt) {
if err := util.NewJSONDecoder(&buf).Decode(&bundle.Manifest); err != nil {
return bundle, fmt.Errorf("bundle load failed on manifest decode: %w", err)
return empty, fmt.Errorf("bundle load failed on manifest decode: %w", err)
}
}
}
@@ -754,52 +748,63 @@ func (r *Reader) Read() (Bundle, error) {
// Parse modules
popts := r.ParserOptions()
popts.RegoVersion = bundle.RegoVersion(popts.EffectiveRegoVersion())
for _, mf := range modules {
modulePopts := popts
g := &errgroup.Group{}
r.metrics.Timer(metrics.RegoModuleParse).Start()
for i, mf := range modules {
mpopts := popts
if regoVersion, err := bundle.RegoVersionForFile(mf.RelativePath, popts.EffectiveRegoVersion()); err != nil {
return bundle, err
return *bundle, err
} else if regoVersion != ast.RegoUndefined {
// We don't expect ast.RegoUndefined here, but don't override configured rego-version if we do just to be extra protective
modulePopts.RegoVersion = regoVersion
// We don't expect ast.RegoUndefined here, but don't override
// configured rego-version if we do just to be extra protective
mpopts.RegoVersion = regoVersion
}
r.metrics.Timer(metrics.RegoModuleParse).Start()
mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, util.ByteSliceToString(mf.Raw), modulePopts)
r.metrics.Timer(metrics.RegoModuleParse).Stop()
if err != nil {
return bundle, err
}
bundle.Modules = append(bundle.Modules, mf)
g.Go(func() (err error) {
if mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, util.ByteSliceToString(mf.Raw), mpopts); err == nil {
modules[i] = mf
}
return err
})
}
err = g.Wait()
r.metrics.Timer(metrics.RegoModuleParse).Stop()
if err != nil {
return empty, err
}
bundle.Modules = modules
if bundle.Type() == DeltaBundleType {
if len(bundle.Data) != 0 {
return bundle, errors.New("delta bundle expected to contain only patch file but data files found")
return empty, errors.New("delta bundle expected to contain only patch file but data files found")
}
if len(bundle.Modules) != 0 {
return bundle, errors.New("delta bundle expected to contain only patch file but policy files found")
return empty, errors.New("delta bundle expected to contain only patch file but policy files found")
}
if len(bundle.WasmModules) != 0 {
return bundle, errors.New("delta bundle expected to contain only patch file but wasm files found")
return empty, errors.New("delta bundle expected to contain only patch file but wasm files found")
}
if r.persist {
return bundle, errors.New("'persist' property is true in config. persisting delta bundle to disk is not supported")
return empty, errors.New(
"'persist' property is true in config. persisting delta bundle to disk is not supported")
}
}
// check if the bundle signatures specify any files that weren't found in the bundle
if bundle.Type() == SnapshotBundleType && len(r.files) != 0 {
extra := []string{}
for k := range r.files {
extra = append(extra, k)
}
return bundle, fmt.Errorf("file(s) %v specified in bundle signatures but not found in the target bundle", extra)
return empty, fmt.Errorf(
"file(s) %v specified in bundle signatures but not found in the target bundle", util.Keys(r.files))
}
if err := bundle.Manifest.validateAndInjectDefaults(bundle); err != nil {
return bundle, err
if err := bundle.Manifest.validateAndInjectDefaults(*bundle); err != nil {
return empty, err
}
// Inject the wasm module entrypoint refs into the WasmModuleFile structs
@@ -812,36 +817,33 @@ func (r *Reader) Read() (Bundle, error) {
for _, entrypoint := range entrypoints {
ref, err := ast.PtrRef(ast.DefaultRootDocument, entrypoint)
if err != nil {
return bundle, fmt.Errorf("failed to parse wasm module entrypoint '%s': %s", entrypoint, err)
return empty, fmt.Errorf("failed to parse wasm module entrypoint '%s': %s", entrypoint, err)
}
bundle.WasmModules[i].Entrypoints = append(bundle.WasmModules[i].Entrypoints, ref)
}
}
if r.includeManifestInData {
var metadata map[string]any
b, err := json.Marshal(&bundle.Manifest)
if err != nil {
return bundle, fmt.Errorf("bundle load failed on manifest marshal: %w", err)
return empty, fmt.Errorf("bundle load failed on manifest marshal: %w", err)
}
err = util.UnmarshalJSON(b, &metadata)
if err != nil {
return bundle, fmt.Errorf("bundle load failed on manifest unmarshal: %w", err)
var metadata map[string]any
if err := util.UnmarshalJSON(b, &metadata); err != nil {
return empty, fmt.Errorf("bundle load failed on manifest unmarshal: %w", err)
}
// For backwards compatibility always write to the old unnamed manifest path
// This will *not* be correct if >1 bundle is in use...
if err := bundle.insertData(legacyManifestStoragePath, metadata); err != nil {
return bundle, fmt.Errorf("bundle load failed on %v: %w", legacyRevisionStoragePath, err)
return empty, fmt.Errorf("bundle load failed on %v: %w", legacyRevisionStoragePath, err)
}
}
bundle.Etag = r.etag
bundle.Raw = raw
return bundle, nil
return *bundle, nil
}
func (r *Reader) isFileExcluded(path string) bool {
@@ -869,10 +871,9 @@ func (r *Reader) checkSignaturesAndDescriptors(signatures SignaturesConfig) erro
}
// verify the JWT signatures included in the `.signatures.json` file
if err := r.verifyBundleSignature(signatures); err != nil {
return err
}
return r.verifyBundleSignature(signatures)
}
return nil
}
@@ -931,19 +932,10 @@ func (w *Writer) DisableFormat(yes bool) *Writer {
// Write writes the bundle to the writer's output stream.
func (w *Writer) Write(bundle Bundle) error {
gw := gzip.NewWriter(w.w)
tw := tar.NewWriter(gw)
tw := archive.NewTarGzWriter(w.w)
bundleType := bundle.Type()
if bundleType == SnapshotBundleType {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(bundle.Data); err != nil {
return err
}
if err := archive.WriteFile(tw, "data.json", buf.Bytes()); err != nil {
if bundle.Type() == SnapshotBundleType {
if err := tw.WriteJSONFile("/data.json", bundle.Data); err != nil {
return err
}
@@ -953,7 +945,7 @@ func (w *Writer) Write(bundle Bundle) error {
path = module.Path
}
if err := archive.WriteFile(tw, path, module.Raw); err != nil {
if err := tw.WriteFile(util.WithPrefix(path, "/"), module.Raw); err != nil {
return err
}
}
@@ -969,55 +961,48 @@ func (w *Writer) Write(bundle Bundle) error {
if err := w.writePlan(tw, bundle); err != nil {
return err
}
} else if bundleType == DeltaBundleType {
if err := writePatch(tw, bundle); err != nil {
} else if bundle.Type() == DeltaBundleType {
if err := tw.WriteJSONFile("/patch.json", bundle.Patch); err != nil {
return err
}
}
if err := writeManifest(tw, bundle); err != nil {
return err
if !bundle.Manifest.Empty() {
if err := tw.WriteJSONFile("/.manifest", bundle.Manifest); err != nil {
return err
}
}
if err := tw.Close(); err != nil {
return err
}
return gw.Close()
return tw.Close()
}
func (w *Writer) writeWasm(tw *tar.Writer, bundle Bundle) error {
func (w *Writer) writeWasm(tw *archive.TarGzWriter, bundle Bundle) error {
for _, wm := range bundle.WasmModules {
path := wm.URL
if w.usePath {
path = wm.Path
}
err := archive.WriteFile(tw, path, wm.Raw)
if err != nil {
if err := tw.WriteFile(util.WithPrefix(path, "/"), wm.Raw); err != nil {
return err
}
}
if len(bundle.Wasm) > 0 {
err := archive.WriteFile(tw, "/"+WasmFile, bundle.Wasm)
if err != nil {
return err
}
if len(bundle.Wasm) == 0 {
return nil
}
return nil
return tw.WriteFile(util.WithPrefix(WasmFile, "/"), bundle.Wasm)
}
func (w *Writer) writePlan(tw *tar.Writer, bundle Bundle) error {
func (w *Writer) writePlan(tw *archive.TarGzWriter, bundle Bundle) error {
for _, wm := range bundle.PlanModules {
path := wm.URL
if w.usePath {
path = wm.Path
}
err := archive.WriteFile(tw, path, wm.Raw)
if err != nil {
if err := tw.WriteFile(util.WithPrefix(path, "/"), wm.Raw); err != nil {
return err
}
}
@@ -1025,34 +1010,7 @@ func (w *Writer) writePlan(tw *tar.Writer, bundle Bundle) error {
return nil
}
func writeManifest(tw *tar.Writer, bundle Bundle) error {
if bundle.Manifest.Empty() {
return nil
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(bundle.Manifest); err != nil {
return err
}
return archive.WriteFile(tw, ManifestExt, buf.Bytes())
}
func writePatch(tw *tar.Writer, bundle Bundle) error {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(bundle.Patch); err != nil {
return err
}
return archive.WriteFile(tw, patchFile, buf.Bytes())
}
func writeSignatures(tw *tar.Writer, bundle Bundle) error {
func writeSignatures(tw *archive.TarGzWriter, bundle Bundle) error {
if bundle.Signatures.isEmpty() {
return nil
}
@@ -1062,7 +1020,7 @@ func writeSignatures(tw *tar.Writer, bundle Bundle) error {
return err
}
return archive.WriteFile(tw, fmt.Sprintf(".%v", SignaturesFile), bs)
return tw.WriteFile(util.WithPrefix(SignaturesFile, "/."), bs)
}
func hashBundleFiles(hash SignatureHasher, b *Bundle) ([]FileInfo, error) {
@@ -1115,8 +1073,7 @@ func hashBundleFiles(hash SignatureHasher, b *Bundle) ([]FileInfo, error) {
return files, err
}
bs, err = hash.HashFile(result)
if err != nil {
if bs, err = hash.HashFile(result); err != nil {
return files, err
}
@@ -1227,10 +1184,6 @@ func (b *Bundle) GenerateSignature(signingConfig *SigningConfig, keyID string, u
return err
}
if b.Signatures.isEmpty() {
b.Signatures = SignaturesConfig{}
}
if signingConfig.Plugin != "" {
b.Signatures.Plugin = signingConfig.Plugin
}
@@ -1243,7 +1196,6 @@ func (b *Bundle) GenerateSignature(signingConfig *SigningConfig, keyID string, u
// ParsedModules returns a map of parsed modules with names that are
// unique and human readable for the given a bundle name.
func (b *Bundle) ParsedModules(bundleName string) map[string]*ast.Module {
mods := make(map[string]*ast.Module, len(b.Modules))
for _, mf := range b.Modules {
@@ -1255,9 +1207,10 @@ func (b *Bundle) ParsedModules(bundleName string) map[string]*ast.Module {
func (b *Bundle) RegoVersion(def ast.RegoVersion) ast.RegoVersion {
if v := b.Manifest.RegoVersion; v != nil {
if *v == 0 {
switch *v {
case 0:
return ast.RegoV0
} else if *v == 1 {
case 1:
return ast.RegoV1
}
}
@@ -1328,10 +1281,6 @@ func (m *Manifest) numericRegoVersionForFile(path string) (*int, error) {
// Equal returns true if this bundle's contents equal the other bundle's
// contents.
func (b Bundle) Equal(other Bundle) bool {
if !reflect.DeepEqual(b.Data, other.Data) {
return false
}
if len(b.Modules) != len(other.Modules) {
return false
}
@@ -1357,6 +1306,10 @@ func (b Bundle) Equal(other Bundle) bool {
return false
}
if !reflect.DeepEqual(b.Data, other.Data) {
return false
}
return bytes.Equal(b.Wasm, other.Wasm)
}
@@ -1487,7 +1440,6 @@ func Merge(bundles []*Bundle) (*Bundle, error) {
// If usePath is true, per-file rego-versions will be calculated using the file's ModuleFile.Path; otherwise, the file's
// ModuleFile.URL will be used.
func MergeWithRegoVersion(bundles []*Bundle, regoVersion ast.RegoVersion, usePath bool) (*Bundle, error) {
if len(bundles) == 0 {
return nil, errors.New("expected at least one bundle")
}
@@ -1512,7 +1464,6 @@ func MergeWithRegoVersion(bundles []*Bundle, regoVersion ast.RegoVersion, usePat
var result Bundle
for _, b := range bundles {
if b.Manifest.Roots == nil {
return nil, errors.New("bundle manifest not initialized")
}
@@ -1607,16 +1558,11 @@ func bundleRelativePath(m ModuleFile, usePath bool) string {
}
func bundleAbsolutePath(m ModuleFile, usePath bool) string {
var p string
p := m.URL
if usePath {
p = m.Path
} else {
p = m.URL
}
if !path.IsAbs(p) {
p = "/" + p
}
return path.Clean(p)
return path.Clean(util.WithPrefix(p, "/"))
}
// RootPathsOverlap takes in two bundle root paths and returns true if they overlap.
@@ -1642,7 +1588,6 @@ func rootPathSegments(path string) []string {
}
func rootContains(root []string, other []string) bool {
// A single segment, empty string root always contains the other.
if len(root) == 1 && root[0] == "" {
return true
@@ -1674,7 +1619,7 @@ func getNormalizedPath(path string) []string {
// other hand, if the path is empty, filepath.Dir will return '.'.
// Note: filepath.Dir can return paths with '\' separators, always use
// filepath.ToSlash to keep them normalized.
dirpath := strings.TrimLeft(normalizePath(filepath.Dir(path)), "/.")
dirpath := strings.TrimLeft(filepath.ToSlash(filepath.Dir(path)), "/.")
var key []string
if dirpath != "" {
key = strings.Split(dirpath, "/")
@@ -1701,56 +1646,52 @@ func dfs(value any, path string, fn func(string, any) (bool, error)) error {
}
func modulePathWithPrefix(bundleName string, modulePath string) string {
// Default prefix is just the bundle name
prefix := bundleName
// Bundle names are sometimes just file paths, some of which
// are full urls (file:///foo/). Parse these and only use the path.
parsed, err := url.Parse(bundleName)
if err == nil {
prefix = filepath.Join(parsed.Host, parsed.Path)
return path.Join(parsed.Host, parsed.Path, modulePath)
}
// Note: filepath.Join can return paths with '\' separators, always use
// filepath.ToSlash to keep them normalized.
return normalizePath(filepath.Join(prefix, modulePath))
return path.Join(bundleName, modulePath)
}
// IsStructuredDoc checks if the file name equals a structured file extension ex. ".json"
func IsStructuredDoc(name string) bool {
return filepath.Base(name) == dataFile || filepath.Base(name) == yamlDataFile ||
filepath.Base(name) == SignaturesFile || filepath.Base(name) == ManifestExt
base := filepath.Base(name)
return base == dataFile || base == yamlDataFile || base == SignaturesFile || base == ManifestExt
}
func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes int64) (SignaturesConfig, Patch, []*Descriptor, error) {
func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes int64) (*Bundle, []*Descriptor, error) {
bundle := &Bundle{}
descriptors := []*Descriptor{}
var signatures SignaturesConfig
var patch Patch
for {
f, err := loader.NextFile()
if err == io.EOF {
break
}
if err != nil {
return signatures, patch, nil, fmt.Errorf("bundle read failed: %w", err)
if err == io.EOF {
break
}
return bundle, nil, fmt.Errorf("bundle read failed: %w", err)
}
// check for the signatures file
if !skipVerify && strings.HasSuffix(f.Path(), SignaturesFile) {
isSignaturesFile := strings.HasSuffix(f.Path(), SignaturesFile)
if !skipVerify && isSignaturesFile {
buf, err := readFile(f, sizeLimitBytes)
if err != nil {
return signatures, patch, nil, err
return bundle, nil, err
}
if err := util.NewJSONDecoder(&buf).Decode(&signatures); err != nil {
return signatures, patch, nil, fmt.Errorf("bundle load failed on signatures decode: %w", err)
if err := util.NewJSONDecoder(&buf).Decode(&bundle.Signatures); err != nil {
return bundle, nil, fmt.Errorf("bundle load failed on signatures decode: %w", err)
}
} else if !strings.HasSuffix(f.Path(), SignaturesFile) {
} else if !isSignaturesFile {
descriptors = append(descriptors, f)
if filepath.Base(f.Path()) == patchFile {
base := filepath.Base(f.Path())
if base == patchFile {
var b bytes.Buffer
tee := io.TeeReader(f.reader, &b)
@@ -1758,18 +1699,19 @@ func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes in
buf, err := readFile(f, sizeLimitBytes)
if err != nil {
return signatures, patch, nil, err
return bundle, nil, err
}
if err := util.NewJSONDecoder(&buf).Decode(&patch); err != nil {
return signatures, patch, nil, fmt.Errorf("bundle load failed on patch decode: %w", err)
if err := util.NewJSONDecoder(&buf).Decode(&bundle.Patch); err != nil {
return bundle, nil, fmt.Errorf("bundle load failed on patch decode: %w", err)
}
f.reader = &b
}
}
}
return signatures, patch, descriptors, nil
return bundle, descriptors, nil
}
func readFile(f *Descriptor, sizeLimitBytes int64) (bytes.Buffer, error) {
@@ -1839,7 +1781,3 @@ func fstatFileSize(f *os.File) (int64, error) {
}
return fileInfo.Size(), nil
}
func normalizePath(p string) string {
return filepath.ToSlash(p)
}

View File

@@ -352,12 +352,10 @@ func (t *tarballLoader) NextFile() (*Descriptor, error) {
for {
header, err := t.tr.Next()
if err == io.EOF {
break
}
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
@@ -365,7 +363,6 @@ func (t *tarballLoader) NextFile() (*Descriptor, error) {
if header.Typeflag == tar.TypeReg {
if t.filter != nil {
if t.filter(filepath.ToSlash(header.Name), header.FileInfo(), getdepth(header.Name, false)) {
continue
}
@@ -504,9 +501,9 @@ func getdepth(path string, isDir bool) int {
}
func getFileStoragePath(path string) (storage.Path, error) {
fpath := strings.TrimLeft(normalizePath(filepath.Dir(path)), "/.")
fpath := strings.TrimLeft(filepath.ToSlash(filepath.Dir(path)), "/.")
if strings.HasSuffix(path, RegoExt) {
fpath = strings.Trim(normalizePath(path), "/")
fpath = strings.Trim(filepath.ToSlash(path), "/")
}
p, ok := storage.ParsePathEscaped("/" + fpath)

View File

@@ -571,12 +571,11 @@ func doDFS(obj map[string]json.RawMessage, path string, roots []string) error {
}
for key := range obj {
newPath := filepath.Join(strings.Trim(path, "/"), key)
// Note: filepath.Join can return paths with '\' separators, always use
// filepath.ToSlash to keep them normalized.
newPath = strings.TrimLeft(normalizePath(newPath), "/.")
newPath = strings.TrimLeft(filepath.ToSlash(newPath), "/.")
contains := false
prefix := false
@@ -1191,17 +1190,20 @@ func applyPatches(ctx context.Context, store storage.Store, txn storage.Transact
// Helpers for the older single (unnamed) bundle style manifest storage.
// LegacyManifestStoragePath is the older unnamed bundle path for manifests to be stored.
//
// Deprecated: Use ManifestStoragePath and named bundles instead.
var legacyManifestStoragePath = storage.MustParsePath("/system/bundle/manifest")
var legacyRevisionStoragePath = append(legacyManifestStoragePath, "revision")
// LegacyWriteManifestToStore will write the bundle manifest to the older single (unnamed) bundle manifest location.
//
// Deprecated: Use WriteManifestToStore and named bundles instead.
func LegacyWriteManifestToStore(ctx context.Context, store storage.Store, txn storage.Transaction, manifest Manifest) error {
return write(ctx, store, txn, legacyManifestStoragePath, manifest)
}
// LegacyEraseManifestFromStore will erase the bundle manifest from the older single (unnamed) bundle manifest location.
//
// Deprecated: Use WriteManifestToStore and named bundles instead.
func LegacyEraseManifestFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) error {
err := store.Write(ctx, txn, storage.RemoveOp, legacyManifestStoragePath, nil)
@@ -1212,12 +1214,14 @@ func LegacyEraseManifestFromStore(ctx context.Context, store storage.Store, txn
}
// LegacyReadRevisionFromStore will read the bundle manifest revision from the older single (unnamed) bundle manifest location.
//
// Deprecated: Use ReadBundleRevisionFromStore and named bundles instead.
func LegacyReadRevisionFromStore(ctx context.Context, store storage.Store, txn storage.Transaction) (string, error) {
return readRevisionFromStore(ctx, store, txn, legacyRevisionStoragePath)
}
// ActivateLegacy calls Activate for the bundles but will also write their manifest to the older unnamed store location.
//
// Deprecated: Use Activate with named bundles instead.
func ActivateLegacy(opts *ActivateOpts) error {
opts.legacy = true

View File

@@ -495,6 +495,7 @@ func loadOneSchema(path string) (any, error) {
}
// All returns a Result object loaded (recursively) from the specified paths.
//
// Deprecated: Use FileLoader.Filtered() instead.
func All(paths []string) (*Result, error) {
return NewFileLoader().Filtered(paths, nil)
@@ -503,6 +504,7 @@ func All(paths []string) (*Result, error) {
// Filtered returns a Result object loaded (recursively) from the specified
// paths while applying the given filters. If any filter returns true, the
// file/directory is excluded.
//
// Deprecated: Use FileLoader.Filtered() instead.
func Filtered(paths []string, filter Filter) (*Result, error) {
return NewFileLoader().Filtered(paths, filter)
@@ -511,6 +513,7 @@ func Filtered(paths []string, filter Filter) (*Result, error) {
// AsBundle loads a path as a bundle. If it is a single file
// it will be treated as a normal tarball bundle. If a directory
// is supplied it will be loaded as an unzipped bundle tree.
//
// Deprecated: Use FileLoader.AsBundle() instead.
func AsBundle(path string) (*bundle.Bundle, error) {
return NewFileLoader().AsBundle(path)
@@ -631,11 +634,10 @@ func (l *Result) mergeDocument(path string, doc any) error {
}
func (l *Result) withParent(p string) *Result {
path := append(l.path, p)
return &Result{
Documents: l.Documents,
Modules: l.Modules,
path: path,
path: append(l.path, p),
}
}

View File

@@ -250,9 +250,8 @@ func convertPointsToBase64(alg string, r, s []byte) (string, error) {
copy(rBytesPadded[keyBytes-len(r):], r)
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(s):], s)
signatureEnc := append(rBytesPadded, sBytesPadded...)
return base64.RawURLEncoding.EncodeToString(signatureEnc), nil
return base64.RawURLEncoding.EncodeToString(append(rBytesPadded, sBytesPadded...)), nil
}
func retrieveCurveBits(alg string) (int, error) {

View File

@@ -25,6 +25,7 @@ import (
"github.com/open-policy-agent/opa/v1/bundle"
"github.com/open-policy-agent/opa/v1/ir"
"github.com/open-policy-agent/opa/v1/loader"
"github.com/open-policy-agent/opa/v1/loader/filter"
"github.com/open-policy-agent/opa/v1/metrics"
"github.com/open-policy-agent/opa/v1/plugins"
"github.com/open-policy-agent/opa/v1/resolver"
@@ -44,7 +45,7 @@ const (
wasmVarPrefix = "^"
)
// nolint: deadcode,varcheck
// nolint:varcheck
const (
targetWasm = "wasm"
targetRego = "rego"
@@ -235,6 +236,7 @@ func EvalInstrument(instrument bool) EvalOption {
}
// EvalTracer configures a tracer for a Prepared Query's evaluation
//
// Deprecated: Use EvalQueryTracer instead.
func EvalTracer(tracer topdown.Tracer) EvalOption {
return func(e *EvalContext) {
@@ -670,6 +672,7 @@ type Rego struct {
regoVersion ast.RegoVersion
compilerHook func(*ast.Compiler)
evalMode *ast.CompilerEvalMode
filter filter.LoaderFilter
}
func (r *Rego) RegoVersion() ast.RegoVersion {
@@ -1046,6 +1049,12 @@ func LoadBundle(path string) func(r *Rego) {
}
}
func WithFilter(f filter.LoaderFilter) func(r *Rego) {
return func(r *Rego) {
r.filter = f
}
}
// ParsedBundle returns an argument that adds a bundle to be loaded.
func ParsedBundle(name string, b *bundle.Bundle) func(r *Rego) {
return func(r *Rego) {
@@ -1115,6 +1124,7 @@ func Trace(yes bool) func(r *Rego) {
}
// Tracer returns an argument that adds a query tracer to r.
//
// Deprecated: Use QueryTracer instead.
func Tracer(t topdown.Tracer) func(r *Rego) {
return func(r *Rego) {
@@ -2044,6 +2054,7 @@ func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.M
WithSkipBundleVerification(r.skipBundleVerification).
WithRegoVersion(r.regoVersion).
WithCapabilities(r.capabilities).
WithFilter(r.filter).
AsBundle(path)
if err != nil {
return fmt.Errorf("loading error: %s", err)

View File

@@ -349,10 +349,25 @@ func (h *handle) Unregister(_ context.Context, txn storage.Transaction) {
}
func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) {
if db.returnASTValuesOnRead && len(db.triggers) > 0 {
// FIXME: Not very performant for large data.
// While it's unlikely, the API allows one trigger to be configured to want
// data conversion, and another that doesn't. So let's handle that properly.
var wantsDataConversion bool
if db.returnASTValuesOnRead && len(event.Data) > 0 {
for _, t := range db.triggers {
if !t.SkipDataConversion {
wantsDataConversion = true
break
}
}
}
dataEvents := make([]storage.DataEvent, 0, len(event.Data))
var converted storage.TriggerEvent
if wantsDataConversion {
converted = storage.TriggerEvent{
Policy: event.Policy,
Data: make([]storage.DataEvent, 0, len(event.Data)),
Context: event.Context,
}
for _, dataEvent := range event.Data {
if astData, ok := dataEvent.Data.(ast.Value); ok {
@@ -360,25 +375,21 @@ func (db *store) runOnCommitTriggers(ctx context.Context, txn storage.Transactio
if err != nil {
panic(err)
}
dataEvents = append(dataEvents, storage.DataEvent{
converted.Data = append(converted.Data, storage.DataEvent{
Path: dataEvent.Path,
Data: jsn,
Removed: dataEvent.Removed,
})
} else {
dataEvents = append(dataEvents, dataEvent)
}
}
event = storage.TriggerEvent{
Policy: event.Policy,
Data: dataEvents,
Context: event.Context,
}
}
for _, t := range db.triggers {
t.OnCommit(ctx, txn, event)
if wantsDataConversion && !t.SkipDataConversion {
t.OnCommit(ctx, txn, converted)
} else {
t.OnCommit(ctx, txn, event)
}
}
}

View File

@@ -210,6 +210,10 @@ func (e TriggerEvent) DataChanged() bool {
// TriggerConfig contains the trigger registration configuration.
type TriggerConfig struct {
// SkipDataConversion when set to true, avoids converting data passed to
// trigger functions from the store to Go types, and instead passes the
// original representation (e.g., ast.Value).
SkipDataConversion bool
// OnCommit is invoked when a transaction is successfully committed. The
// callback is invoked with a handle to the write transaction that

View File

@@ -209,8 +209,7 @@ func SetOperand(x ast.Value, pos int) (ast.Set, error) {
return s, nil
}
// StringOperand converts x to a string. If the cast fails, a descriptive error is
// returned.
// StringOperand returns x as [ast.String], or a descriptive error if the conversion fails.
func StringOperand(x ast.Value, pos int) (ast.String, error) {
s, ok := x.(ast.String)
if !ok {
@@ -219,6 +218,17 @@ func StringOperand(x ast.Value, pos int) (ast.String, error) {
return s, nil
}
// StringOperandByteSlice returns x a []byte, assuming x is [ast.String], or a descriptive error
// if that is not the case. The returned byte slice points directly at the underlying array backing
// the string, and should not be modified.
func StringOperandByteSlice(x ast.Value, pos int) ([]byte, error) {
s, err := StringOperand(x, pos)
if err != nil {
return nil, err
}
return util.StringToByteSlice(string(s)), nil
}
// ObjectOperand converts x to an object. If the cast fails, a descriptive
// error is returned.
func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {

View File

@@ -255,17 +255,17 @@ func extractVerifyOpts(options ast.Object) (verifyOpt x509.VerifyOptions, err er
}
func builtinCryptoX509ParseKeyPair(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
certificate, err := builtins.StringOperand(operands[0].Value, 1)
certificate, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
key, err := builtins.StringOperand(operands[1].Value, 1)
key, err := builtins.StringOperandByteSlice(operands[1].Value, 1)
if err != nil {
return err
}
certs, err := getTLSx509KeyPairFromString([]byte(certificate), []byte(key))
certs, err := getTLSx509KeyPairFromString(certificate, key)
if err != nil {
return err
}
@@ -326,10 +326,7 @@ func builtinCryptoX509ParseCertificateRequest(_ BuiltinContext, operands []*ast.
}
func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var x any
a := operands[0].Value
input, err := builtins.StringOperand(a, 1)
input, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
@@ -371,6 +368,7 @@ func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter
return err
}
var x any
if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
return err
}
@@ -430,53 +428,51 @@ func toHexEncodedString(src []byte) string {
}
func builtinCryptoMd5(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
md5sum := md5.Sum([]byte(s))
md5sum := md5.Sum(bs)
return iter(ast.StringTerm(toHexEncodedString(md5sum[:])))
}
func builtinCryptoSha1(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
sha1sum := sha1.Sum([]byte(s))
sha1sum := sha1.Sum(bs)
return iter(ast.StringTerm(toHexEncodedString(sha1sum[:])))
}
func builtinCryptoSha256(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
sha256sum := sha256.Sum256([]byte(s))
sha256sum := sha256.Sum256(bs)
return iter(ast.StringTerm(toHexEncodedString(sha256sum[:])))
}
func hmacHelper(operands []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error {
a1 := operands[0].Value
message, err := builtins.StringOperand(a1, 1)
message, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
a2 := operands[1].Value
key, err := builtins.StringOperand(a2, 2)
key, err := builtins.StringOperandByteSlice(operands[1].Value, 2)
if err != nil {
return err
}
mac := hmac.New(h, []byte(key))
mac.Write([]byte(message))
mac := hmac.New(h, key)
mac.Write(message)
messageDigest := mac.Sum(nil)
return iter(ast.StringTerm(hex.EncodeToString(messageDigest)))
@@ -499,21 +495,17 @@ func builtinCryptoHmacSha512(_ BuiltinContext, operands []*ast.Term, iter func(*
}
func builtinCryptoHmacEqual(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a1 := operands[0].Value
mac1, err := builtins.StringOperand(a1, 1)
mac1, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
a2 := operands[1].Value
mac2, err := builtins.StringOperand(a2, 2)
mac2, err := builtins.StringOperandByteSlice(operands[1].Value, 2)
if err != nil {
return err
}
res := hmac.Equal([]byte(mac1), []byte(mac2))
return iter(ast.InternedTerm(res))
return iter(ast.InternedTerm(hmac.Equal(mac1, mac2)))
}
func init() {
@@ -668,7 +660,7 @@ func addCACertsFromFile(pool *x509.CertPool, filePath string) (*x509.CertPool, e
pool = x509.NewCertPool()
}
caCert, err := readCertFromFile(filePath)
caCert, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
@@ -703,17 +695,7 @@ func addCACertsFromEnv(pool *x509.CertPool, envName string) (*x509.CertPool, err
return nil, fmt.Errorf("could not add CA certificates from envvar %q: %w", envName, err)
}
return pool, err
}
// ReadCertFromFile reads a cert from file
func readCertFromFile(localCertFile string) ([]byte, error) {
// Read in the cert file
certPEM, err := os.ReadFile(localCertFile)
if err != nil {
return nil, err
}
return certPEM, nil
return pool, nil
}
var beginPrefix = []byte("-----BEGIN ")
@@ -771,13 +753,3 @@ func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls.
return &cert, nil
}
// ReadKeyFromFile reads a key from file
func readKeyFromFile(localKeyFile string) ([]byte, error) {
// Read in the cert file
key, err := os.ReadFile(localKeyFile)
if err != nil {
return nil, err
}
return key, nil
}

View File

@@ -5,7 +5,6 @@
package topdown
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -21,7 +20,6 @@ import (
)
func builtinJSONMarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
asJSON, err := ast.JSON(operands[0].Value)
if err != nil {
return err
@@ -32,11 +30,10 @@ func builtinJSONMarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
return err
}
return iter(ast.StringTerm(string(bs)))
return iter(ast.StringTerm(util.ByteSliceToString(bs)))
}
func builtinJSONMarshalWithOpts(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
asJSON, err := ast.JSON(operands[0].Value)
if err != nil {
return err
@@ -101,36 +98,34 @@ func builtinJSONMarshalWithOpts(_ BuiltinContext, operands []*ast.Term, iter fun
}
var bs []byte
if shouldPrettyPrint {
bs, err = json.MarshalIndent(asJSON, prefixWith, indentWith)
} else {
bs, err = json.Marshal(asJSON)
}
if err != nil {
return err
}
s := util.ByteSliceToString(bs)
if shouldPrettyPrint {
// json.MarshalIndent() function will not prefix the first line of emitted JSON
return iter(ast.StringTerm(prefixWith + string(bs)))
return iter(ast.StringTerm(prefixWith + s))
}
return iter(ast.StringTerm(string(bs)))
return iter(ast.StringTerm(s))
}
func builtinJSONUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
var x any
if err := util.UnmarshalJSON([]byte(str), &x); err != nil {
if err := util.UnmarshalJSON(bs, &x); err != nil {
return err
}
v, err := ast.InterfaceToValue(x)
@@ -141,22 +136,21 @@ func builtinJSONUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast
}
func builtinJSONIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return iter(ast.InternedTerm(false))
}
return iter(ast.InternedTerm(json.Valid([]byte(str))))
return iter(ast.InternedTerm(json.Valid(bs)))
}
func builtinBase64Encode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
return iter(ast.StringTerm(base64.StdEncoding.EncodeToString([]byte(str))))
return iter(ast.StringTerm(base64.StdEncoding.EncodeToString(bs)))
}
func builtinBase64Decode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -183,20 +177,20 @@ func builtinBase64IsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast
}
func builtinBase64UrlEncode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
return iter(ast.StringTerm(base64.URLEncoding.EncodeToString([]byte(str))))
return iter(ast.StringTerm(base64.URLEncoding.EncodeToString(bs)))
}
func builtinBase64UrlEncodeNoPad(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
return iter(ast.StringTerm(base64.RawURLEncoding.EncodeToString([]byte(str))))
return iter(ast.StringTerm(base64.RawURLEncoding.EncodeToString(bs)))
}
func builtinBase64UrlDecode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -306,45 +300,39 @@ func builtinURLQueryDecodeObject(_ BuiltinContext, operands []*ast.Term, iter fu
}
func builtinYAMLMarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
asJSON, err := ast.JSON(operands[0].Value)
if err != nil {
return err
}
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
if err := encoder.Encode(asJSON); err != nil {
return err
}
bs, err := yaml.JSONToYAML(buf.Bytes())
bs, err := yaml.Marshal(asJSON)
if err != nil {
return err
}
return iter(ast.StringTerm(string(bs)))
return iter(ast.StringTerm(util.ByteSliceToString(bs)))
}
func builtinYAMLUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
bs, err := yaml.YAMLToJSON([]byte(str))
js, err := yaml.YAMLToJSON(bs)
if err != nil {
return err
}
buf := bytes.NewBuffer(bs)
decoder := util.NewJSONDecoder(buf)
reader := ast.BytesReaderPool.Get()
defer ast.BytesReaderPool.Put(reader)
reader.Reset(js)
var val any
err = decoder.Decode(&val)
if err != nil {
if err = util.NewJSONDecoder(reader).Decode(&val); err != nil {
return err
}
v, err := ast.InterfaceToValue(val)
if err != nil {
return err
@@ -353,22 +341,22 @@ func builtinYAMLUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast
}
func builtinYAMLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return iter(ast.InternedTerm(false))
}
var x any
err = yaml.Unmarshal([]byte(str), &x)
err = yaml.Unmarshal(bs, &x)
return iter(ast.InternedTerm(err == nil))
}
func builtinHexEncode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
bs, err := builtins.StringOperandByteSlice(operands[0].Value, 1)
if err != nil {
return err
}
return iter(ast.StringTerm(hex.EncodeToString([]byte(str))))
return iter(ast.StringTerm(hex.EncodeToString(bs)))
}
func builtinHexDecode(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {

View File

@@ -119,25 +119,52 @@ type eval struct {
defined bool
}
type evp struct {
pool sync.Pool
type (
evfp struct{ pool sync.Pool }
evbp struct{ pool sync.Pool }
)
func (ep *evfp) Put(e *evalFunc) {
if e != nil {
e.e, e.terms, e.ir = nil, nil, nil
ep.pool.Put(e)
}
}
func (ep *evp) Put(e *eval) {
ep.pool.Put(e)
func (ep *evfp) Get() *evalFunc {
return ep.pool.Get().(*evalFunc)
}
func (ep *evp) Get() *eval {
return ep.pool.Get().(*eval)
func (ep *evbp) Put(e *evalBuiltin) {
if e != nil {
e.e, e.bi, e.bctx, e.f, e.terms = nil, nil, nil, nil, nil
ep.pool.Put(e)
}
}
var evalPool = evp{
pool: sync.Pool{
New: func() any {
return &eval{}
func (ep *evbp) Get() *evalBuiltin {
return ep.pool.Get().(*evalBuiltin)
}
var (
evalPool = util.NewSyncPool[eval]()
deecPool = util.NewSyncPool[deferredEarlyExitContainer]()
resolverPool = util.NewSyncPool[evalResolver]()
evalFuncPool = &evfp{
pool: sync.Pool{
New: func() any {
return &evalFunc{}
},
},
},
}
}
evalBuiltinPool = &evbp{
pool: sync.Pool{
New: func() any {
return &evalBuiltin{}
},
},
}
)
func (e *eval) Run(iter evalIterator) error {
if !e.traceEnabled {
@@ -401,9 +428,11 @@ func (e *eval) evalExpr(iter evalIterator) error {
}
return nil
}
expr := e.query[e.index]
e.traceEval(expr)
expr := e.query[e.index]
if e.traceEnabled {
e.traceEval(expr)
}
if len(expr.With) > 0 {
return e.evalWith(iter)
@@ -521,7 +550,7 @@ func (e *eval) evalStep(iter evalIterator) error {
// generateVar inlined here to avoid extra allocations in hot path
rterm := ast.VarTerm(e.fmtVarTerm())
err = e.unify(terms, rterm, func() error {
if e.saveSet.Contains(rterm, e.bindings) {
if e.saveSet != nil && e.saveSet.Contains(rterm, e.bindings) {
return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error {
return iter(e)
})
@@ -888,7 +917,6 @@ func (e *eval) evalNotPartialSupport(negationID uint64, expr *ast.Expr, unknowns
}
func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ref := terms[0].Value.(ast.Ref)
mock, mocked := e.functionMocks.Get(ref)
@@ -912,8 +940,8 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
if ref[0].Equal(ast.DefaultRootDocument) {
if mocked {
f := e.compiler.TypeEnv.Get(ref).(*types.Function)
return e.evalCallValue(f.Arity(), terms, mock, iter)
arity := e.compiler.TypeEnv.GetByRef(ref).(*types.Function).Arity()
return e.evalCallValue(arity, terms, mock, iter)
}
var ir *ast.IndexResult
@@ -928,11 +956,13 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
return err
}
eval := evalFunc{
e: e,
terms: terms,
ir: ir,
}
eval := evalFuncPool.Get()
defer evalFuncPool.Put(eval)
eval.e = e
eval.terms = terms
eval.ir = ir
return eval.eval(iter)
}
@@ -991,13 +1021,14 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
}
}
eval := evalBuiltin{
e: e,
bi: bi,
bctx: bctx,
f: f,
terms: terms[1:],
}
eval := evalBuiltinPool.Get()
defer evalBuiltinPool.Put(eval)
eval.e = e
eval.bi = bi
eval.bctx = bctx
eval.f = f
eval.terms = terms[1:]
return eval.eval(iter)
}
@@ -1054,7 +1085,9 @@ func (e *eval) biunify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) err
case ast.Var, ast.Ref, *ast.ArrayComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case *ast.Array:
return e.biunifyArrays(vA, vB, b1, b2, iter)
if vA.Len() == vB.Len() {
return e.biunifyArraysRec(vA, vB, b1, b2, iter, 0)
}
}
case ast.Object:
switch vB := b.Value.(type) {
@@ -1069,13 +1102,6 @@ func (e *eval) biunify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) err
return nil
}
func (e *eval) biunifyArrays(a, b *ast.Array, b1, b2 *bindings, iter unifyIterator) error {
if a.Len() != b.Len() {
return nil
}
return e.biunifyArraysRec(a, b, b1, b2, iter, 0)
}
func (e *eval) biunifyArraysRec(a, b *ast.Array, b1, b2 *bindings, iter unifyIterator, idx int) error {
if idx == a.Len() {
return iter()
@@ -1643,7 +1669,7 @@ func (e *eval) getRules(ref ast.Ref, args []*ast.Term) (*ast.IndexResult, error)
return nil, nil
}
resolver := resolverPool.Get().(*evalResolver)
resolver := resolverPool.Get()
defer func() {
resolver.e = nil
resolver.args = nil
@@ -1698,14 +1724,6 @@ type evalResolver struct {
args []*ast.Term
}
var (
resolverPool = sync.Pool{
New: func() any {
return &evalResolver{}
},
}
)
func (e *evalResolver) Resolve(ref ast.Ref) (ast.Value, error) {
e.e.instr.startTimer(evalOpResolve)
@@ -2052,8 +2070,7 @@ type evalFunc struct {
terms []*ast.Term
}
func (e evalFunc) eval(iter unifyIterator) error {
func (e *evalFunc) eval(iter unifyIterator) error {
if e.ir.Empty() {
return nil
}
@@ -2065,13 +2082,13 @@ func (e evalFunc) eval(iter unifyIterator) error {
argCount = len(e.ir.Default.Head.Args)
}
if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
return e.e.saveCall(argCount, e.terms, iter)
}
if e.e.partial() {
if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
return e.e.saveCall(argCount, e.terms, iter)
}
var mustGenerateSupport bool
if defRule := e.ir.Default; defRule != nil {
@@ -2109,7 +2126,7 @@ func (e evalFunc) eval(iter unifyIterator) error {
return e.evalValue(iter, argCount, e.ir.EarlyExit)
}
func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) error {
func (e *evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) error {
var cacheKey ast.Ref
if !e.e.partial() {
var hit bool
@@ -2194,7 +2211,7 @@ func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) erro
})
}
func (e evalFunc) evalCache(argCount int, iter unifyIterator) (ast.Ref, bool, error) {
func (e *evalFunc) evalCache(argCount int, iter unifyIterator) (ast.Ref, bool, error) {
plen := len(e.terms)
if plen == argCount+2 { // func name + output = 2
plen -= 1
@@ -2226,7 +2243,7 @@ func (e evalFunc) evalCache(argCount int, iter unifyIterator) (ast.Ref, bool, er
return cacheKey, false, nil
}
func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, args []*ast.Term, cacheKey ast.Ref, prev *ast.Term, findOne bool) (*ast.Term, error) {
func (e *evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, args []*ast.Term, cacheKey ast.Ref, prev *ast.Term, findOne bool) (*ast.Term, error) {
child := evalPool.Get()
defer evalPool.Put(child)
@@ -2288,7 +2305,7 @@ func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, args []*ast.Te
return result, err
}
func (e evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error {
func (e *evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error {
path := e.e.namespaceRef(e.terms[0].Value.(ast.Ref))
if !e.e.saveSupport.Exists(path) {
@@ -2316,7 +2333,7 @@ func (e evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error
return e.e.saveCall(declArgsLen, append([]*ast.Term{term}, e.terms[1:]...), iter)
}
func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error {
func (e *evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error {
child := evalPool.Get()
defer evalPool.Put(child)
@@ -2395,12 +2412,6 @@ func (dc *deferredEarlyExitContainer) copyError() *deferredEarlyExitError {
return &cpy
}
var deecPool = sync.Pool{
New: func() any {
return &deferredEarlyExitContainer{}
},
}
type evalTree struct {
e *eval
bindings *bindings
@@ -2486,7 +2497,7 @@ func (e evalTree) enumerate(iter unifyIterator) error {
return err
}
dc := deecPool.Get().(*deferredEarlyExitContainer)
dc := deecPool.Get()
dc.deferred = nil
defer deecPool.Put(dc)

View File

@@ -314,7 +314,6 @@ func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) {
}
return obj, nil
}
// canonicalizeHeaders returns a copy of the headers where the keys are in
@@ -333,7 +332,7 @@ func canonicalizeHeaders(headers map[string]any) map[string]any {
// a DialContext that opens a socket (specified in the http call).
// The url is expected to contain socket=/path/to/socket (url encoded)
// Ex. "unix://localhost/end/point?socket=%2Ftmp%2Fhttp.sock"
func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transport) {
func useSocket(rawURL string) (bool, string, *http.Transport) {
u, err := url.Parse(rawURL)
if err != nil {
return false, "", nil
@@ -362,7 +361,6 @@ func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transp
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return http.DefaultTransport.(*http.Transport).DialContext(ctx, "unix", socket)
}
tr.TLSClientConfig = tlsConfig
tr.DisableKeepAlives = true
return true, u.String(), tr
@@ -533,6 +531,10 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
}
}
if len(customHeaders) != 0 {
customHeaders = canonicalizeHeaders(customHeaders)
}
isTLS := false
client := &http.Client{
Timeout: timeout,
@@ -579,13 +581,6 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
}
// Use system certs if no CA cert is provided
// or system certs flag is not set
if len(tlsCaCert) == 0 && tlsCaCertFile == "" && tlsCaCertEnvVar == "" && tlsUseSystemCerts == nil {
trueValue := true
tlsUseSystemCerts = &trueValue
}
// Check the system certificates config first so that we
// load additional certificated into the correct pool.
if tlsUseSystemCerts != nil && *tlsUseSystemCerts && runtime.GOOS != "windows" {
@@ -629,21 +624,31 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
tlsConfig.RootCAs = pool
}
// If Host header is set, use it for TLS server name.
if host, hasHost := customHeaders["Host"]; hasHost {
// Only default the ServerName if the caller has
// specified the host. If we don't specify anything,
// Go will default to the target hostname. This name
// is not the same as the default that Go populates
// `req.Host` with, which is why we don't just set
// this unconditionally.
isTLS = true
tlsConfig.ServerName, _ = host.(string)
}
if tlsServerName != "" {
isTLS = true
tlsConfig.ServerName = tlsServerName
}
var transport *http.Transport
if isTLS {
if ok, parsedURL, tr := useSocket(url, &tlsConfig); ok {
transport = tr
url = parsedURL
} else {
transport = http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tlsConfig
transport.DisableKeepAlives = true
}
} else {
if ok, parsedURL, tr := useSocket(url, nil); ok {
transport = tr
url = parsedURL
}
if ok, parsedURL, tr := useSocket(url); ok {
transport = tr
url = parsedURL
} else if isTLS {
transport = http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tlsConfig
transport.DisableKeepAlives = true
}
if bctx.RoundTripper != nil {
@@ -676,8 +681,6 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
// Add custom headers
if len(customHeaders) != 0 {
customHeaders = canonicalizeHeaders(customHeaders)
for k, v := range customHeaders {
header, ok := v.(string)
if !ok {
@@ -697,21 +700,9 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
if host, hasHost := customHeaders["Host"]; hasHost {
host := host.(string) // We already checked that it's a string.
req.Host = host
// Only default the ServerName if the caller has
// specified the host. If we don't specify anything,
// Go will default to the target hostname. This name
// is not the same as the default that Go populates
// `req.Host` with, which is why we don't just set
// this unconditionally.
tlsConfig.ServerName = host
}
}
if tlsServerName != "" {
tlsConfig.ServerName = tlsServerName
}
if len(bctx.DistributedTracingOpts) > 0 {
client.Transport = tracing.NewTransport(client.Transport, bctx.DistributedTracingOpts)
}
@@ -1192,7 +1183,8 @@ func newInterQueryCacheData(bctx BuiltinContext, resp *http.Response, respBody [
RespBody: respBody,
Status: resp.Status,
StatusCode: resp.StatusCode,
Headers: resp.Header}
Headers: resp.Header,
}
return &cv, nil
}
@@ -1222,7 +1214,8 @@ func (c *interQueryCacheData) Clone() (cache.InterQueryCacheValue, error) {
RespBody: dup,
Status: c.Status,
StatusCode: c.StatusCode,
Headers: c.Headers.Clone()}, nil
Headers: c.Headers.Clone(),
}, nil
}
type responseHeaders struct {
@@ -1384,7 +1377,6 @@ func parseMaxAgeCacheDirective(cc map[string]string) (deltaSeconds, error) {
}
func formatHTTPResponseToAST(resp *http.Response, forceJSONDecode, forceYAMLDecode bool) (ast.Value, []byte, error) {
resultRawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, err

View File

@@ -121,6 +121,7 @@ func (q *Query) WithInput(input *ast.Term) *Query {
}
// WithTracer adds a query tracer to use during evaluation. This is optional.
//
// Deprecated: Use WithQueryTracer instead.
func (q *Query) WithTracer(tracer Tracer) *Query {
qt, ok := tracer.(QueryTracer)

View File

@@ -7,6 +7,7 @@ package topdown
import (
"fmt"
"regexp"
"regexp/syntax"
"sync"
gintersect "github.com/yashtewari/glob-intersection"
@@ -15,25 +16,24 @@ import (
"github.com/open-policy-agent/opa/v1/topdown/builtins"
)
const regexCacheMaxSize = 100
const regexInterQueryValueCacheHits = "rego_builtin_regex_interquery_value_cache_hits"
const (
regexCacheMaxSize = 100
regexInterQueryValueCacheHits = "rego_builtin_regex_interquery_value_cache_hits"
)
var regexpCacheLock = sync.Mutex{}
var regexpCache map[string]*regexp.Regexp
var (
regexpCacheLock = sync.RWMutex{}
regexpCache = make(map[string]*regexp.Regexp)
)
func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return iter(ast.InternedTerm(false))
if s, err := builtins.StringOperand(operands[0].Value, 1); err == nil {
if _, err = syntax.Parse(string(s), syntax.Perl); err == nil {
return iter(ast.InternedTerm(true))
}
}
_, err = regexp.Compile(string(s))
if err != nil {
return iter(ast.InternedTerm(false))
}
return iter(ast.InternedTerm(true))
return iter(ast.InternedTerm(false))
}
func builtinRegexMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -107,7 +107,8 @@ func builtinRegexSplit(bctx BuiltinContext, operands []*ast.Term, iter func(*ast
func getRegexp(bctx BuiltinContext, pat string) (*regexp.Regexp, error) {
if bctx.InterQueryBuiltinValueCache != nil {
// TODO: Use named cache
val, ok := bctx.InterQueryBuiltinValueCache.Get(ast.String(pat))
var key ast.Value = ast.String(pat)
val, ok := bctx.InterQueryBuiltinValueCache.Get(key)
if ok {
res, valid := val.(*regexp.Regexp)
if !valid {
@@ -124,20 +125,23 @@ func getRegexp(bctx BuiltinContext, pat string) (*regexp.Regexp, error) {
if err != nil {
return nil, err
}
bctx.InterQueryBuiltinValueCache.Insert(ast.String(pat), re)
bctx.InterQueryBuiltinValueCache.Insert(key, re)
return re, nil
}
regexpCacheLock.Lock()
defer regexpCacheLock.Unlock()
regexpCacheLock.RLock()
re, ok := regexpCache[pat]
numCached := len(regexpCache)
regexpCacheLock.RUnlock()
if !ok {
var err error
re, err = regexp.Compile(pat)
if err != nil {
return nil, err
}
if len(regexpCache) >= regexCacheMaxSize {
regexpCacheLock.Lock()
if numCached >= regexCacheMaxSize {
// Delete a (semi-)random key to make room for the new one.
for k := range regexpCache {
delete(regexpCache, k)
@@ -145,21 +149,24 @@ func getRegexp(bctx BuiltinContext, pat string) (*regexp.Regexp, error) {
}
}
regexpCache[pat] = re
regexpCacheLock.Unlock()
}
return re, nil
}
func getRegexpTemplate(pat string, delimStart, delimEnd byte) (*regexp.Regexp, error) {
regexpCacheLock.Lock()
defer regexpCacheLock.Unlock()
regexpCacheLock.RLock()
re, ok := regexpCache[pat]
regexpCacheLock.RUnlock()
if !ok {
var err error
re, err = compileRegexTemplate(pat, delimStart, delimEnd)
if err != nil {
return nil, err
}
regexpCacheLock.Lock()
regexpCache[pat] = re
regexpCacheLock.Unlock()
}
return re, nil
}
@@ -268,7 +275,6 @@ func builtinRegexReplace(bctx BuiltinContext, operands []*ast.Term, iter func(*a
}
func init() {
regexpCache = map[string]*regexp.Regexp{}
RegisterBuiltinFunc(ast.RegexIsValid.Name, builtinRegexIsValid)
RegisterBuiltinFunc(ast.RegexMatch.Name, builtinRegexMatch)
RegisterBuiltinFunc(ast.RegexMatchDeprecated.Name, builtinRegexMatch)

View File

@@ -23,34 +23,25 @@ func builtinSemVerCompare(_ BuiltinContext, operands []*ast.Term, iter func(*ast
return err
}
versionA, err := semver.NewVersion(string(versionStringA))
versionA, err := semver.Parse(string(versionStringA))
if err != nil {
return fmt.Errorf("operand 1: string %s is not a valid SemVer", versionStringA)
}
versionB, err := semver.NewVersion(string(versionStringB))
versionB, err := semver.Parse(string(versionStringB))
if err != nil {
return fmt.Errorf("operand 2: string %s is not a valid SemVer", versionStringB)
}
result := versionA.Compare(*versionB)
return iter(ast.InternedTerm(result))
return iter(ast.InternedTerm(versionA.Compare(versionB)))
}
func builtinSemVerIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
versionString, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return iter(ast.InternedTerm(false))
if err == nil {
_, err = semver.Parse(string(versionString))
}
result := true
_, err = semver.NewVersion(string(versionString))
if err != nil {
result = false
}
return iter(ast.InternedTerm(result))
return iter(ast.InternedTerm(err == nil))
}
func init() {

View File

@@ -428,7 +428,7 @@ func builtinJWTVerify(bctx BuiltinContext, jwt ast.Value, keyStr ast.Value, hash
// If a match is found, verify using only that key. Only applicable when a JWKS was provided.
if header.kid != "" {
if key := getKeyByKid(header.kid, keys); key != nil {
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), signature)
return done(err == nil)
}
@@ -440,7 +440,7 @@ func builtinJWTVerify(bctx BuiltinContext, jwt ast.Value, keyStr ast.Value, hash
if key.alg == "" {
// No algorithm provided for the key - this is likely a certificate and not a JWKS, so
// we'll need to verify to find out
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), signature)
if err == nil {
return done(true)
}
@@ -448,7 +448,7 @@ func builtinJWTVerify(bctx BuiltinContext, jwt ast.Value, keyStr ast.Value, hash
if header.alg != key.alg {
continue
}
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), []byte(signature))
err = verify(key.key, getInputSHA([]byte(token.header+"."+token.payload), hasher), signature)
if err == nil {
return done(true)
}
@@ -509,7 +509,7 @@ func builtinJWTVerifyHS(bctx BuiltinContext, operands []*ast.Term, hashF func()
return err
}
valid := hmac.Equal([]byte(signature), mac.Sum(nil))
valid := hmac.Equal(signature, mac.Sum(nil))
putTokenInCache(bctx, jwt, astSecret, nil, nil, valid)
@@ -662,7 +662,7 @@ func (constraints *tokenConstraints) validate() error {
}
// verify verifies a JWT using the constraints and the algorithm from the header
func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature string) error {
func (constraints *tokenConstraints) verify(kid, alg, header, payload string, signature []byte) error {
// Construct the payload
plaintext := append(append([]byte(header), '.'), []byte(payload)...)
@@ -670,7 +670,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
if constraints.keys != nil {
if kid != "" {
if key := getKeyByKid(kid, constraints.keys); key != nil {
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, signature)
if err != nil {
return errSignatureNotVerified
}
@@ -681,7 +681,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
verified := false
for _, key := range constraints.keys {
if key.alg == "" {
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, signature)
if err == nil {
verified = true
break
@@ -690,7 +690,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
if alg != key.alg {
continue
}
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, signature)
if err == nil {
verified = true
break
@@ -704,7 +704,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
return nil
}
if constraints.secret != "" {
err := jwsbb.Verify([]byte(constraints.secret), alg, plaintext, []byte(signature))
err := jwsbb.Verify([]byte(constraints.secret), alg, plaintext, signature)
if err != nil {
return errSignatureNotVerified
}
@@ -1170,17 +1170,17 @@ func decodeJWT(a ast.Value) (*JSONWebToken, error) {
return &JSONWebToken{header: parts[0], payload: parts[1], signature: parts[2]}, nil
}
func (token *JSONWebToken) decodeSignature() (string, error) {
func (token *JSONWebToken) decodeSignature() ([]byte, error) {
decodedSignature, err := getResult(builtinBase64UrlDecode, ast.StringTerm(token.signature))
if err != nil {
return "", err
return nil, err
}
signatureAst, err := builtins.StringOperand(decodedSignature.Value, 1)
signatureBs, err := builtins.StringOperandByteSlice(decodedSignature.Value, 1)
if err != nil {
return "", err
return nil, err
}
return string(signatureAst), err
return signatureBs, nil
}
// Extract, validate and return the JWT header as an ast.Object.

View File

@@ -170,6 +170,7 @@ func (evt *Event) equalNodes(other *Event) bool {
}
// Tracer defines the interface for tracing in the top-down evaluation engine.
//
// Deprecated: Use QueryTracer instead.
type Tracer interface {
Enabled() bool
@@ -230,6 +231,7 @@ func (b *BufferTracer) Enabled() bool {
}
// Trace adds the event to the buffer.
//
// Deprecated: Use TraceEvent instead.
func (b *BufferTracer) Trace(evt *Event) {
*b = append(*b, evt)
@@ -806,7 +808,7 @@ func printPrettyVars(w *bytes.Buffer, exprVars map[string]varInfo) {
w.WriteString("\n\nWhere:\n")
for _, info := range byName {
w.WriteString(fmt.Sprintf("\n%s: %s", info.Title(), iStrs.Truncate(info.Value(), maxPrettyExprVarWidth)))
fmt.Fprintf(w, "\n%s: %s", info.Title(), iStrs.Truncate(info.Value(), maxPrettyExprVarWidth))
}
return
@@ -878,7 +880,7 @@ func printArrows(w *bytes.Buffer, l []varInfo, printValueAt int) {
valueStr := iStrs.Truncate(info.Value(), maxPrettyExprVarWidth)
if (i > 0 && col == l[i-1].col) || (i < len(l)-1 && col == l[i+1].col) {
// There is another var on this column, so we need to include the name to differentiate them.
w.WriteString(fmt.Sprintf("%s: %s", info.Title(), valueStr))
fmt.Fprintf(w, "%s: %s", info.Title(), valueStr)
} else {
w.WriteString(valueStr)
}

View File

@@ -716,6 +716,7 @@ func (t *Function) NamedFuncArgs() FuncArgs {
}
// Args returns the function's arguments as a slice, ignoring variadic arguments.
//
// Deprecated: Use FuncArgs instead.
func (t *Function) Args() []Type {
cpy := make([]Type, len(t.args))

View File

@@ -77,13 +77,10 @@ func dfsRecursive(t Traversal, eq Equals, u, z T, path []T) []T {
}
for _, v := range t.Edges(u) {
if eq(v, z) {
path = append(path, z)
path = append(path, u)
return path
return append(path, z, u)
}
if p := dfsRecursive(t, eq, v, z, path); len(p) > 0 {
path = append(p, u)
return path
return append(p, u)
}
}
return path

View File

@@ -1,13 +1,40 @@
package util
import (
"math"
"slices"
"strings"
"sync"
"unsafe"
)
// SyncPool is a generic sync.Pool for type T, providing some convenience
// over sync.Pool directly: [SyncPool.Put] ensures that nil values are not
// put into the pool, and [SyncPool.Get] returns a pointer to T without having
// to do a type assertion at the call site.
type SyncPool[T any] struct {
pool sync.Pool
}
func NewSyncPool[T any]() *SyncPool[T] {
return &SyncPool[T]{
pool: sync.Pool{
New: func() any {
return new(T)
},
},
}
}
func (p *SyncPool[T]) Get() *T {
return p.pool.Get().(*T)
}
func (p *SyncPool[T]) Put(x *T) {
if x != nil {
p.pool.Put(x)
}
}
// NewPtrSlice returns a slice of pointers to T with length n,
// with only 2 allocations performed no matter the size of n.
// See:
@@ -44,6 +71,12 @@ func StringToByteSlice[T ~string](s T) []byte {
// NumDigitsInt returns the number of digits in n.
// This is useful for pre-allocating buffers for string conversion.
func NumDigitsInt(n int) int {
return NumDigitsInt64(int64(n))
}
// NumDigitsInt64 returns the number of digits in n.
// This is useful for pre-allocating buffers for string conversion.
func NumDigitsInt64(n int64) int {
if n == 0 {
return 1
}
@@ -52,7 +85,12 @@ func NumDigitsInt(n int) int {
n = -n
}
return int(math.Log10(float64(n))) + 1
count := 0
for n > 0 {
n /= 10
count++
}
return count
}
// NumDigitsUint returns the number of digits in n.
@@ -62,16 +100,10 @@ func NumDigitsUint(n uint64) int {
return 1
}
return int(math.Log10(float64(n))) + 1
}
// KeysCount returns the number of keys in m that satisfy predicate p.
func KeysCount[K comparable, V any](m map[K]V, p func(K) bool) int {
count := 0
for k := range m {
if p(k) {
count++
}
for n > 0 {
n /= 10
count++
}
return count
}
@@ -129,5 +161,7 @@ func (sp *SlicePool[T]) Get(length int) *[]T {
// Put returns a pointer to a slice of type T to the pool.
func (sp *SlicePool[T]) Put(s *[]T) {
sp.pool.Put(s)
if s != nil {
sp.pool.Put(s)
}
}

View File

@@ -3,29 +3,21 @@ package util
import (
"bytes"
"compress/gzip"
"encoding/binary"
"errors"
"io"
"net/http"
"strings"
"sync"
"github.com/open-policy-agent/opa/v1/util/decoding"
)
var gzipReaderPool = sync.Pool{
New: func() any {
reader := new(gzip.Reader)
return reader
},
}
var gzipReaderPool = NewSyncPool[gzip.Reader]()
// Note(philipc): Originally taken from server/server.go
// The DecodingLimitHandler handles validating that the gzip payload is within the
// allowed max size limit. Thus, in the event of a forged payload size trailer,
// the worst that can happen is that we waste memory up to the allowed max gzip
// payload size, but not an unbounded amount of memory, as was potentially
// possible before.
// The DecodingLimitHandler handles setting the max size limits in the context.
// This function enforces those limits. For gzip payloads, we use a LimitReader
// to ensure we don't decompress more than the allowed maximum, preventing
// memory exhaustion from forged gzip trailers.
func ReadMaybeCompressedBody(r *http.Request) ([]byte, error) {
length := r.ContentLength
if maxLenConf, ok := decoding.GetServerDecodingMaxLen(r.Context()); ok {
@@ -40,16 +32,7 @@ func ReadMaybeCompressedBody(r *http.Request) ([]byte, error) {
if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") {
gzipMaxLength, _ := decoding.GetServerDecodingGzipMaxLen(r.Context())
// Note(philipc): The last 4 bytes of a well-formed gzip blob will
// always be a little-endian uint32, representing the decompressed
// content size, modulo 2^32. We validate that the size is safe,
// earlier in DecodingLimitHandler.
sizeDecompressed := int64(binary.LittleEndian.Uint32(content[len(content)-4:]))
if sizeDecompressed > gzipMaxLength {
return nil, errors.New("gzip payload too large")
}
gzReader := gzipReaderPool.Get().(*gzip.Reader)
gzReader := gzipReaderPool.Get()
defer func() {
gzReader.Close()
gzipReaderPool.Put(gzReader)
@@ -59,11 +42,16 @@ func ReadMaybeCompressedBody(r *http.Request) ([]byte, error) {
return nil, err
}
decompressed := bytes.NewBuffer(make([]byte, 0, sizeDecompressed))
if _, err = io.CopyN(decompressed, gzReader, sizeDecompressed); err != nil {
decompressed := bytes.NewBuffer(make([]byte, 0, len(content)))
limitReader := io.LimitReader(gzReader, gzipMaxLength+1)
if _, err := decompressed.ReadFrom(limitReader); err != nil {
return nil, err
}
if int64(decompressed.Len()) > gzipMaxLength {
return nil, errors.New("gzip payload too large")
}
return decompressed.Bytes(), nil
}

View File

@@ -0,0 +1,13 @@
package util
import "strings"
// WithPrefix ensures that the string s starts with the given prefix.
// If s already starts with prefix, it is returned unchanged.
func WithPrefix(s, prefix string) string {
if strings.HasPrefix(s, prefix) {
return s
}
return prefix + s
}

View File

@@ -10,7 +10,7 @@ import (
"runtime/debug"
)
var Version = "1.10.1"
var Version = "1.11.1"
// GoVersion is the version of Go this was built with
var GoVersion = runtime.Version()

View File

@@ -1,21 +1,16 @@
MIT License
MIT No Attribution
Copyright (c) 2021 Segment
Copyright 2023 Segment
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -130,7 +130,12 @@ loop:
ADVANCE_LOOP(loop) // Store results and continue
done:
RETURN()
// RETURN() replacing the macro to please go vet.
SUB R0, R3;
SUB R1, R4;
MOVD R3, ret+56(FP);
MOVD R4, ret1+64(FP);
RET
// func decodeStdARM64(dst []byte, src []byte, lut *int8) (int, int)
@@ -145,7 +150,12 @@ loop:
ADVANCE_LOOP(loop) // Store results and continue
done:
RETURN()
// RETURN() replacing the macro to please go vet.
SUB R0, R3;
SUB R1, R4;
MOVD R3, ret+56(FP);
MOVD R4, ret1+64(FP);
RET
DATA ·mask_lut+0x00(SB)/1, $0xa8

View File

@@ -150,8 +150,8 @@ func (f *StringSliceFlag) Apply(set *flag.FlagSet) error {
setValue = f.Value.clone()
default:
setValue = new(StringSlice)
setValue.WithSeparatorSpec(f.separator)
}
setValue.WithSeparatorSpec(f.separator)
setValue.keepSpace = f.KeepSpace

View File

@@ -136,7 +136,10 @@ var SubcommandHelpTemplate = `NAME:
{{template "helpNameTemplate" .}}
USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Description}}
{{template "usageTemplate" .}}{{if .Category}}
CATEGORY:
{{.Category}}{{end}}{{if .Description}}
DESCRIPTION:
{{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}}

View File

@@ -54,7 +54,7 @@ var helpCommand = &Command{
cCtx = cCtx.parentContext
}
// Case 4. $ app hello foo
// Case 4. $ app help foo
// foo is the command for which help needs to be shown
if argsPresent {
return ShowCommandHelp(cCtx, firstArg)

View File

@@ -83,7 +83,10 @@ var SubcommandHelpTemplate = `NAME:
{{template "helpNameTemplate" .}}
USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Description}}
{{template "usageTemplate" .}}{{if .Category}}
CATEGORY:
{{.Category}}{{end}}{{if .Description}}
DESCRIPTION:
{{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}}

View File

@@ -39,5 +39,8 @@ type Directive struct {
}
func (d *Directive) ArgumentMap(vars map[string]interface{}) map[string]interface{} {
if d.Definition == nil {
return nil
}
return arg2map(d.Definition.Arguments, d.Arguments, vars)
}

View File

@@ -37,5 +37,8 @@ type Argument struct {
}
func (f *Field) ArgumentMap(vars map[string]interface{}) map[string]interface{} {
if f.Definition == nil {
return nil
}
return arg2map(f.Definition.Arguments, f.Arguments, vars)
}

View File

@@ -29,9 +29,10 @@ type Value struct {
Comment *CommentGroup
// Require validation
Definition *Definition
VariableDefinition *VariableDefinition
ExpectedType *Type
Definition *Definition
VariableDefinition *VariableDefinition
ExpectedType *Type
ExpectedTypeHasDefault bool
}
type ChildValue struct {

View File

@@ -182,6 +182,7 @@ func (w *Walker) walkValue(value *ast.Value) {
fieldDef := value.Definition.Fields.ForName(child.Name)
if fieldDef != nil {
child.Value.ExpectedType = fieldDef.Type
child.Value.ExpectedTypeHasDefault = fieldDef.DefaultValue != nil && fieldDef.DefaultValue.Kind != ast.NullValue
child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
}
}
@@ -208,6 +209,7 @@ func (w *Walker) walkValue(value *ast.Value) {
func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
if argDef != nil {
arg.Value.ExpectedType = argDef.Type
arg.Value.ExpectedTypeHasDefault = argDef.DefaultValue != nil && argDef.DefaultValue.Kind != ast.NullValue
arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
}

View File

@@ -77,6 +77,7 @@ func (r *Rules) AddRule(name string, ruleFunc core.RuleFunc) {
// GetInner returns the internal rule map.
// If the map is not initialized, it returns an empty map.
// This returns a copy of the rules map, not the original map.
func (r *Rules) GetInner() map[string]core.RuleFunc {
if r == nil {
return nil // impossible nonsense, hopefully
@@ -84,7 +85,13 @@ func (r *Rules) GetInner() map[string]core.RuleFunc {
if r.rules == nil {
return make(map[string]core.RuleFunc)
}
return r.rules
rules := make(map[string]core.RuleFunc)
for k, v := range r.rules {
rules[k] = v
}
return rules
}
// RemoveRule removes a rule with the specified name from the rule set.

View File

@@ -25,6 +25,11 @@ var VariablesInAllowedPositionRule = Rule{
}
}
// If the expected type has a default, the given variable can be null
if value.ExpectedTypeHasDefault {
tmp.NonNull = false
}
if !value.VariableDefinition.Type.IsCompatible(&tmp) {
addError(
Message(

20
vendor/modules.txt vendored
View File

@@ -288,7 +288,7 @@ github.com/containerd/errdefs/pkg/internal/cause
# github.com/containerd/log v0.1.0
## explicit; go 1.20
github.com/containerd/log
# github.com/containerd/platforms v1.0.0-rc.1
# github.com/containerd/platforms v1.0.0-rc.2
## explicit; go 1.20
github.com/containerd/platforms
# github.com/coreos/go-oidc/v3 v3.17.0
@@ -716,8 +716,8 @@ github.com/golang-jwt/jwt/v4
# github.com/golang-jwt/jwt/v5 v5.3.0
## explicit; go 1.21
github.com/golang-jwt/jwt/v5
# github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
## explicit
# github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8
## explicit; go 1.20
github.com/golang/groupcache/lru
# github.com/golang/protobuf v1.5.4
## explicit; go 1.17
@@ -942,8 +942,8 @@ github.com/lestrrat-go/httprc/v3
github.com/lestrrat-go/httprc/v3/errsink
github.com/lestrrat-go/httprc/v3/proxysink
github.com/lestrrat-go/httprc/v3/tracesink
# github.com/lestrrat-go/jwx/v3 v3.0.11
## explicit; go 1.24.4
# github.com/lestrrat-go/jwx/v3 v3.0.12
## explicit; go 1.24.0
github.com/lestrrat-go/jwx/v3
github.com/lestrrat-go/jwx/v3/cert
github.com/lestrrat-go/jwx/v3/internal/base64
@@ -1130,7 +1130,7 @@ github.com/moby/term/windows
# github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd
## explicit
github.com/modern-go/concurrent
# github.com/modern-go/reflect2 v1.0.2
# github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee
## explicit; go 1.12
github.com/modern-go/reflect2
# github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
@@ -1275,7 +1275,7 @@ github.com/onsi/gomega/matchers/support/goraph/edge
github.com/onsi/gomega/matchers/support/goraph/node
github.com/onsi/gomega/matchers/support/goraph/util
github.com/onsi/gomega/types
# github.com/open-policy-agent/opa v1.10.1
# github.com/open-policy-agent/opa v1.11.1
## explicit; go 1.24.6
github.com/open-policy-agent/opa/ast
github.com/open-policy-agent/opa/ast/json
@@ -1928,7 +1928,7 @@ github.com/samber/slog-common
# github.com/samber/slog-zerolog/v2 v2.9.0
## explicit; go 1.21
github.com/samber/slog-zerolog/v2
# github.com/segmentio/asm v1.2.0
# github.com/segmentio/asm v1.2.1
## explicit; go 1.18
github.com/segmentio/asm/base64
github.com/segmentio/asm/cpu
@@ -2148,14 +2148,14 @@ github.com/tus/tusd/v2/pkg/handler
## explicit; go 1.13
github.com/unrolled/secure
github.com/unrolled/secure/cspbuilder
# github.com/urfave/cli/v2 v2.27.5
# github.com/urfave/cli/v2 v2.27.7
## explicit; go 1.18
github.com/urfave/cli/v2
# github.com/valyala/fastjson v1.6.4
## explicit; go 1.12
github.com/valyala/fastjson
github.com/valyala/fastjson/fastfloat
# github.com/vektah/gqlparser/v2 v2.5.30
# github.com/vektah/gqlparser/v2 v2.5.31
## explicit; go 1.22
github.com/vektah/gqlparser/v2/ast
github.com/vektah/gqlparser/v2/gqlerror