Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
ab183f7248 build(deps): bump github.com/open-policy-agent/opa from 1.11.1 to 1.12.3
Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 1.11.1 to 1.12.3.
- [Release notes](https://github.com/open-policy-agent/opa/releases)
- [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-policy-agent/opa/compare/v1.11.1...v1.12.3)

---
updated-dependencies:
- dependency-name: github.com/open-policy-agent/opa
  dependency-version: 1.12.3
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-15 16:38:01 +00:00
46 changed files with 21523 additions and 694 deletions

2
go.mod
View File

@@ -61,7 +61,7 @@ require (
github.com/onsi/ginkgo v1.16.5
github.com/onsi/ginkgo/v2 v2.27.5
github.com/onsi/gomega v1.39.0
github.com/open-policy-agent/opa v1.11.1
github.com/open-policy-agent/opa v1.12.3
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20250724122329-41ba6b191e76
github.com/opencloud-eu/reva/v2 v2.41.1-0.20260107152322-93760b632993

4
go.sum
View File

@@ -957,8 +957,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q=
github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
github.com/open-policy-agent/opa v1.11.1 h1:4bMlG6DjRZTRAswRyF+KUCgxHu1Gsk0h9EbZ4W9REvM=
github.com/open-policy-agent/opa v1.11.1/go.mod h1:QimuJO4T3KYxWzrmAymqlFvsIanCjKrGjmmC8GgAdgE=
github.com/open-policy-agent/opa v1.12.3 h1:qe3m/w52baKC/HJtippw+hYBUKCzuBCPjB+D5P9knfc=
github.com/open-policy-agent/opa v1.12.3/go.mod h1:RnDgm04GA1RjEXJvrsG9uNT/+FyBNmozcPvA2qz60M4=
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a h1:Sakl76blJAaM6NxylVkgSzktjo2dS504iDotEFJsh3M=
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a/go.mod h1:pjcozWijkNPbEtX5SIQaxEW/h8VAVZYTLx+70bmB3LY=
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89 h1:W1ms+lP5lUUIzjRGDg93WrQfZJZCaV1ZP3KeyXi8bzY=

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -719,6 +719,18 @@ opa_strings_upper,opa_abort
opa_strings_upper,opa_unicode_to_upper
opa_strings_upper,opa_realloc
opa_strings_upper,opa_unicode_encode_utf8
to_string,opa_value_type
to_string,opa_string_terminated
to_string,opa_value_dump
to_string,opa_strlen
to_string,opa_string_allocated
opa_template_string,opa_value_type
opa_template_string,opa_array_with_cap
opa_template_string,to_string
opa_template_string,opa_array_append
opa_template_string,opa_malloc
opa_template_string,memcpy
opa_template_string,opa_string_allocated
opa_types_is_number,opa_value_type
opa_types_is_number,opa_boolean
opa_types_is_string,opa_value_type
1 opa_agg_count opa_value_type
719 opa_strings_upper opa_unicode_to_upper
720 opa_strings_upper opa_realloc
721 opa_strings_upper opa_unicode_encode_utf8
722 to_string opa_value_type
723 to_string opa_string_terminated
724 to_string opa_value_dump
725 to_string opa_strlen
726 to_string opa_string_allocated
727 opa_template_string opa_value_type
728 opa_template_string opa_array_with_cap
729 opa_template_string to_string
730 opa_template_string opa_array_append
731 opa_template_string opa_malloc
732 opa_template_string memcpy
733 opa_template_string opa_string_allocated
734 opa_types_is_number opa_value_type
735 opa_types_is_number opa_boolean
736 opa_types_is_string opa_value_type

View File

Binary file not shown.

View File

@@ -162,6 +162,7 @@ var builtinsFunctions = map[string]string{
ast.TrimRight.Name: "opa_strings_trim_right",
ast.TrimSuffix.Name: "opa_strings_trim_suffix",
ast.TrimSpace.Name: "opa_strings_trim_space",
ast.InternalTemplateString.Name: "opa_template_string",
ast.NumbersRange.Name: "opa_numbers_range",
ast.ToNumber.Name: "opa_to_number",
ast.WalkBuiltin.Name: "opa_value_transitive_closure",

View File

@@ -1768,7 +1768,7 @@ func (p *Planner) planRef(ref ast.Ref, iter planiter) error {
return errors.New("illegal ref: non-var head")
}
if head.Compare(ast.DefaultRootDocument.Value) == 0 {
if head.Equal(ast.DefaultRootDocument.Value) {
virtual := p.rules.Get(ref[0].Value)
base := &baseptr{local: p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var))}
return p.planRefData(virtual, base, ref, 1, iter)

View File

@@ -83,7 +83,7 @@ func readModule(r io.Reader) (*module.Module, error) {
var m module.Module
if err := readSections(r, &m); err != nil && err != io.EOF {
if err := readSections(r, &m); err != io.EOF {
return nil, err
}

View File

@@ -433,18 +433,7 @@ func (a *Annotations) toObject() (*Object, *Error) {
}
if len(a.Scope) > 0 {
switch a.Scope {
case annotationScopeDocument:
obj.Insert(InternedTerm("scope"), InternedTerm("document"))
case annotationScopePackage:
obj.Insert(InternedTerm("scope"), InternedTerm("package"))
case annotationScopeRule:
obj.Insert(InternedTerm("scope"), InternedTerm("rule"))
case annotationScopeSubpackages:
obj.Insert(InternedTerm("scope"), InternedTerm("subpackages"))
default:
obj.Insert(InternedTerm("scope"), StringTerm(a.Scope))
}
obj.Insert(InternedTerm("scope"), InternedTerm(a.Scope))
}
if len(a.Title) > 0 {

View File

@@ -151,6 +151,7 @@ var DefaultBuiltins = [...]*Builtin{
Sprintf,
StringReverse,
RenderTemplate,
InternalTemplateString,
// Numbers
NumbersRange,
@@ -1109,7 +1110,7 @@ var Concat = &Builtin{
types.Named("output", types.S).Description("the joined string"),
),
Categories: stringsCat,
CanSkipBctx: true,
CanSkipBctx: false,
}
var FormatInt = &Builtin{
@@ -1277,7 +1278,7 @@ var Replace = &Builtin{
types.Named("y", types.S).Description("string with replaced substrings"),
),
Categories: stringsCat,
CanSkipBctx: true,
CanSkipBctx: false,
}
var ReplaceN = &Builtin{
@@ -1297,7 +1298,7 @@ The old string comparisons are done in argument order.`,
),
types.Named("output", types.S).Description("string with replaced substrings"),
),
CanSkipBctx: true,
CanSkipBctx: false,
}
var RegexReplace = &Builtin{
@@ -3388,6 +3389,11 @@ var InternalTestCase = &Builtin{
Decl: types.NewFunction([]types.Type{types.NewArray(nil, types.A)}, nil),
}
var InternalTemplateString = &Builtin{
Name: "internal.template_string",
Decl: types.NewFunction([]types.Type{types.NewArray(nil, types.A)}, types.S),
}
/**
* Deprecated built-ins.
*/

View File

@@ -58,12 +58,14 @@ const FeatureRefHeads = "rule_head_refs"
const FeatureRegoV1 = "rego_v1"
const FeatureRegoV1Import = "rego_v1_import"
const FeatureKeywordsInRefs = "keywords_in_refs"
const FeatureTemplateStrings = "template_strings"
// Features carries the default features supported by this version of OPA.
// Use RegisterFeatures to add to them.
var Features = []string{
FeatureRegoV1,
FeatureKeywordsInRefs,
FeatureTemplateStrings,
}
// RegisterFeatures lets applications wrapping OPA register features, to be
@@ -269,6 +271,12 @@ func (c *Capabilities) ContainsFeature(feature string) bool {
return slices.Contains(c.Features, feature)
}
func (c *Capabilities) ContainsBuiltin(name string) bool {
return slices.ContainsFunc(c.Builtins, func(builtin *Builtin) bool {
return builtin.Name == name
})
}
// addBuiltinSorted inserts a built-in into c in sorted order. An existing built-in with the same name
// will be overwritten.
func (c *Capabilities) addBuiltinSorted(bi *Builtin) {

View File

@@ -7,7 +7,6 @@ package ast
import (
"fmt"
"slices"
"sort"
"strings"
"github.com/open-policy-agent/opa/v1/types"
@@ -16,11 +15,6 @@ import (
type varRewriter func(Ref) Ref
// exprChecker defines the interface for executing type checking on a single
// expression. The exprChecker must update the provided TypeEnv with inferred
// types of vars.
type exprChecker func(*TypeEnv, *Expr) *Error
// typeChecker implements type checking on queries and rules. Errors are
// accumulated on the typeChecker so that a single run can report multiple
// issues.
@@ -28,7 +22,6 @@ type typeChecker struct {
builtins map[string]*Builtin
required *Capabilities
errs Errors
exprCheckers map[string]exprChecker
varRewriter varRewriter
ss *SchemaSet
allowNet []string
@@ -39,11 +32,7 @@ type typeChecker struct {
// newTypeChecker returns a new typeChecker object that has no errors.
func newTypeChecker() *typeChecker {
return &typeChecker{
exprCheckers: map[string]exprChecker{
"eq": checkExprEq,
},
}
return &typeChecker{}
}
func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv {
@@ -126,43 +115,39 @@ func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv {
// are found. The resulting TypeEnv wraps the provided one. The resulting
// TypeEnv will be able to resolve types of vars contained in the body.
func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {
var errors []*Error
errors := []*Error{}
env = tc.newEnv(env)
vis := newRefChecker(env, tc.varRewriter)
gv := NewGenericVisitor(vis.Visit)
WalkExprs(body, func(expr *Expr) bool {
for _, bexpr := range body {
WalkExprs(bexpr, func(expr *Expr) bool {
closureErrs := tc.checkClosures(env, expr)
errors = append(errors, closureErrs...)
closureErrs := tc.checkClosures(env, expr)
for _, err := range closureErrs {
errors = append(errors, err)
}
// reset errors from previous iteration
vis.errs = nil
gv.Walk(expr)
errors = append(errors, vis.errs...)
hasClosureErrors := len(closureErrs) > 0
// reset errors from previous iteration
vis.errs = nil
NewGenericVisitor(vis.Visit).Walk(expr)
for _, err := range vis.errs {
errors = append(errors, err)
}
hasRefErrors := len(vis.errs) > 0
if err := tc.checkExpr(env, expr); err != nil {
// Suppress this error if a more actionable one has occurred. In
// this case, if an error occurred in a ref or closure contained in
// this expression, and the error is due to a nil type, then it's
// likely to be the result of the more specific error.
skip := (hasClosureErrors || hasRefErrors) && causedByNilType(err)
if !skip {
errors = append(errors, err)
if err := tc.checkExpr(env, expr); err != nil {
hasClosureErrors := len(closureErrs) > 0
hasRefErrors := len(vis.errs) > 0
// Suppress this error if a more actionable one has occurred. In
// this case, if an error occurred in a ref or closure contained in
// this expression, and the error is due to a nil type, then it's
// likely to be the result of the more specific error.
skip := (hasClosureErrors || hasRefErrors) && causedByNilType(err)
if !skip {
errors = append(errors, err)
}
}
}
return true
})
return true
})
}
tc.err(errors)
tc.err(errors...)
return env, errors
}
@@ -243,7 +228,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
for _, schemaAnnot := range schemaAnnots {
refType, err := tc.getSchemaType(schemaAnnot, rule)
if err != nil {
tc.err([]*Error{err})
tc.err(err)
continue
}
@@ -259,7 +244,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
} else {
newType, err := override(ref[len(prefixRef):], t, refType, rule)
if err != nil {
tc.err([]*Error{err})
tc.err(err)
continue
}
env.tree.Put(prefixRef, newType)
@@ -281,23 +266,25 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
var tpe types.Type
if len(rule.Head.Args) > 0 {
// If args are not referred to in body, infer as any.
WalkVars(rule.Head.Args, func(v Var) bool {
if cpy.GetByValue(v) == nil {
cpy.tree.PutOne(v, types.A)
}
return false
})
for _, arg := range rule.Head.Args {
// If args are not referred to in body, infer as any.
WalkTerms(arg, func(t *Term) bool {
if _, ok := t.Value.(Var); ok {
if cpy.GetByValue(t.Value) == nil {
cpy.tree.PutOne(t.Value, types.A)
}
}
return false
})
}
// Construct function type.
args := make([]types.Type, len(rule.Head.Args))
for i := range len(rule.Head.Args) {
for i := range rule.Head.Args {
args[i] = cpy.GetByValue(rule.Head.Args[i].Value)
}
f := types.NewFunction(args, cpy.Get(rule.Head.Value))
tpe = f
tpe = types.NewFunction(args, cpy.GetByValue(rule.Head.Value.Value))
} else {
switch rule.Head.RuleKind() {
case SingleValue:
@@ -310,7 +297,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
var err error
tpe, err = nestedObject(cpy, objPath, typeV)
if err != nil {
tc.err([]*Error{NewError(TypeErr, rule.Head.Location, "%s", err.Error())})
tc.err(NewError(TypeErr, rule.Head.Location, "%s", err.Error()))
tpe = nil
}
} else if typeV != nil {
@@ -374,9 +361,8 @@ func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error {
}
}
checker := tc.exprCheckers[operator]
if checker != nil {
return checker(env, expr)
if operator == "eq" {
return checkExprEq(env, expr)
}
return tc.checkExprBuiltin(env, expr)
@@ -599,7 +585,7 @@ func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool {
return unifies
}
return false
case Set:
case *set:
switch tpe := tpe.(type) {
case *types.Set:
return unify1Set(env, v, tpe, union)
@@ -674,14 +660,14 @@ func unify1Object(env *TypeEnv, val Object, tpe *types.Object, union bool) bool
return !stop
}
func unify1Set(env *TypeEnv, val Set, tpe *types.Set, union bool) bool {
func unify1Set(env *TypeEnv, val *set, tpe *types.Set, union bool) bool {
of := types.Values(tpe)
return !val.Until(func(elem *Term) bool {
return !unify1(env, elem, of, union)
})
}
func (tc *typeChecker) err(errors []*Error) {
func (tc *typeChecker) err(errors ...*Error) {
tc.errs = append(tc.errs, errors...)
}
@@ -702,7 +688,6 @@ func newRefChecker(env *TypeEnv, f varRewriter) *refChecker {
return &refChecker{
env: env,
errs: nil,
varRewriter: f,
}
}
@@ -714,8 +699,9 @@ func (rc *refChecker) Visit(x any) bool {
case *Expr:
switch terms := x.Terms.(type) {
case []*Term:
vis := NewGenericVisitor(rc.Visit)
for i := 1; i < len(terms); i++ {
NewGenericVisitor(rc.Visit).Walk(terms[i])
vis.Walk(terms[i])
}
return true
case *Term:
@@ -805,7 +791,6 @@ func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx i
}
func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error {
if idx == len(ref) {
return nil
}
@@ -820,16 +805,16 @@ func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error {
switch value := head.Value.(type) {
case Var:
if exist := rc.env.GetByValue(value); exist != nil {
if exist := rc.env.GetByValue(head.Value); exist != nil {
if !unifies(exist, keys) {
return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe))
}
} else {
rc.env.tree.PutOne(value, types.Keys(tpe))
rc.env.tree.PutOne(head.Value, types.Keys(tpe))
}
case Ref:
if exist := rc.env.Get(value); exist != nil {
if exist := rc.env.GetByRef(value); exist != nil {
if !unifies(exist, keys) {
return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, keys, getOneOfForType(tpe))
}
@@ -1130,7 +1115,7 @@ func getOneOfForNode(node *typeTreeNode) (result []Value) {
return false
})
sortValueSlice(result)
slices.SortFunc(result, Value.Compare)
return result
}
@@ -1153,16 +1138,10 @@ func getOneOfForType(tpe types.Type) (result []Value) {
}
result = removeDuplicate(result)
sortValueSlice(result)
slices.SortFunc(result, Value.Compare)
return result
}
func sortValueSlice(sl []Value) {
sort.Slice(sl, func(i, j int) bool {
return sl[i].Compare(sl[j]) < 0
})
}
func removeDuplicate(list []Value) []Value {
seen := make(map[Value]bool)
var newResult []Value
@@ -1186,13 +1165,13 @@ func getArgTypes(env *TypeEnv, args []*Term) []types.Type {
// getPrefix returns the shortest prefix of ref that exists in env
func getPrefix(env *TypeEnv, ref Ref) (Ref, types.Type) {
if len(ref) == 1 {
t := env.Get(ref)
t := env.GetByRef(ref)
if t != nil {
return ref, t
}
}
for i := 1; i < len(ref); i++ {
t := env.Get(ref[:i])
t := env.GetByRef(ref[:i])
if t != nil {
return ref[:i], t
}
@@ -1200,12 +1179,14 @@ func getPrefix(env *TypeEnv, ref Ref) (Ref, types.Type) {
return nil, nil
}
var dynamicAnyAny = types.NewDynamicProperty(types.A, types.A)
// override takes a type t and returns a type obtained from t where the path represented by ref within it has type o (overriding the original type of that path)
func override(ref Ref, t types.Type, o types.Type, rule *Rule) (types.Type, *Error) {
var newStaticProps []*types.StaticProperty
obj, ok := t.(*types.Object)
if !ok {
newType, err := getObjectType(ref, o, rule, types.NewDynamicProperty(types.A, types.A))
newType, err := getObjectType(ref, o, rule, dynamicAnyAny)
if err != nil {
return nil, err
}

View File

@@ -96,6 +96,9 @@ func Compare(a, b any) int {
return -1
}
return 1
case *TemplateString:
b := b.(*TemplateString)
return a.Compare(b)
case Var:
return VarCompare(a, b.(Var))
case Ref:
@@ -179,26 +182,28 @@ func sortOrder(x any) int {
return 2
case String:
return 3
case Var:
case *TemplateString:
return 4
case Ref:
case Var:
return 5
case *Array:
case Ref:
return 6
case Object:
case *Array:
return 7
case Set:
case Object:
return 8
case *ArrayComprehension:
case Set:
return 9
case *ObjectComprehension:
case *ArrayComprehension:
return 10
case *SetComprehension:
case *ObjectComprehension:
return 11
case Call:
case *SetComprehension:
return 12
case Args:
case Call:
return 13
case Args:
return 14
case *Expr:
return 100
case *SomeDecl:
@@ -322,14 +327,6 @@ func TermValueEqual(a, b *Term) bool {
}
func ValueEqual(a, b Value) bool {
// TODO(ae): why doesn't this work the same?
//
// case interface{ Equal(Value) bool }:
// return v.Equal(b)
//
// When put on top, golangci-lint even flags the other cases as unreachable..
// but TestTopdownVirtualCache will have failing test cases when we replace
// the other cases with the above one.. 🤔
switch v := a.(type) {
case Null:
return v.Equal(b)
@@ -345,6 +342,8 @@ func ValueEqual(a, b Value) bool {
return v.Equal(b)
case *Array:
return v.Equal(b)
case *TemplateString:
return v.Equal(b)
}
return a.Compare(b) == 0

View File

File diff suppressed because it is too large Load Diff

View File

@@ -54,15 +54,14 @@ func (env *TypeEnv) GetByValue(v Value) types.Type {
return types.B
case Number:
return types.N
case String:
case String, *TemplateString:
return types.S
// Composites.
case *Array:
static := make([]types.Type, x.Len())
for i := range static {
tpe := env.GetByValue(x.Elem(i).Value)
static[i] = tpe
static[i] = env.GetByValue(x.Elem(i).Value)
}
var dynamic types.Type
@@ -80,17 +79,13 @@ func (env *TypeEnv) GetByValue(v Value) types.Type {
x.Foreach(func(k, v *Term) {
if IsConstant(k.Value) {
kjson, err := JSON(k.Value)
if err == nil {
tpe := env.GetByValue(v.Value)
static = append(static, types.NewStaticProperty(kjson, tpe))
if kjson, err := JSON(k.Value); err == nil {
static = append(static, types.NewStaticProperty(kjson, env.GetByValue(v.Value)))
return
}
}
// Can't handle it as a static property, fallback to dynamic
typeK := env.GetByValue(k.Value)
typeV := env.GetByValue(v.Value)
dynamic = types.NewDynamicProperty(typeK, typeV)
dynamic = types.NewDynamicProperty(env.GetByValue(k.Value), env.GetByValue(v.Value))
})
if len(static) == 0 && dynamic == nil {
@@ -99,7 +94,7 @@ func (env *TypeEnv) GetByValue(v Value) types.Type {
return types.NewObject(static, dynamic)
case Set:
case *set:
var tpe types.Type
x.Foreach(func(elem *Term) {
tpe = types.Or(tpe, env.GetByValue(elem.Value))
@@ -162,7 +157,6 @@ func (env *TypeEnv) GetByRef(ref Ref) types.Type {
}
func (env *TypeEnv) getRefFallback(ref Ref) types.Type {
if env.next != nil {
return env.next.GetByRef(ref)
}
@@ -299,15 +293,11 @@ func (n *typeTreeNode) PutOne(key Value, tpe types.Type) {
func (n *typeTreeNode) Put(path Ref, tpe types.Type) {
curr := n
for _, term := range path {
c, ok := curr.children.Get(term.Value)
var child *typeTreeNode
child, ok := curr.children.Get(term.Value)
if !ok {
child = newTypeTree()
child.key = term.Value
curr.children.Put(child.key, child)
} else {
child = c
}
curr = child
@@ -321,23 +311,18 @@ func (n *typeTreeNode) Put(path Ref, tpe types.Type) {
func (n *typeTreeNode) Insert(path Ref, tpe types.Type, env *TypeEnv) {
curr := n
for i, term := range path {
c, ok := curr.children.Get(term.Value)
var child *typeTreeNode
child, ok := curr.children.Get(term.Value)
if !ok {
child = newTypeTree()
child.key = term.Value
curr.children.Put(child.key, child)
} else {
child = c
if child.value != nil && i+1 < len(path) {
// If child has an object value, merge the new value into it.
if o, ok := child.value.(*types.Object); ok {
var err error
child.value, err = insertIntoObject(o, path[i+1:], tpe, env)
if err != nil {
panic(fmt.Errorf("unreachable, insertIntoObject: %w", err))
}
} else if child.value != nil && i+1 < len(path) {
// If child has an object value, merge the new value into it.
if o, ok := child.value.(*types.Object); ok {
var err error
child.value, err = insertIntoObject(o, path[i+1:], tpe, env)
if err != nil {
panic(fmt.Errorf("unreachable, insertIntoObject: %w", err))
}
}
}
@@ -349,8 +334,7 @@ func (n *typeTreeNode) Insert(path Ref, tpe types.Type, env *TypeEnv) {
if _, ok := tpe.(*types.Object); ok && curr.children.Len() > 0 {
// merge all leafs into the inserted object
leafs := curr.Leafs()
for p, t := range leafs {
for p, t := range curr.Leafs() {
var err error
curr.value, err = insertIntoObject(curr.value.(*types.Object), *p, t, env)
if err != nil {
@@ -388,7 +372,8 @@ func mergeTypes(a, b types.Type) types.Type {
bDynProps := bObj.DynamicProperties()
dynProps := types.NewDynamicProperty(
types.Or(aDynProps.Key, bDynProps.Key),
mergeTypes(aDynProps.Value, bDynProps.Value))
mergeTypes(aDynProps.Value, bDynProps.Value),
)
return types.NewObject(nil, dynProps)
} else if bAny, ok := b.(types.Any); ok && len(a.StaticProperties()) == 0 {
// If a is an object type with no static components ...
@@ -417,14 +402,14 @@ func mergeTypes(a, b types.Type) types.Type {
}
func (n *typeTreeNode) String() string {
b := strings.Builder{}
b := &strings.Builder{}
key := "-"
if k := n.key; k != nil {
b.WriteString(k.String())
} else {
b.WriteString("-")
key = k.String()
}
b.WriteString(key)
if v := n.value; v != nil {
b.WriteString(": ")
b.WriteString(v.String())
@@ -432,9 +417,7 @@ func (n *typeTreeNode) String() string {
n.children.Iter(func(_ Value, child *typeTreeNode) bool {
b.WriteString("\n\t+ ")
s := child.String()
s = strings.ReplaceAll(s, "\n", "\n\t")
b.WriteString(s)
b.WriteString(strings.ReplaceAll(child.String(), "\n", "\n\t"))
return false
})
@@ -485,7 +468,8 @@ func (n *typeTreeNode) Leafs() map[*Ref]types.Type {
func collectLeafs(n *typeTreeNode, path Ref, leafs map[*Ref]types.Type) {
nPath := append(path, NewTerm(n.key))
if n.Leaf() {
leafs[&nPath] = n.Value()
npc := nPath // copy of else nPath escapes to heap even if !n.Leaf()
leafs[&npc] = n.Value()
return
}
n.children.Iter(func(_ Value, v *typeTreeNode) bool {
@@ -513,7 +497,6 @@ func selectConstant(tpe types.Type, term *Term) types.Type {
// contains vars or refs, then the returned type will be a union of the
// possible types.
func selectRef(tpe types.Type, ref Ref) types.Type {
if tpe == nil || len(ref) == 0 {
return tpe
}

View File

@@ -121,9 +121,13 @@ func (e *Error) Error() string {
// NewError returns a new Error object.
func NewError(code string, loc *Location, f string, a ...any) *Error {
return newErrorString(code, loc, fmt.Sprintf(f, a...))
}
func newErrorString(code string, loc *Location, m string) *Error {
return &Error{
Code: code,
Location: loc,
Message: fmt.Sprintf(f, a...),
Message: m,
}
}

View File

@@ -412,7 +412,7 @@ func (i *refindices) updateGlobMatch(rule *Rule, expr *Expr) {
if _, ok := match.Value.(Var); ok {
var ref Ref
for _, other := range i.rules[rule] {
if _, ok := other.Value.(Var); ok && other.Value.Compare(match.Value) == 0 {
if ov, ok := other.Value.(Var); ok && ov.Equal(match.Value) {
ref = other.Ref
}
}

View File

@@ -158,18 +158,42 @@ func (s *Scanner) WithoutKeywords(kws map[string]tokens.Token) (*Scanner, map[st
return &cpy, kw
}
type ScanOptions struct {
continueTemplateString bool
rawTemplateString bool
}
type ScanOption func(*ScanOptions)
// ContinueTemplateString will continue scanning a template string
func ContinueTemplateString(raw bool) ScanOption {
return func(opts *ScanOptions) {
opts.continueTemplateString = true
opts.rawTemplateString = raw
}
}
// Scan will increment the scanners position in the source
// code until the next token is found. The token, starting position
// of the token, string literal, and any errors encountered are
// returned. A token will always be returned, the caller must check
// for any errors before using the other values.
func (s *Scanner) Scan() (tokens.Token, Position, string, []Error) {
func (s *Scanner) Scan(opts ...ScanOption) (tokens.Token, Position, string, []Error) {
scanOpts := &ScanOptions{}
for _, opt := range opts {
opt(scanOpts)
}
pos := Position{Offset: s.offset - s.width, Row: s.row, Col: s.col, Tabs: s.tabs}
var tok tokens.Token
var lit string
if s.isWhitespace() {
if scanOpts.continueTemplateString {
if scanOpts.rawTemplateString {
lit, tok = s.scanRawTemplateString()
} else {
lit, tok = s.scanTemplateString()
}
} else if s.isWhitespace() {
// string(rune) is an unnecessary heap allocation in this case as we know all
// the possible whitespace values, and can simply translate to string ourselves
switch s.curr {
@@ -275,6 +299,17 @@ func (s *Scanner) Scan() (tokens.Token, Position, string, []Error) {
tok = tokens.Semicolon
case '.':
tok = tokens.Dot
case '$':
switch s.curr {
case '`':
s.next()
lit, tok = s.scanRawTemplateString()
case '"':
s.next()
lit, tok = s.scanTemplateString()
default:
s.error("illegal $ character")
}
}
}
@@ -395,6 +430,116 @@ func (s *Scanner) scanRawString() string {
return util.ByteSliceToString(s.bs[start : s.offset-1])
}
func (s *Scanner) scanTemplateString() (string, tokens.Token) {
tok := tokens.TemplateStringPart
start := s.literalStart()
var escapes []int
for {
ch := s.curr
if ch == '\n' || ch < 0 {
s.error("non-terminated string")
break
}
s.next()
if ch == '"' {
tok = tokens.TemplateStringEnd
break
}
if ch == '{' {
break
}
if ch == '\\' {
switch s.curr {
case '\\', '"', '/', 'b', 'f', 'n', 'r', 't':
s.next()
case '{':
escapes = append(escapes, s.offset-1)
s.next()
case 'u':
s.next()
s.next()
s.next()
s.next()
default:
s.error("illegal escape sequence")
}
}
}
// Lazily remove escapes to not unnecessarily allocate a new byte slice
if len(escapes) > 0 {
return util.ByteSliceToString(removeEscapes(s, escapes, start)), tok
}
return util.ByteSliceToString(s.bs[start : s.offset-1]), tok
}
func (s *Scanner) scanRawTemplateString() (string, tokens.Token) {
tok := tokens.RawTemplateStringPart
start := s.literalStart()
var escapes []int
for {
ch := s.curr
if ch < 0 {
s.error("non-terminated string")
break
}
s.next()
if ch == '`' {
tok = tokens.RawTemplateStringEnd
break
}
if ch == '{' {
break
}
if ch == '\\' {
switch s.curr {
case '{':
escapes = append(escapes, s.offset-1)
s.next()
}
}
}
// Lazily remove escapes to not unnecessarily allocate a new byte slice
if len(escapes) > 0 {
return util.ByteSliceToString(removeEscapes(s, escapes, start)), tok
}
return util.ByteSliceToString(s.bs[start : s.offset-1]), tok
}
func removeEscapes(s *Scanner, escapes []int, start int) []byte {
from := start
bs := make([]byte, 0, s.offset-start-len(escapes))
for _, escape := range escapes {
// Append the bytes before the escape sequence.
if escape > from {
bs = append(bs, s.bs[from:escape-1]...)
}
// Skip the escape character.
from = escape
}
// Append the remaining bytes after the last escape sequence.
if from < s.offset-1 {
bs = append(bs, s.bs[from:s.offset-1]...)
}
return bs
}
func (s *Scanner) scanComment() string {
start := s.literalStart()
for s.curr != '\n' && s.curr != -1 {

View File

@@ -39,6 +39,10 @@ const (
Number
String
TemplateStringPart
TemplateStringEnd
RawTemplateStringPart
RawTemplateStringEnd
LBrack
RBrack
@@ -67,6 +71,7 @@ const (
Lte
Dot
Semicolon
Dollar
Every
Contains
@@ -74,53 +79,58 @@ const (
)
var strings = [...]string{
Illegal: "illegal",
EOF: "eof",
Whitespace: "whitespace",
Comment: "comment",
Ident: "identifier",
Package: "package",
Import: "import",
As: "as",
Default: "default",
Else: "else",
Not: "not",
Some: "some",
With: "with",
Null: "null",
True: "true",
False: "false",
Number: "number",
String: "string",
LBrack: "[",
RBrack: "]",
LBrace: "{",
RBrace: "}",
LParen: "(",
RParen: ")",
Comma: ",",
Colon: ":",
Add: "plus",
Sub: "minus",
Mul: "mul",
Quo: "div",
Rem: "rem",
And: "and",
Or: "or",
Unify: "eq",
Equal: "equal",
Assign: "assign",
In: "in",
Neq: "neq",
Gt: "gt",
Lt: "lt",
Gte: "gte",
Lte: "lte",
Dot: ".",
Semicolon: ";",
Every: "every",
Contains: "contains",
If: "if",
Illegal: "illegal",
EOF: "eof",
Whitespace: "whitespace",
Comment: "comment",
Ident: "identifier",
Package: "package",
Import: "import",
As: "as",
Default: "default",
Else: "else",
Not: "not",
Some: "some",
With: "with",
Null: "null",
True: "true",
False: "false",
Number: "number",
String: "string",
TemplateStringPart: "template-string-part",
TemplateStringEnd: "template-string-end",
RawTemplateStringPart: "raw-template-string-part",
RawTemplateStringEnd: "raw-template-string-end",
LBrack: "[",
RBrack: "]",
LBrace: "{",
RBrace: "}",
LParen: "(",
RParen: ")",
Comma: ",",
Colon: ":",
Add: "plus",
Sub: "minus",
Mul: "mul",
Quo: "div",
Rem: "rem",
And: "and",
Or: "or",
Unify: "eq",
Equal: "equal",
Assign: "assign",
In: "in",
Neq: "neq",
Gt: "gt",
Lt: "lt",
Gte: "gte",
Lte: "lte",
Dot: ".",
Semicolon: ";",
Dollar: "dollar",
Every: "every",
Contains: "contains",
If: "if",
}
var keywords = map[string]Token{
@@ -147,3 +157,7 @@ func IsKeyword(tok Token) bool {
_, ok := keywords[strings[tok]]
return ok
}
func KeywordFor(tok Token) string {
return strings[tok]
}

View File

@@ -42,10 +42,17 @@ var (
}
internedVarValues = map[string]Value{
"input": Var("input"),
"data": Var("data"),
"key": Var("key"),
"value": Var("value"),
"input": Var("input"),
"data": Var("data"),
"args": Var("args"),
"schema": Var("schema"),
"key": Var("key"),
"value": Var("value"),
"future": Var("future"),
"rego": Var("rego"),
"set": Var("set"),
"internal": Var("internal"),
"else": Var("else"),
"i": Var("i"), "j": Var("j"), "k": Var("k"), "v": Var("v"), "x": Var("x"), "y": Var("y"), "z": Var("z"),
}
@@ -190,6 +197,13 @@ func InternedTerm[T internable](v T) *Term {
}
}
// InternedItem works just like [Item] but returns interned terms for both
// key and value where possible. This is mostly useful for making tests less
// verbose.
func InternedItem[K, V internable](key K, value V) [2]*Term {
return [2]*Term{InternedTerm(key), InternedTerm(value)}
}
// InternedIntFromString returns a term with the given integer value if the string
// maps to an interned term. If the string does not map to an interned term, nil is
// returned.

View File

@@ -1736,6 +1736,10 @@ func (p *Parser) parseTerm() *Term {
term = p.parseNumber()
case tokens.String:
term = p.parseString()
case tokens.TemplateStringPart, tokens.TemplateStringEnd:
term = p.parseTemplateString(false)
case tokens.RawTemplateStringPart, tokens.RawTemplateStringEnd:
term = p.parseTemplateString(true)
case tokens.Ident, tokens.Contains: // NOTE(sr): contains anywhere BUT in rule heads gets no special treatment
term = p.parseVar()
case tokens.LBrack:
@@ -1767,7 +1771,7 @@ func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term {
return nil
}
offset := p.s.loc.Offset
p.doScan(skipws)
p.doScan(skipws, noScanOptions...)
switch p.s.tok {
case tokens.LParen, tokens.Dot, tokens.LBrack:
@@ -1788,7 +1792,7 @@ func (p *Parser) parseHeadFinish(head *Term, skipws bool) *Term {
return nil
}
offset := p.s.loc.Offset
p.doScan(false)
p.scanWS()
switch p.s.tok {
case tokens.Add, tokens.Sub, tokens.Mul, tokens.Quo, tokens.Rem,
@@ -1796,7 +1800,7 @@ func (p *Parser) parseHeadFinish(head *Term, skipws bool) *Term {
tokens.Equal, tokens.Neq, tokens.Gt, tokens.Gte, tokens.Lt, tokens.Lte:
p.illegalToken()
case tokens.Whitespace:
p.doScan(skipws)
p.doScan(skipws, noScanOptions...)
}
switch p.s.tok {
@@ -1886,6 +1890,11 @@ func (p *Parser) parseString() *Term {
return NewTerm(InternedEmptyString.Value).SetLocation(p.s.Loc())
}
inner := p.s.lit[1 : len(p.s.lit)-1]
if !strings.ContainsRune(inner, '\\') { // nothing to un-escape
return StringTerm(inner).SetLocation(p.s.Loc())
}
var s string
if err := json.Unmarshal([]byte(p.s.lit), &s); err != nil {
p.errorf(p.s.Loc(), "illegal string literal: %s", p.s.lit)
@@ -1903,6 +1912,120 @@ func (p *Parser) parseRawString() *Term {
return StringTerm(p.s.lit[1 : len(p.s.lit)-1]).SetLocation(p.s.Loc())
}
func templateStringPartToStringLiteral(tok tokens.Token, lit string) (string, error) {
switch tok {
case tokens.TemplateStringPart, tokens.TemplateStringEnd:
inner := lit[1 : len(lit)-1]
if !strings.ContainsRune(inner, '\\') { // nothing to un-escape
return inner, nil
}
buf := make([]byte, 0, len(inner)+2)
buf = append(buf, '"')
buf = append(buf, inner...)
buf = append(buf, '"')
var s string
if err := json.Unmarshal(buf, &s); err != nil {
return "", fmt.Errorf("illegal template-string part: %s", lit)
}
return s, nil
case tokens.RawTemplateStringPart, tokens.RawTemplateStringEnd:
return lit[1 : len(lit)-1], nil
default:
return "", errors.New("expected template-string part")
}
}
func (p *Parser) parseTemplateString(multiLine bool) *Term {
loc := p.s.Loc()
if !p.po.Capabilities.ContainsFeature(FeatureTemplateStrings) {
p.errorf(loc, "template strings are not supported by current capabilities")
return nil
}
var parts []Node
for {
s, err := templateStringPartToStringLiteral(p.s.tok, p.s.lit)
if err != nil {
p.error(p.s.Loc(), err.Error())
return nil
}
// Don't add empty strings
if len(s) > 0 {
parts = append(parts, StringTerm(s).SetLocation(p.s.Loc()))
}
if p.s.tok == tokens.TemplateStringEnd || p.s.tok == tokens.RawTemplateStringEnd {
break
}
numCommentsBefore := len(p.s.comments)
p.scan()
numCommentsAfter := len(p.s.comments)
expr := p.parseLiteral()
if expr == nil {
p.error(p.s.Loc(), "invalid template-string expression")
return nil
}
if expr.Negated {
p.errorf(expr.Loc(), "unexpected negation ('%s') in template-string expression", tokens.KeywordFor(tokens.Not))
return nil
}
// Note: Actually unification
if expr.IsEquality() {
p.errorf(expr.Loc(), "unexpected unification ('=') in template-string expression")
return nil
}
if expr.IsAssignment() {
p.errorf(expr.Loc(), "unexpected assignment (':=') in template-string expression")
return nil
}
if expr.IsEvery() {
p.errorf(expr.Loc(), "unexpected '%s' in template-string expression", tokens.KeywordFor(tokens.Every))
return nil
}
if expr.IsSome() {
p.errorf(expr.Loc(), "unexpected '%s' in template-string expression", tokens.KeywordFor(tokens.Some))
return nil
}
// FIXME: Can we optimize for collections and comprehensions too? To qualify, they must not contain refs or calls.
var nonOptional bool
if term, ok := expr.Terms.(*Term); ok && numCommentsAfter == numCommentsBefore {
switch term.Value.(type) {
case String, Number, Boolean, Null:
nonOptional = true
parts = append(parts, term)
}
}
if !nonOptional {
parts = append(parts, expr)
}
if p.s.tok != tokens.RBrace {
p.errorf(p.s.Loc(), "expected %s to end template string expression", tokens.RBrace)
return nil
}
p.doScan(false, scanner.ContinueTemplateString(multiLine))
}
// When there are template-expressions, the initial location will only contain the text up to the first expression
loc.Text = p.s.Text(loc.Offset, p.s.tokEnd)
return TemplateStringTerm(multiLine, parts...).SetLocation(loc)
}
func (p *Parser) parseCall(operator *Term, offset int) (term *Term) {
if !p.enter() {
return nil
@@ -2456,15 +2579,17 @@ func (p *Parser) illegalToken() {
p.illegal("")
}
var noScanOptions []scanner.ScanOption
func (p *Parser) scan() {
p.doScan(true)
p.doScan(true, noScanOptions...)
}
func (p *Parser) scanWS() {
p.doScan(false)
p.doScan(false, noScanOptions...)
}
func (p *Parser) doScan(skipws bool) {
func (p *Parser) doScan(skipws bool, scanOpts ...scanner.ScanOption) {
// NOTE(tsandall): the last position is used to compute the "text" field for
// complex AST nodes. Whitespace never affects the last position of an AST
@@ -2477,7 +2602,7 @@ func (p *Parser) doScan(skipws bool) {
var errs []scanner.Error
for {
var pos scanner.Position
p.s.tok, pos, p.s.lit, errs = p.s.s.Scan()
p.s.tok, pos, p.s.lit, errs = p.s.s.Scan(scanOpts...)
p.s.tokEnd = pos.End
p.s.loc.Row = pos.Row
@@ -2532,12 +2657,10 @@ func (p *Parser) restore(s *state) {
}
func setLocRecursive(x any, loc *location.Location) {
NewGenericVisitor(func(x any) bool {
if node, ok := x.(Node); ok {
node.SetLoc(loc)
}
WalkNodes(x, func(n Node) bool {
n.SetLoc(loc)
return false
}).Walk(x)
})
}
func (p *Parser) setLoc(term *Term, loc *location.Location, offset, end int) *Term {

View File

@@ -11,7 +11,6 @@
package ast
import (
"bytes"
"errors"
"fmt"
"slices"
@@ -625,10 +624,9 @@ func ParseStatements(filename, input string) ([]Statement, []*Comment, error) {
// ParseStatementsWithOpts returns a slice of parsed statements. This is the
// default return value from the parser.
func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, error) {
parser := NewParser().
WithFilename(filename).
WithReader(bytes.NewBufferString(input)).
WithReader(strings.NewReader(input)).
WithProcessAnnotation(popts.ProcessAnnotation).
WithFutureKeywords(popts.FutureKeywords...).
WithAllFutureKeywords(popts.AllFutureKeywords).
@@ -638,7 +636,6 @@ func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Sta
withUnreleasedKeywords(popts.unreleasedKeywords)
stmts, comments, errs := parser.Parse()
if len(errs) > 0 {
return nil, nil, errs
}
@@ -647,7 +644,6 @@ func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Sta
}
func parseModule(filename string, stmts []Statement, comments []*Comment, regoCompatibilityMode RegoVersion) (*Module, error) {
if len(stmts) == 0 {
return nil, NewError(ParseErr, &Location{File: filename}, "empty module")
}
@@ -662,23 +658,21 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, regoCo
mod := &Module{
Package: pkg,
stmts: stmts,
// The comments slice only holds comments that were not their own statements.
Comments: comments,
stmts: stmts,
}
// The comments slice only holds comments that were not their own statements.
mod.Comments = append(mod.Comments, comments...)
mod.regoVersion = regoCompatibilityMode
if regoCompatibilityMode == RegoUndefined {
mod.regoVersion = DefaultRegoVersion
} else {
mod.regoVersion = regoCompatibilityMode
}
for i, stmt := range stmts[1:] {
switch stmt := stmt.(type) {
case *Import:
mod.Imports = append(mod.Imports, stmt)
if mod.regoVersion == RegoV0 && Compare(stmt.Path.Value, RegoV1CompatibleRef) == 0 {
if mod.regoVersion == RegoV0 && RegoV1CompatibleRef.Equal(stmt.Path.Value) {
mod.regoVersion = RegoV0CompatV1
}
case *Rule:

View File

@@ -621,7 +621,7 @@ func (imp *Import) SetLoc(loc *Location) {
// document. This is the alias if defined otherwise the last element in the
// path.
func (imp *Import) Name() Var {
if len(imp.Alias) != 0 {
if imp.Alias != "" {
return imp.Alias
}
switch v := imp.Path.Value.(type) {
@@ -988,6 +988,7 @@ func (head *Head) Copy() *Head {
cpy.Key = head.Key.Copy()
cpy.Value = head.Value.Copy()
cpy.keywords = nil
cpy.Assign = head.Assign
return &cpy
}

View File

@@ -27,13 +27,12 @@ func checkRootDocumentOverrides(node any) Errors {
errors := Errors{}
WalkRules(node, func(rule *Rule) bool {
var name string
name := rule.Head.Name
if len(rule.Head.Reference) > 0 {
name = rule.Head.Reference[0].Value.(Var).String()
} else {
name = rule.Head.Name.String()
name = rule.Head.Reference[0].Value.(Var)
}
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
if ReservedVars.Contains(name) {
errors = append(errors, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name))
}
@@ -52,8 +51,8 @@ func checkRootDocumentOverrides(node any) Errors {
if expr.IsAssignment() {
// assign() can be called directly, so we need to assert its given first operand exists before checking its name.
if nameOp := expr.Operand(0); nameOp != nil {
name := nameOp.String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
name := Var(nameOp.String())
if ReservedVars.Contains(name) {
errors = append(errors, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name))
}
}
@@ -65,26 +64,24 @@ func checkRootDocumentOverrides(node any) Errors {
}
func walkCalls(node any, f func(any) bool) {
vis := &GenericVisitor{func(x any) bool {
switch x := x.(type) {
vis := NewGenericVisitor(func(x any) bool {
switch y := x.(type) {
case Call:
return f(x)
case *Expr:
if x.IsCall() {
if y.IsCall() {
return f(x)
}
case *Head:
// GenericVisitor doesn't walk the rule head ref
walkCalls(x.Reference, f)
walkCalls(y.Reference, f)
}
return false
}}
})
vis.Walk(node)
}
func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node any) Errors {
errs := make(Errors, 0)
func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node any) (errs Errors) {
walkCalls(node, func(x any) bool {
var operator string
var loc *Location

View File

@@ -48,6 +48,8 @@ func ValueName(x Value) string {
return "objectcomprehension"
case *SetComprehension:
return "setcomprehension"
case *TemplateString:
return "templatestring"
}
return TypeName(x)

View File

@@ -25,7 +25,13 @@ import (
"github.com/open-policy-agent/opa/v1/util"
)
var errFindNotFound = errors.New("find: not found")
var (
NullValue Value = Null{}
errFindNotFound = errors.New("find: not found")
varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$")
)
// Location records a position in source code.
type Location = location.Location
@@ -43,6 +49,7 @@ func NewLocation(text []byte, file string, row int, col int) *Location {
// - Variables, References
// - Array, Set, and Object Comprehensions
// - Calls
// - Template Strings
type Value interface {
Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively.
Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found.
@@ -351,6 +358,8 @@ func (term *Term) Copy() *Term {
cpy.Value = v.Copy()
case *SetComprehension:
cpy.Value = v.Copy()
case *TemplateString:
cpy.Value = v.Copy()
case Call:
cpy.Value = v.Copy()
}
@@ -456,7 +465,17 @@ func (term *Term) Vars() VarSet {
}
// IsConstant returns true if the AST value is constant.
// Note that this is only a shallow check as we currently don't have a real
// notion of constant "vars" in the AST implementation. Meaning that while we could
// derive that a reference to a constant value is also constant, we currently don't.
func IsConstant(v Value) bool {
switch v.(type) {
case Null, Boolean, Number, String:
return true
case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call:
return false
}
found := false
vis := GenericVisitor{
func(x any) bool {
@@ -531,8 +550,6 @@ func IsScalar(v Value) bool {
// Null represents the null value defined by JSON.
type Null struct{}
var NullValue Value = Null{}
// NullTerm creates a new Term with a Null value.
func NullTerm() *Term {
return &Term{Value: NullValue}
@@ -818,6 +835,173 @@ func (str String) Hash() int {
return int(xxhash.Sum64String(string(str)))
}
type TemplateString struct {
Parts []Node `json:"parts"`
MultiLine bool `json:"multi_line"`
}
func (ts *TemplateString) Copy() *TemplateString {
cpy := &TemplateString{MultiLine: ts.MultiLine, Parts: make([]Node, len(ts.Parts))}
for i, p := range ts.Parts {
switch v := p.(type) {
case *Expr:
cpy.Parts[i] = v.Copy()
case *Term:
cpy.Parts[i] = v.Copy()
}
}
return cpy
}
func (ts *TemplateString) Equal(other Value) bool {
if o, ok := other.(*TemplateString); ok && ts.MultiLine == o.MultiLine && len(ts.Parts) == len(o.Parts) {
for i, p := range ts.Parts {
switch v := p.(type) {
case *Expr:
if ope, ok := o.Parts[i].(*Expr); !ok || !v.Equal(ope) {
return false
}
case *Term:
if opt, ok := o.Parts[i].(*Term); !ok || !v.Equal(opt) {
return false
}
default:
return false
}
}
return true
}
return false
}
func (ts *TemplateString) Compare(other Value) int {
if ots, ok := other.(*TemplateString); ok {
if ts.MultiLine != ots.MultiLine {
if !ts.MultiLine {
return -1
}
return 1
}
if len(ts.Parts) != len(ots.Parts) {
return len(ts.Parts) - len(ots.Parts)
}
for i := range ts.Parts {
if cmp := Compare(ts.Parts[i], ots.Parts[i]); cmp != 0 {
return cmp
}
}
return 0
}
return Compare(ts, other)
}
func (ts *TemplateString) Find(path Ref) (Value, error) {
if len(path) == 0 {
return ts, nil
}
return nil, errFindNotFound
}
func (ts *TemplateString) Hash() int {
hash := 0
for _, p := range ts.Parts {
switch x := p.(type) {
case *Expr:
hash += x.Hash()
case *Term:
hash += x.Value.Hash()
default:
panic(fmt.Sprintf("invalid template part type %T", p))
}
}
return hash
}
func (*TemplateString) IsGround() bool {
return false
}
func (ts *TemplateString) String() string {
str := strings.Builder{}
str.WriteString("$\"")
for _, p := range ts.Parts {
switch x := p.(type) {
case *Expr:
str.WriteByte('{')
str.WriteString(p.String())
str.WriteByte('}')
case *Term:
s := p.String()
if _, ok := x.Value.(String); ok {
s = strings.TrimPrefix(s, "\"")
s = strings.TrimSuffix(s, "\"")
s = EscapeTemplateStringStringPart(s)
}
str.WriteString(s)
default:
str.WriteString("<invalid>")
}
}
str.WriteByte('"')
return str.String()
}
func TemplateStringTerm(multiLine bool, parts ...Node) *Term {
return &Term{Value: &TemplateString{MultiLine: multiLine, Parts: parts}}
}
// EscapeTemplateStringStringPart escapes unescaped left curly braces in s - i.e "{" becomes "\{".
// The internal representation of string terms within a template string does **NOT**
// treat '{' as special, but expects code dealing with template strings to escape them when
// required, such as when serializing the complete template string. Code that programmatically
// constructs template strings should not pre-escape left curly braces in string term parts.
//
// // TODO(anders): a future optimization would be to combine this with the other escaping done
// // for strings (e.g. escaping quotes, backslashes, and JSON control characters) in a single operation
// // to avoid multiple passes and allocations over the same string. That's currently done by
// // strconv.Quote, so we would need to re-implement that logic in code of our own.
// // NOTE(anders): I would love to come up with a better name for this component than
// // "TemplateStringStringPart"..
func EscapeTemplateStringStringPart(s string) string {
numUnescaped := countUnescapedLeftCurly(s)
if numUnescaped == 0 {
return s
}
l := len(s)
escaped := make([]byte, 0, l+numUnescaped)
if s[0] == '{' {
escaped = append(escaped, '\\', s[0])
} else {
escaped = append(escaped, s[0])
}
for i := 1; i < l; i++ {
if s[i] == '{' && s[i-1] != '\\' {
escaped = append(escaped, '\\', s[i])
} else {
escaped = append(escaped, s[i])
}
}
return util.ByteSliceToString(escaped)
}
func countUnescapedLeftCurly(s string) (n int) {
// Note(anders): while not the functions I'd intuitively reach for to solve this,
// they are hands down the fastest option here, as they're done in assembly, which
// performs about an order of magnitude better than a manual loop in Go.
if n = strings.Count(s, "{"); n > 0 {
n -= strings.Count(s, `\{`)
}
return n
}
// Var represents a variable as defined by the language.
type Var string
@@ -951,14 +1135,14 @@ func (ref Ref) Insert(x *Term, pos int) Ref {
// Extend returns a copy of ref with the terms from other appended. The head of
// other will be converted to a string.
func (ref Ref) Extend(other Ref) Ref {
dst := make(Ref, len(ref)+len(other))
offset := len(ref)
dst := make(Ref, offset+len(other))
copy(dst, ref)
head := other[0].Copy()
head.Value = String(head.Value.(Var))
offset := len(ref)
dst[offset] = head
dst[offset] = head
copy(dst[offset+1:], other[1:])
return dst
}
@@ -1070,42 +1254,38 @@ func (ref Ref) HasPrefix(other Ref) bool {
func (ref Ref) ConstantPrefix() Ref {
i := ref.Dynamic()
if i < 0 {
return ref.Copy()
return ref
}
return ref[:i].Copy()
return ref[:i]
}
// StringPrefix returns the string portion of the ref starting from the head.
func (ref Ref) StringPrefix() Ref {
for i := 1; i < len(ref); i++ {
switch ref[i].Value.(type) {
case String: // pass
default: // cut off
return ref[:i].Copy()
return ref[:i]
}
}
return ref.Copy()
return ref
}
// GroundPrefix returns the ground portion of the ref starting from the head. By
// definition, the head of the reference is always ground.
func (ref Ref) GroundPrefix() Ref {
if ref.IsGround() {
return ref
}
prefix := make(Ref, 0, len(ref))
for i, x := range ref {
if i > 0 && !x.IsGround() {
break
for i := range ref {
if i > 0 && !ref[i].IsGround() {
return ref[:i]
}
prefix = append(prefix, x)
}
return prefix
return ref
}
// DynamicSuffix returns the dynamic portion of the ref.
// If the ref is not dynamic, nil is returned.
func (ref Ref) DynamicSuffix() Ref {
i := ref.Dynamic()
if i < 0 {
@@ -1116,7 +1296,7 @@ func (ref Ref) DynamicSuffix() Ref {
// IsGround returns true if all of the parts of the Ref are ground.
func (ref Ref) IsGround() bool {
if len(ref) == 0 {
if len(ref) < 2 {
return true
}
return termSliceIsGround(ref[1:])
@@ -1136,18 +1316,29 @@ func (ref Ref) IsNested() bool {
// contains non-string terms this function returns an error. Path
// components are escaped.
func (ref Ref) Ptr() (string, error) {
parts := make([]string, 0, len(ref)-1)
for _, term := range ref[1:] {
if str, ok := term.Value.(String); ok {
parts = append(parts, url.PathEscape(string(str)))
} else {
buf := &strings.Builder{}
tail := ref[1:]
l := max(len(tail)-1, 0) // number of '/' to add
for i := range tail {
str, ok := tail[i].Value.(String)
if !ok {
return "", errors.New("invalid path value type")
}
l += len(str)
}
return strings.Join(parts, "/"), nil
}
buf.Grow(l)
var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$")
for i := range tail {
if i > 0 {
buf.WriteByte('/')
}
str := string(tail[i].Value.(String))
// Sadly, the url package does not expose an appender for this.
buf.WriteString(url.PathEscape(str))
}
return buf.String(), nil
}
func IsVarCompatibleString(s string) bool {
return varRegexp.MatchString(s)
@@ -1263,13 +1454,12 @@ type Array struct {
// Copy returns a deep copy of arr.
func (arr *Array) Copy() *Array {
cpy := make([]int, len(arr.elems))
copy(cpy, arr.hashs)
return &Array{
elems: termSliceCopy(arr.elems),
hashs: cpy,
hashs: slices.Clone(arr.hashs),
hash: arr.hash,
ground: arr.IsGround()}
ground: arr.ground,
}
}
// Equal returns true if arr is equal to other.
@@ -1548,13 +1738,19 @@ type set struct {
// Copy returns a deep copy of s.
func (s *set) Copy() Set {
terms := make([]*Term, len(s.keys))
for i := range s.keys {
terms[i] = s.keys[i].Copy()
cpy := &set{
hash: s.hash,
ground: s.ground,
sortGuard: sync.Once{},
elems: make(map[int]*Term, len(s.elems)),
keys: make([]*Term, 0, len(s.keys)),
}
cpy := NewSet(terms...).(*set)
cpy.hash = s.hash
cpy.ground = s.ground
for hash := range s.elems {
cpy.elems[hash] = s.elems[hash].Copy()
cpy.keys = append(cpy.keys, cpy.elems[hash])
}
return cpy
}
@@ -2309,19 +2505,21 @@ func (obj *object) Merge(other Object) (Object, bool) {
// is called. The conflictResolver can return a merged value and a boolean
// indicating if the merge has failed and should stop.
func (obj *object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) {
result := NewObject()
// Might overallocate assuming no conflicts is the common case,
// but that's typically faster than iterating over each object twice.
result := newobject(obj.Len() + other.Len())
stop := obj.Until(func(k, v *Term) bool {
v2 := other.Get(k)
// The key didn't exist in other, keep the original value
if v2 == nil {
result.Insert(k, v)
result.insert(k, v, false)
return false
}
// The key exists in both, resolve the conflict if possible
merged, stop := conflictResolver(v, v2)
if !stop {
result.Insert(k, merged)
result.insert(k, merged, false)
}
return stop
})
@@ -2333,7 +2531,7 @@ func (obj *object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (
// Copy in any values from other for keys that don't exist in obj
other.Foreach(func(k, v *Term) {
if v2 := obj.Get(k); v2 == nil {
result.Insert(k, v)
result.insert(k, v, false)
}
})
return result, true
@@ -2733,12 +2931,28 @@ func (c Call) IsGround() bool {
return termSliceIsGround(c)
}
// MakeExpr returns an ew Expr from this call.
// MakeExpr returns a new Expr from this call.
func (c Call) MakeExpr(output *Term) *Expr {
terms := []*Term(c)
return NewExpr(append(terms, output))
}
func (c Call) Operator() Ref {
if len(c) == 0 {
return nil
}
return c[0].Value.(Ref)
}
func (c Call) Operands() []*Term {
if len(c) < 1 {
return nil
}
return c[1:]
}
func (c Call) String() string {
args := make([]string, len(c)-1)
for i := 1; i < len(c); i++ {

View File

@@ -19,7 +19,6 @@ type Transformer interface {
// Transform iterates the AST and calls the Transform function on the
// Transformer t for x before recursing.
func Transform(t Transformer, x any) (any, error) {
if term, ok := x.(*Term); ok {
return Transform(t, term.Value)
}
@@ -284,6 +283,19 @@ func Transform(t Transformer, x any) (any, error) {
}
}
return y, nil
case *TemplateString:
for i := range y.Parts {
if expr, ok := y.Parts[i].(*Expr); ok {
transformed, err := Transform(t, expr)
if err != nil {
return nil, err
}
if y.Parts[i], ok = transformed.(*Expr); !ok {
return nil, fmt.Errorf("illegal transform: %T != %T", expr, transformed)
}
}
}
return y, nil
default:
return y, nil
}
@@ -291,29 +303,29 @@ func Transform(t Transformer, x any) (any, error) {
// TransformRefs calls the function f on all references under x.
func TransformRefs(x any, f func(Ref) (Value, error)) (any, error) {
t := &GenericTransformer{func(x any) (any, error) {
t := NewGenericTransformer(func(x any) (any, error) {
if r, ok := x.(Ref); ok {
return f(r)
}
return x, nil
}}
})
return Transform(t, x)
}
// TransformVars calls the function f on all vars under x.
func TransformVars(x any, f func(Var) (Value, error)) (any, error) {
t := &GenericTransformer{func(x any) (any, error) {
t := NewGenericTransformer(func(x any) (any, error) {
if v, ok := x.(Var); ok {
return f(v)
}
return x, nil
}}
})
return Transform(t, x)
}
// TransformComprehensions calls the functio nf on all comprehensions under x.
// TransformComprehensions calls the function f on all comprehensions under x.
func TransformComprehensions(x any, f func(any) (Value, error)) (any, error) {
t := &GenericTransformer{func(x any) (any, error) {
t := NewGenericTransformer(func(x any) (any, error) {
switch x := x.(type) {
case *ArrayComprehension:
return f(x)
@@ -323,7 +335,7 @@ func TransformComprehensions(x any, f func(any) (Value, error)) (any, error) {
return f(x)
}
return x, nil
}}
})
return Transform(t, x)
}
@@ -387,11 +399,7 @@ func transformTerm(t Transformer, term *Term) (*Term, error) {
if err != nil {
return nil, err
}
r := &Term{
Value: v,
Location: term.Location,
}
return r, nil
return &Term{Value: v, Location: term.Location}, nil
}
func transformValue(t Transformer, v Value) (Value, error) {
@@ -407,13 +415,18 @@ func transformValue(t Transformer, v Value) (Value, error) {
}
func transformVar(t Transformer, v Var) (Var, error) {
v1, err := Transform(t, v)
tv, err := t.Transform(v)
if err != nil {
return "", err
}
r, ok := v1.(Var)
if tv == nil {
return "", nil
}
r, ok := tv.(Var)
if !ok {
return "", fmt.Errorf("illegal transform: %T != %T", v, v1)
return "", fmt.Errorf("illegal transform: %T != %T", v, tv)
}
return r, nil
}

View File

@@ -11,12 +11,11 @@ func isRefSafe(ref Ref, safe VarSet) bool {
case Call:
return isCallSafe(head, safe)
default:
for v := range ref[0].Vars() {
if !safe.Contains(v) {
return false
}
}
return true
vis := varVisitorPool.Get().WithParams(SafetyCheckVisitorParams)
vis.Walk(ref[0])
isSafe := vis.Vars().DiffCount(safe) == 0
varVisitorPool.Put(vis)
return isSafe
}
}

View File

@@ -358,6 +358,11 @@
"Minor": 34,
"Patch": 0
},
"internal.template_string": {
"Major": 1,
"Minor": 12,
"Patch": 0
},
"internal.test_case": {
"Major": 1,
"Minor": 2,
@@ -1037,6 +1042,11 @@
"Major": 0,
"Minor": 59,
"Patch": 0
},
"template_strings": {
"Major": 1,
"Minor": 12,
"Patch": 0
}
},
"keywords": {

View File

@@ -4,44 +4,108 @@
package ast
// Visitor defines the interface for iterating AST elements. The Visit function
// can return a Visitor w which will be used to visit the children of the AST
// element v. If the Visit function returns nil, the children will not be
// visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type Visitor interface {
Visit(v any) (w Visitor)
}
var (
termTypeVisitor = newTypeVisitor[*Term]()
varTypeVisitor = newTypeVisitor[Var]()
exprTypeVisitor = newTypeVisitor[*Expr]()
ruleTypeVisitor = newTypeVisitor[*Rule]()
refTypeVisitor = newTypeVisitor[Ref]()
bodyTypeVisitor = newTypeVisitor[Body]()
withTypeVisitor = newTypeVisitor[*With]()
)
// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
// and after the AST has been visited.
//
// Deprecated: use GenericVisitor or another visitor implementation
type BeforeAndAfterVisitor interface {
Visitor
Before(x any)
After(x any)
}
type (
// GenericVisitor provides a utility to walk over AST nodes using a
// closure. If the closure returns true, the visitor will not walk
// over AST nodes under x.
GenericVisitor struct {
f func(x any) bool
}
// Walk iterates the AST by calling the Visit function on the Visitor
// BeforeAfterVisitor provides a utility to walk over AST nodes using
// closures. If the before closure returns true, the visitor will not
// walk over AST nodes under x. The after closure is invoked always
// after visiting a node.
BeforeAfterVisitor struct {
before func(x any) bool
after func(x any)
}
// VarVisitor walks AST nodes under a given node and collects all encountered
// variables. The collected variables can be controlled by specifying
// VarVisitorParams when creating the visitor.
VarVisitor struct {
params VarVisitorParams
vars VarSet
}
// VarVisitorParams contains settings for a VarVisitor.
VarVisitorParams struct {
SkipRefHead bool
SkipRefCallHead bool
SkipObjectKeys bool
SkipClosures bool
SkipWithTarget bool
SkipSets bool
}
// Visitor defines the interface for iterating AST elements. The Visit function
// can return a Visitor w which will be used to visit the children of the AST
// element v. If the Visit function returns nil, the children will not be
// visited.
//
// Deprecated: use [GenericVisitor] or another visitor implementation
Visitor interface {
Visit(v any) (w Visitor)
}
// BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
// and after the AST has been visited.
//
// Deprecated: use [GenericVisitor] or another visitor implementation
BeforeAndAfterVisitor interface {
Visitor
Before(x any)
After(x any)
}
// typeVisitor is a generic visitor for a specific type T (the "generic" name was
// however taken). Contrary to the [GenericVisitor], the typeVisitor only invokes
// the visit function for nodes of type T, saving both CPU cycles and type assertions.
// typeVisitor implementations carry no state, and can be shared freely across
// goroutines. Access is private for the time being, as there is already inflation
// in visitor types exposed in the AST package. The various WalkXXX functions however
// now leverage typeVisitor under the hood.
//
// While a typeVisitor is generally a more performant option over a GenericVisitor,
// it is not as flexible: a type visitor can only visit nodes of a single type T,
// whereas a GenericVisitor visits all nodes. Adding to that, a typeVisitor can only
// be instantiated for **concrete types** — not interfaces (e.g., [*Expr], not [Node]),
// as reflection would be required to determine the concrete type at runtime, thus
// nullifying the performance benefits of the typeVisitor in the first place.
typeVisitor[T any] struct {
typ any
}
)
// Walk iterates the AST by calling the Visit function on the [Visitor]
// v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
// Deprecated: use [GenericVisitor.Walk]
func Walk(v Visitor, x any) {
if bav, ok := v.(BeforeAndAfterVisitor); !ok {
walk(v, x)
} else {
bav.Before(x)
defer bav.After(x)
walk(bav, x)
bav.After(x)
}
}
// WalkBeforeAndAfter iterates the AST by calling the Visit function on the
// Visitor v for x before recursing.
//
// Deprecated: use GenericVisitor.Walk
// Deprecated: use [GenericVisitor.Walk]
func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x any) {
Walk(v, x)
}
@@ -153,132 +217,258 @@ func walk(v Visitor, x any) {
for i := range x.Symbols {
Walk(w, x.Symbols[i])
}
case *TemplateString:
for i := range x.Parts {
Walk(w, x.Parts[i])
}
}
}
// WalkVars calls the function f on all vars under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkVars(x any, f func(Var) bool) {
vis := &GenericVisitor{func(x any) bool {
if v, ok := x.(Var); ok {
return f(v)
}
return false
}}
vis.Walk(x)
varTypeVisitor.walk(x, f)
}
// WalkClosures calls the function f on all closures under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkClosures(x any, f func(any) bool) {
vis := &GenericVisitor{func(x any) bool {
vis := NewGenericVisitor(func(x any) bool {
switch x := x.(type) {
case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every:
return f(x)
}
return false
}}
})
vis.Walk(x)
}
// WalkRefs calls the function f on all references under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkRefs(x any, f func(Ref) bool) {
vis := &GenericVisitor{func(x any) bool {
if r, ok := x.(Ref); ok {
return f(r)
}
return false
}}
vis.Walk(x)
refTypeVisitor.walk(x, f)
}
// WalkTerms calls the function f on all terms under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkTerms(x any, f func(*Term) bool) {
vis := &GenericVisitor{func(x any) bool {
if term, ok := x.(*Term); ok {
return f(term)
}
return false
}}
vis.Walk(x)
termTypeVisitor.walk(x, f)
}
// WalkWiths calls the function f on all with modifiers under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkWiths(x any, f func(*With) bool) {
vis := &GenericVisitor{func(x any) bool {
if w, ok := x.(*With); ok {
return f(w)
}
return false
}}
vis.Walk(x)
withTypeVisitor.walk(x, f)
}
// WalkExprs calls the function f on all expressions under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkExprs(x any, f func(*Expr) bool) {
vis := &GenericVisitor{func(x any) bool {
if r, ok := x.(*Expr); ok {
return f(r)
}
return false
}}
vis.Walk(x)
exprTypeVisitor.walk(x, f)
}
// WalkBodies calls the function f on all bodies under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkBodies(x any, f func(Body) bool) {
vis := &GenericVisitor{func(x any) bool {
if b, ok := x.(Body); ok {
return f(b)
}
return false
}}
vis.Walk(x)
bodyTypeVisitor.walk(x, f)
}
// WalkRules calls the function f on all rules under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkRules(x any, f func(*Rule) bool) {
vis := &GenericVisitor{func(x any) bool {
if r, ok := x.(*Rule); ok {
stop := f(r)
// NOTE(tsandall): since rules cannot be embedded inside of queries
// we can stop early if there is no else block.
if stop || r.Else == nil {
return true
switch x := x.(type) {
case *Module:
for i := range x.Rules {
if !f(x.Rules[i]) && x.Rules[i].Else != nil {
WalkRules(x.Rules[i].Else, f)
}
}
return false
}}
vis.Walk(x)
case *Rule:
if !f(x) && x.Else != nil {
WalkRules(x.Else, f)
}
default:
ruleTypeVisitor.walk(x, f)
}
}
// WalkNodes calls the function f on all nodes under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkNodes(x any, f func(Node) bool) {
vis := &GenericVisitor{func(x any) bool {
vis := NewGenericVisitor(func(x any) bool {
if n, ok := x.(Node); ok {
return f(n)
}
return false
}}
})
vis.Walk(x)
}
// GenericVisitor provides a utility to walk over AST nodes using a
// closure. If the closure returns true, the visitor will not walk
// over AST nodes under x.
type GenericVisitor struct {
f func(x any) bool
func newTypeVisitor[T any]() *typeVisitor[T] {
var t T
return &typeVisitor[T]{typ: any(t)}
}
func (tv *typeVisitor[T]) walkArgs(args Args, visit func(x T) bool) {
// If T is not Args, avoid allocation by inlining the walk.
if _, ok := tv.typ.(Args); !ok {
for i := range args {
tv.walk(args[i], visit)
}
} else {
tv.walk(args, visit) // allocates
}
}
func (tv *typeVisitor[T]) walkBody(body Body, visit func(x T) bool) {
if _, ok := tv.typ.(Body); !ok {
for i := range body {
tv.walk(body[i], visit)
}
} else {
tv.walk(body, visit) // allocates
}
}
func (tv *typeVisitor[T]) walkRef(ref Ref, visit func(x T) bool) {
if _, ok := tv.typ.(Ref); !ok {
for i := range ref {
tv.walk(ref[i], visit)
}
} else {
tv.walk(ref, visit) // allocates
}
}
func (tv *typeVisitor[T]) walk(x any, visit func(x T) bool) {
if v, ok := x.(T); ok && visit(v) {
return
}
switch x := x.(type) {
case *Module:
tv.walk(x.Package, visit)
for i := range x.Imports {
tv.walk(x.Imports[i], visit)
}
for i := range x.Rules {
tv.walk(x.Rules[i], visit)
}
for i := range x.Annotations {
tv.walk(x.Annotations[i], visit)
}
for i := range x.Comments {
tv.walk(x.Comments[i], visit)
}
case *Package:
tv.walkRef(x.Path, visit)
case *Import:
tv.walk(x.Path, visit)
if _, ok := tv.typ.(Var); ok {
tv.walk(x.Alias, visit)
}
case *Rule:
tv.walk(x.Head, visit)
tv.walkBody(x.Body, visit)
if x.Else != nil {
tv.walk(x.Else, visit)
}
case *Head:
if _, ok := tv.typ.(Var); ok {
tv.walk(x.Name, visit)
}
tv.walkArgs(x.Args, visit)
if x.Key != nil {
tv.walk(x.Key, visit)
}
if x.Value != nil {
tv.walk(x.Value, visit)
}
case Body:
for i := range x {
tv.walk(x[i], visit)
}
case Args:
for i := range x {
tv.walk(x[i], visit)
}
case *Expr:
switch ts := x.Terms.(type) {
case *Term, *SomeDecl, *Every:
tv.walk(ts, visit)
case []*Term:
for i := range ts {
tv.walk(ts[i], visit)
}
}
for i := range x.With {
tv.walk(x.With[i], visit)
}
case *With:
tv.walk(x.Target, visit)
tv.walk(x.Value, visit)
case *Term:
tv.walk(x.Value, visit)
case Ref:
for i := range x {
tv.walk(x[i], visit)
}
case *object:
x.Foreach(func(k, v *Term) {
tv.walk(k, visit)
tv.walk(v, visit)
})
case Object:
for _, k := range x.Keys() {
tv.walk(k, visit)
tv.walk(x.Get(k), visit)
}
case *Array:
for i := range x.Len() {
tv.walk(x.Elem(i), visit)
}
case Set:
xSlice := x.Slice()
for i := range xSlice {
tv.walk(xSlice[i], visit)
}
case *ArrayComprehension:
tv.walk(x.Term, visit)
tv.walkBody(x.Body, visit)
case *ObjectComprehension:
tv.walk(x.Key, visit)
tv.walk(x.Value, visit)
tv.walkBody(x.Body, visit)
case *SetComprehension:
tv.walk(x.Term, visit)
tv.walkBody(x.Body, visit)
case Call:
for i := range x {
tv.walk(x[i], visit)
}
case *Every:
if x.Key != nil {
tv.walk(x.Key, visit)
}
tv.walk(x.Value, visit)
tv.walk(x.Domain, visit)
tv.walkBody(x.Body, visit)
case *SomeDecl:
for i := range x.Symbols {
tv.walk(x.Symbols[i], visit)
}
case *TemplateString:
for i := range x.Parts {
tv.walk(x.Parts[i], visit)
}
}
}
// NewGenericVisitor returns a new GenericVisitor that will invoke the function
// f on AST nodes.
// f on AST nodes. Note that while it returns a pointer, the creating a GenericVisitor
// doesn't commonly allocate it on the heap, as long as it doesn't escape the function
// in which it is created and used (as it's trivially inlined).
func NewGenericVisitor(f func(x any) bool) *GenericVisitor {
return &GenericVisitor{f}
}
@@ -310,7 +500,9 @@ func (vis *GenericVisitor) Walk(x any) {
vis.Walk(x.Path)
case *Import:
vis.Walk(x.Path)
vis.Walk(x.Alias)
if x.Alias != "" {
vis.f(x.Alias)
}
case *Rule:
vis.Walk(x.Head)
vis.Walk(x.Body)
@@ -318,8 +510,12 @@ func (vis *GenericVisitor) Walk(x any) {
vis.Walk(x.Else)
}
case *Head:
vis.Walk(x.Name)
vis.Walk(x.Args)
if x.Name != "" {
vis.f(x.Name)
}
if x.Args != nil {
vis.Walk(x.Args)
}
if x.Key != nil {
vis.Walk(x.Key)
}
@@ -399,18 +595,13 @@ func (vis *GenericVisitor) Walk(x any) {
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
case *TemplateString:
for i := range x.Parts {
vis.Walk(x.Parts[i])
}
}
}
// BeforeAfterVisitor provides a utility to walk over AST nodes using
// closures. If the before closure returns true, the visitor will not
// walk over AST nodes under x. The after closure is invoked always
// after visiting a node.
type BeforeAfterVisitor struct {
before func(x any) bool
after func(x any)
}
// NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that
// will invoke the functions before and after AST nodes.
func NewBeforeAfterVisitor(before func(x any) bool, after func(x any)) *BeforeAfterVisitor {
@@ -542,31 +733,29 @@ func (vis *BeforeAfterVisitor) Walk(x any) {
}
}
// VarVisitor walks AST nodes under a given node and collects all encountered
// variables. The collected variables can be controlled by specifying
// VarVisitorParams when creating the visitor.
type VarVisitor struct {
params VarVisitorParams
vars VarSet
}
// VarVisitorParams contains settings for a VarVisitor.
type VarVisitorParams struct {
SkipRefHead bool
SkipRefCallHead bool
SkipObjectKeys bool
SkipClosures bool
SkipWithTarget bool
SkipSets bool
}
// NewVarVisitor returns a new VarVisitor object.
// NewVarVisitor returns a new [VarVisitor] object.
func NewVarVisitor() *VarVisitor {
return &VarVisitor{
vars: NewVarSet(),
}
}
// ClearOrNewVarVisitor clears a non-nil [VarVisitor] or returns a new one.
func ClearOrNewVarVisitor(vis *VarVisitor) *VarVisitor {
if vis == nil {
return NewVarVisitor()
}
return vis.Clear()
}
// ClearOrNew resets the visitor to its initial state, or returns a new one if nil.
//
// Deprecated: use [ClearOrNewVarVisitor] instead.
func (vis *VarVisitor) ClearOrNew() *VarVisitor {
return ClearOrNewVarVisitor(vis)
}
// Clear resets the visitor to its initial state, and returns it for chaining.
func (vis *VarVisitor) Clear() *VarVisitor {
vis.params = VarVisitorParams{}
@@ -575,14 +764,6 @@ func (vis *VarVisitor) Clear() *VarVisitor {
return vis
}
// ClearOrNew returns a new VarVisitor if vis is nil, or else a cleared VarVisitor.
func (vis *VarVisitor) ClearOrNew() *VarVisitor {
if vis == nil {
return NewVarVisitor()
}
return vis.Clear()
}
// WithParams sets the parameters in params on vis.
func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor {
vis.params = params
@@ -598,7 +779,7 @@ func (vis *VarVisitor) Add(v Var) {
}
}
// Vars returns a VarSet that contains collected vars.
// Vars returns a [VarSet] that contains collected vars.
func (vis *VarVisitor) Vars() VarSet {
return vis.vars
}
@@ -625,7 +806,7 @@ func (vis *VarVisitor) visit(v any) bool {
}
if vis.params.SkipClosures {
switch v := v.(type) {
case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *TemplateString:
return true
case *Expr:
if ev, ok := v.Terms.(*Every); ok {
@@ -695,9 +876,8 @@ func (vis *VarVisitor) visit(v any) bool {
return false
}
// Walk iterates the AST by calling the function f on the
// GenericVisitor before recursing. Contrary to the generic Walk, this
// does not require allocating the visitor from heap.
// Walk iterates the AST by calling the function f on the [VarVisitor] before recursing.
// Contrary to the deprecated [Walk] function, this does not require allocating the visitor from heap.
func (vis *VarVisitor) Walk(x any) {
if vis.visit(x) {
return
@@ -705,16 +885,9 @@ func (vis *VarVisitor) Walk(x any) {
switch x := x.(type) {
case *Module:
vis.Walk(x.Package)
for i := range x.Imports {
vis.Walk(x.Imports[i])
}
for i := range x.Rules {
vis.Walk(x.Rules[i])
}
for i := range x.Comments {
vis.Walk(x.Comments[i])
}
case *Package:
vis.WalkRef(x.Path)
case *Import:
@@ -767,9 +940,9 @@ func (vis *VarVisitor) Walk(x any) {
vis.Walk(x[i].Value)
}
case *object:
x.Foreach(func(k, _ *Term) {
x.Foreach(func(k, v *Term) {
vis.Walk(k)
vis.Walk(x.Get(k))
vis.Walk(v)
})
case *Array:
x.Foreach(func(t *Term) {
@@ -805,6 +978,10 @@ func (vis *VarVisitor) Walk(x any) {
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
case *TemplateString:
for i := range x.Parts {
vis.Walk(x.Parts[i])
}
}
}

View File

@@ -970,7 +970,7 @@ func compileModules(compiler *ast.Compiler, m metrics.Metrics, bundles map[strin
m.Timer(metrics.RegoModuleCompile).Start()
defer m.Timer(metrics.RegoModuleCompile).Stop()
modules := map[string]*ast.Module{}
modules := make(map[string]*ast.Module, len(compiler.Modules)+len(extraModules)+len(bundles))
// preserve any modules already on the compiler
maps.Copy(modules, compiler.Modules)

View File

@@ -27,8 +27,6 @@ import (
const defaultLocationFile = "__format_default__"
var (
elseVar ast.Value = ast.Var("else")
expandedConst = ast.NewBody(ast.NewExpr(ast.InternedTerm(true)))
commentsSlicePool = util.NewSlicePool[*ast.Comment](50)
varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$")
@@ -732,7 +730,7 @@ func (w *writer) writeElse(rule *ast.Rule, comments []*ast.Comment) ([]*ast.Comm
rule.Else.Head.Name = "else" // NOTE(sr): whaaat
elseHeadReference := ast.NewTerm(elseVar) // construct a reference for the term
elseHeadReference := ast.VarTerm("else") // construct a reference for the term
elseHeadReference.Location = rule.Else.Head.Location // and set the location to match the rule location
rule.Else.Head.Reference = ast.Ref{elseHeadReference}
@@ -1284,6 +1282,11 @@ func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Co
}
}
case *ast.TemplateString:
comments, err = w.writeTemplateString(x, comments)
if err != nil {
return nil, err
}
case ast.Var:
w.write(w.formatVar(x))
case ast.Call:
@@ -1301,6 +1304,91 @@ func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Co
return comments, nil
}
func (w *writer) writeTemplateString(ts *ast.TemplateString, comments []*ast.Comment) ([]*ast.Comment, error) {
w.write("$")
if ts.MultiLine {
w.write("`")
} else {
w.write(`"`)
}
for i, p := range ts.Parts {
switch x := p.(type) {
case *ast.Expr:
w.write("{")
w.up()
if w.beforeEnd != nil {
// We have a comment on the same line as the opening template-expression brace '{'
w.endLine()
w.startLine()
} else {
// We might have comments to write; the first of which should be on the same line as the opening template-expression brace '{'
before, _, _ := partitionComments(comments, x.Location)
if len(before) > 0 {
w.write(" ")
w.inline = true
if err := w.writeComments(before); err != nil {
return nil, err
}
comments = comments[len(before):]
}
}
var err error
comments, err = w.writeExpr(x, comments)
if err != nil {
return comments, err
}
// write trailing comments
if i+1 < len(ts.Parts) {
before, _, _ := partitionComments(comments, ts.Parts[i+1].Loc())
if len(before) > 0 {
w.endLine()
if err := w.writeComments(before); err != nil {
return nil, err
}
comments = comments[len(before):]
w.startLine()
}
}
w.write("}")
if err := w.down(); err != nil {
return nil, err
}
case *ast.Term:
if s, ok := x.Value.(ast.String); ok {
if ts.MultiLine {
w.write(ast.EscapeTemplateStringStringPart(string(s)))
} else {
str := ast.EscapeTemplateStringStringPart(s.String())
w.write(str[1 : len(str)-1])
}
} else {
s := x.String()
s = strings.TrimPrefix(s, "\"")
s = strings.TrimSuffix(s, "\"")
w.write(s)
}
default:
w.write("<invalid>")
}
}
if ts.MultiLine {
w.write("`")
} else {
w.write(`"`)
}
return comments, nil
}
func (w *writer) writeRef(x ast.Ref, comments []*ast.Comment) ([]*ast.Comment, error) {
if len(x) > 0 {
parens := false
@@ -1931,7 +2019,7 @@ func partitionComments(comments []*ast.Comment, l *ast.Location) ([]*ast.Comment
var at *ast.Comment
before := make([]*ast.Comment, 0, numBefore)
after := comments[0 : 0 : len(comments)-numBefore]
after := make([]*ast.Comment, 0, numAfter)
for _, c := range comments {
switch cmp := c.Location.Row - l.Row; {

View File

@@ -2212,7 +2212,7 @@ func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Met
if r.pkg != "" {
var err error
pkg, err = ast.ParsePackage(fmt.Sprintf("package %v", r.pkg))
pkg, err = ast.ParsePackage("package " + r.pkg)
if err != nil {
return nil, nil, err
}

View File

@@ -216,16 +216,25 @@ func (vis namespacingVisitor) Visit(x any) bool {
switch x := x.(type) {
case *ast.ArrayComprehension:
x.Term = vis.namespaceTerm(x.Term)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
vis := ast.NewGenericVisitor(vis.Visit)
for _, expr := range x.Body {
vis.Walk(expr)
}
return true
case *ast.SetComprehension:
x.Term = vis.namespaceTerm(x.Term)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
vis := ast.NewGenericVisitor(vis.Visit)
for _, expr := range x.Body {
vis.Walk(expr)
}
return true
case *ast.ObjectComprehension:
x.Key = vis.namespaceTerm(x.Key)
x.Value = vis.namespaceTerm(x.Value)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
vis := ast.NewGenericVisitor(vis.Visit)
for _, expr := range x.Body {
vis.Walk(expr)
}
return true
case *ast.Expr:
switch terms := x.Terms.(type) {

View File

@@ -344,7 +344,7 @@ func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
}
for _, v := range p.sorted {
if ref[0].Value.Compare(v) == 0 {
if v.Equal(ref[0].Value) {
return true
}
}
@@ -403,7 +403,7 @@ func containedIn(value ast.Value, x any) bool {
if v, ok := value.(ast.Ref); ok {
match = x.HasPrefix(v)
} else {
match = x.Compare(value) == 0
match = x.Equal(value)
}
if stop || match {
stop = true

View File

@@ -28,7 +28,6 @@ func (h printHook) Print(_ print.Context, msg string) error {
}
func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
if bctx.PrintHook == nil {
return iter(nil)
}
@@ -40,7 +39,7 @@ func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term
buf := make([]string, arr.Len())
err = builtinPrintCrossProductOperands(bctx, buf, arr, 0, func(buf []string) error {
err = builtinPrintCrossProductOperands(bctx.Location, buf, arr, 0, func(buf []string) error {
pctx := print.Context{
Context: bctx.Context,
Location: bctx.Location,
@@ -54,20 +53,32 @@ func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term
return iter(nil)
}
func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operands *ast.Array, i int, f func([]string) error) error {
func builtinPrintCrossProductOperands(loc *ast.Location, buf []string, operands *ast.Array, i int, f func([]string) error) error {
if i >= operands.Len() {
return f(buf)
}
xs, ok := operands.Elem(i).Value.(ast.Set)
operand := operands.Elem(i)
// We allow primitives ...
switch x := operand.Value.(type) {
case ast.String:
buf[i] = string(x)
return builtinPrintCrossProductOperands(loc, buf, operands, i+1, f)
case ast.Number, ast.Boolean, ast.Null:
buf[i] = x.String()
return builtinPrintCrossProductOperands(loc, buf, operands, i+1, f)
}
// ... but all other operand types must be sets.
xs, ok := operand.Value.(ast.Set)
if !ok {
return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.ValueName(operands.Elem(i).Value)))}
return Halt{Err: internalErr(loc, "illegal argument type: "+ast.ValueName(operand.Value))}
}
if xs.Len() == 0 {
buf[i] = "<undefined>"
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
return builtinPrintCrossProductOperands(loc, buf, operands, i+1, f)
}
return xs.Iter(func(x *ast.Term) error {
@@ -77,7 +88,7 @@ func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operand
default:
buf[i] = v.String()
}
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
return builtinPrintCrossProductOperands(loc, buf, operands, i+1, f)
})
}

View File

@@ -134,7 +134,7 @@ func (q *Query) WithTracer(tracer Tracer) *Query {
// WithQueryTracer adds a query tracer to use during evaluation. This is optional.
// Disabled QueryTracers will be ignored.
func (q *Query) WithQueryTracer(tracer QueryTracer) *Query {
if !tracer.Enabled() {
if tracer == nil || !tracer.Enabled() {
return q
}

View File

@@ -0,0 +1,73 @@
package topdown
import (
"bytes"
"io"
)
var _ io.Writer = (*sinkW)(nil)
type sinkWriter interface {
io.Writer
String() string
Grow(int)
WriteByte(byte) error
WriteString(string) (int, error)
}
type sinkW struct {
buf *bytes.Buffer
cancel Cancel
err error
}
func newSink(name string, hint int, c Cancel) sinkWriter {
b := &bytes.Buffer{}
if hint > 0 {
b.Grow(hint)
}
if c == nil {
return b
}
return &sinkW{
cancel: c,
buf: b,
err: Halt{
Err: &Error{
Code: CancelErr,
Message: name + ": timed out before finishing",
},
},
}
}
func (sw *sinkW) Grow(n int) {
sw.buf.Grow(n)
}
func (sw *sinkW) Write(bs []byte) (int, error) {
if sw.cancel.Cancelled() {
return 0, sw.err
}
return sw.buf.Write(bs)
}
func (sw *sinkW) WriteByte(b byte) error {
if sw.cancel.Cancelled() {
return sw.err
}
return sw.buf.WriteByte(b)
}
func (sw *sinkW) WriteString(s string) (int, error) {
if sw.cancel.Cancelled() {
return 0, sw.err
}
return sw.buf.WriteString(s)
}
func (sw *sinkW) String() string {
return sw.buf.String()
}

View File

@@ -152,7 +152,7 @@ func builtinFormatInt(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter
return iter(ast.InternedTerm(fmt.Sprintf(format, i)))
}
func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
func builtinConcat(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
join, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
@@ -163,11 +163,13 @@ func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
return iter(term)
}
sb := newSink(ast.Concat.Name, 0, bctx.Cancel)
// NOTE(anderseknert):
// More or less Go's strings.Join implementation, but where we avoid
// creating an intermediate []string slice to pass to that function,
// as that's expensive (3.5x more space allocated). Instead we build
// the string directly using a strings.Builder to concatenate the string
// the string directly using the sink to concatenate the string
// values from the array/set with the separator.
n := 0
switch b := operands[1].Value.(type) {
@@ -182,25 +184,36 @@ func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
}
sep := string(join)
n += len(sep) * (l - 1)
var sb strings.Builder
sb.Grow(n)
sb.WriteString(string(b.Elem(0).Value.(ast.String)))
if _, err := sb.WriteString(string(b.Elem(0).Value.(ast.String))); err != nil {
return err
}
if sep == "" {
for i := 1; i < l; i++ {
sb.WriteString(string(b.Elem(i).Value.(ast.String)))
if _, err := sb.WriteString(string(b.Elem(i).Value.(ast.String))); err != nil {
return err
}
}
} else if len(sep) == 1 {
// when the separator is a single byte, sb.WriteByte is substantially faster
bsep := sep[0]
for i := 1; i < l; i++ {
sb.WriteByte(bsep)
sb.WriteString(string(b.Elem(i).Value.(ast.String)))
if err := sb.WriteByte(bsep); err != nil {
return err
}
if _, err := sb.WriteString(string(b.Elem(i).Value.(ast.String))); err != nil {
return err
}
}
} else {
// for longer separators, there is no such difference between WriteString and Write
for i := 1; i < l; i++ {
sb.WriteString(sep)
sb.WriteString(string(b.Elem(i).Value.(ast.String)))
if _, err := sb.WriteString(sep); err != nil {
return err
}
if _, err := sb.WriteString(string(b.Elem(i).Value.(ast.String))); err != nil {
return err
}
}
}
return iter(ast.InternedTerm(sb.String()))
@@ -215,12 +228,15 @@ func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
sep := string(join)
l := b.Len()
n += len(sep) * (l - 1)
var sb strings.Builder
sb.Grow(n)
for i, v := range b.Slice() {
sb.WriteString(string(v.Value.(ast.String)))
if _, err := sb.WriteString(string(v.Value.(ast.String))); err != nil {
return err
}
if i < l-1 {
sb.WriteString(sep)
if _, err := sb.WriteString(sep); err != nil {
return err
}
}
}
return iter(ast.InternedTerm(sb.String()))
@@ -523,7 +539,7 @@ func builtinSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e
return iter(ast.ArrayTerm(util.SplitMap(text, delim, ast.InternedTerm)...))
}
func builtinReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
func builtinReplace(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
@@ -539,7 +555,12 @@ func builtinReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
return err
}
replaced := strings.ReplaceAll(string(s), string(old), string(n))
sink := newSink(ast.Replace.Name, len(s), bctx.Cancel)
replacer := strings.NewReplacer(string(old), string(n))
if _, err := replacer.WriteString(sink, string(s)); err != nil {
return err
}
replaced := sink.String()
if replaced == string(s) {
return iter(operands[0])
}
@@ -547,7 +568,7 @@ func builtinReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
return iter(ast.InternedTerm(replaced))
}
func builtinReplaceN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
func builtinReplaceN(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
patterns, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
@@ -574,7 +595,12 @@ func builtinReplaceN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term
oldnewArr = append(oldnewArr, string(keyVal), string(strVal))
}
return iter(ast.InternedTerm(strings.NewReplacer(oldnewArr...).Replace(string(s))))
sink := newSink(ast.ReplaceN.Name, len(s), bctx.Cancel)
replacer := strings.NewReplacer(oldnewArr...)
if _, err := replacer.WriteString(sink, string(s)); err != nil {
return err
}
return iter(ast.InternedTerm(sink.String()))
}
func builtinTrim(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {

View File

@@ -0,0 +1,45 @@
// Copyright 2025 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/topdown/builtins"
)
func builtinTemplateString(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
if err != nil {
return err
}
buf := make([]string, arr.Len())
var count int
err = builtinPrintCrossProductOperands(bctx.Location, buf, arr, 0, func(buf []string) error {
count += 1
// Precautionary run-time assertion that template-strings can't produce multiple outputs; e.g. for custom relation type built-ins not known at compile-time.
if count > 1 {
return Halt{Err: &Error{
Code: ConflictErr,
Location: bctx.Location,
Message: "template-strings must not produce multiple outputs",
}}
}
return nil
})
if err != nil {
return err
}
return iter(ast.StringTerm(strings.Join(buf, "")))
}
func init() {
RegisterBuiltinFunc(ast.InternalTemplateString.Name, builtinTemplateString)
}

View File

@@ -21,6 +21,7 @@ import (
"fmt"
"hash"
"math/big"
"strconv"
"strings"
"github.com/lestrrat-go/jwx/v3/jwk"
@@ -1131,8 +1132,8 @@ func builtinJWTDecodeVerify(bctx BuiltinContext, operands []*ast.Term, iter func
switch v := nbf.Value.(type) {
case ast.Number:
// constraints.time is in nanoseconds but nbf Value is in seconds
compareTime := ast.FloatNumberTerm(constraints.time / 1000000000)
if ast.Compare(compareTime, v) == -1 {
compareTime := ast.Number(strconv.FormatFloat(constraints.time/1000000000, 'g', -1, 64))
if compareTime.Compare(v) == -1 {
return iter(unverified)
}
default:

View File

@@ -10,7 +10,7 @@ import (
"runtime/debug"
)
var Version = "1.11.1"
var Version = "1.12.3"
// GoVersion is the version of Go this was built with
var GoVersion = runtime.Version()

2
vendor/modules.txt vendored
View File

@@ -1277,7 +1277,7 @@ github.com/onsi/gomega/matchers/support/goraph/edge
github.com/onsi/gomega/matchers/support/goraph/node
github.com/onsi/gomega/matchers/support/goraph/util
github.com/onsi/gomega/types
# github.com/open-policy-agent/opa v1.11.1
# github.com/open-policy-agent/opa v1.12.3
## explicit; go 1.24.6
github.com/open-policy-agent/opa/ast
github.com/open-policy-agent/opa/ast/json