Skip to content

Commit 633d2d8

Browse files
committed
Preserve source information during policy composition.
The old implementation outputs the wrong source information for policy files derived from the dummy AST created by RuleComposer.Compose. This change fixes that by preserving the correct source and merging the offset ranges of the rule match expressions inserted by the composer optimizer into the final AST.
1 parent 52280ba commit 633d2d8

File tree

11 files changed

+512
-36
lines changed

11 files changed

+512
-36
lines changed

cel/folding_test.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ func TestConstantFoldingOptimizer(t *testing.T) {
366366
if err != nil {
367367
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
368368
}
369-
opt := NewStaticOptimizer(folder)
369+
opt, err := NewStaticOptimizer(folder)
370+
if err != nil {
371+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
372+
}
370373
optimized, iss := opt.Optimize(e, checked)
371374
if iss.Err() != nil {
372375
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -441,7 +444,10 @@ func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
441444
if err != nil {
442445
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
443446
}
444-
opt := NewStaticOptimizer(folder)
447+
opt, err := NewStaticOptimizer(folder)
448+
if err != nil {
449+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
450+
}
445451
optimized, iss := opt.Optimize(e, checked)
446452
if tc.error != "" {
447453
if iss.Err() == nil {
@@ -508,7 +514,10 @@ func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
508514
if err != nil {
509515
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
510516
}
511-
opt := NewStaticOptimizer(folder)
517+
opt, err := NewStaticOptimizer(folder)
518+
if err != nil {
519+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
520+
}
512521
optimized, iss := opt.Optimize(e, checked)
513522
if iss.Err() != nil {
514523
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -570,7 +579,10 @@ func TestConstantFoldingOptimizerWithLimit(t *testing.T) {
570579
if err != nil {
571580
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
572581
}
573-
opt := NewStaticOptimizer(folder)
582+
opt, err := NewStaticOptimizer(folder)
583+
if err != nil {
584+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
585+
}
574586
optimized, iss := opt.Optimize(e, checked)
575587
if iss.Err() != nil {
576588
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -828,7 +840,10 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
828840
if err != nil {
829841
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
830842
}
831-
opt := NewStaticOptimizer(folder)
843+
opt, err := NewStaticOptimizer(folder)
844+
if err != nil {
845+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
846+
}
832847
optimized, iss := opt.Optimize(e, checked)
833848
if iss.Err() != nil {
834849
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())

cel/inlining_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,10 @@ func TestInliningOptimizer(t *testing.T) {
220220
t.Fatalf("Compile() failed: %v", iss.Err())
221221
}
222222

223-
opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
223+
opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
224+
if err != nil {
225+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
226+
}
224227
optimized, iss := opt.Optimize(e, checked)
225228
if iss.Err() != nil {
226229
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -236,7 +239,10 @@ func TestInliningOptimizer(t *testing.T) {
236239
if err != nil {
237240
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
238241
}
239-
opt = cel.NewStaticOptimizer(folder)
242+
opt, err = cel.NewStaticOptimizer(folder)
243+
if err != nil {
244+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
245+
}
240246
optimized, iss = opt.Optimize(e, optimized)
241247
if iss.Err() != nil {
242248
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -727,7 +733,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) {
727733
t.Fatalf("Compile() failed: %v", iss.Err())
728734
}
729735

730-
opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
736+
opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
737+
if err != nil {
738+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
739+
}
731740
optimized, iss := opt.Optimize(e, checked)
732741
if iss.Err() != nil {
733742
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -743,7 +752,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) {
743752
if err != nil {
744753
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
745754
}
746-
opt = cel.NewStaticOptimizer(folder)
755+
opt, err = cel.NewStaticOptimizer(folder)
756+
if err != nil {
757+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
758+
}
747759
optimized, iss = opt.Optimize(e, optimized)
748760
if iss.Err() != nil {
749761
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())

cel/optimizer.go

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package cel
1616

1717
import (
18+
"fmt"
1819
"sort"
1920

2021
"github.com/google/cel-go/common"
@@ -29,17 +30,43 @@ import (
2930
// passes to ensure that the final optimized output is a valid expression with metadata consistent
3031
// with what would have been generated from a parsed and checked expression.
3132
//
32-
// Note: source position information is best-effort and likely wrong, but optimized expressions
33+
// Note: source position information is best-effort and incomplete, but optimized expressions
3334
// should be suitable for calls to parser.Unparse.
3435
type StaticOptimizer struct {
3536
optimizers []ASTOptimizer
37+
// If set, Optimize() will use this Source instead of the one from the AST.
38+
sourceOverride *Source
3639
}
3740

41+
type OptimizerOption func(*StaticOptimizer) (*StaticOptimizer, error)
42+
3843
// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
3944
// to a checked expression.
40-
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
41-
return &StaticOptimizer{
42-
optimizers: optimizers,
45+
func NewStaticOptimizer(options ...any) (*StaticOptimizer, error) {
46+
so := &StaticOptimizer{}
47+
var err error
48+
for _, opt := range options {
49+
switch v := opt.(type) {
50+
case ASTOptimizer:
51+
so.optimizers = append(so.optimizers, v)
52+
case OptimizerOption:
53+
so, err = v(so)
54+
if err != nil {
55+
return nil, err
56+
}
57+
default:
58+
return nil, fmt.Errorf("unsupported option: %v", v)
59+
}
60+
}
61+
return so, nil
62+
}
63+
64+
// OptimizeWithSource overrides the source used by the optimizer.
65+
// Note this will cause the source info from the AST passed to Optimize() to be discarded.
66+
func OptimizeWithSource(source Source) OptimizerOption {
67+
return func(so *StaticOptimizer) (*StaticOptimizer, error) {
68+
so.sourceOverride = &source
69+
return so, nil
4370
}
4471
}
4572

@@ -49,15 +76,21 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
4976
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
5077
// Make a copy of the AST to be optimized.
5178
optimized := ast.Copy(a.NativeRep())
79+
source := a.Source()
80+
sourceInfo := optimized.SourceInfo()
81+
if opt.sourceOverride != nil {
82+
source = *opt.sourceOverride
83+
sourceInfo = ast.NewSourceInfo(*opt.sourceOverride)
84+
}
5285
ids := newIDGenerator(ast.MaxID(a.NativeRep()))
5386

5487
// Create the optimizer context, could be pooled in the future.
55-
issues := NewIssues(common.NewErrors(a.Source()))
88+
issues := NewIssues(common.NewErrors(source))
5689
baseFac := ast.NewExprFactory()
5790
exprFac := &optimizerExprFactory{
5891
idGenerator: ids,
5992
fac: baseFac,
60-
sourceInfo: optimized.SourceInfo(),
93+
sourceInfo: sourceInfo,
6194
}
6295
ctx := &OptimizerContext{
6396
optimizerExprFactory: exprFac,
@@ -80,7 +113,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
80113

81114
// Recheck the updated expression for any possible type-agreement or validation errors.
82115
parsed := &Ast{
83-
source: a.Source(),
116+
source: source,
84117
impl: ast.NewAST(expr, info)}
85118
checked, iss := ctx.Check(parsed)
86119
if iss.Err() != nil {
@@ -91,7 +124,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
91124

92125
// Return the optimized result.
93126
return &Ast{
94-
source: a.Source(),
127+
source: source,
95128
impl: optimized,
96129
}, nil
97130
}
@@ -100,6 +133,8 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
100133
// that the ids within the expression correspond to the ids within macros.
101134
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
102135
optimized.RenumberIDs(idGen)
136+
info.RenumberIDs(idGen)
137+
103138
if len(info.MacroCalls()) == 0 {
104139
return
105140
}
@@ -260,6 +295,9 @@ func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
260295
for macroID, call := range copyInfo.MacroCalls() {
261296
opt.SetMacroCall(macroID, call)
262297
}
298+
for id, offset := range copyInfo.OffsetRanges() {
299+
opt.sourceInfo.SetOffsetRange(id, offset)
300+
}
263301
return copyExpr
264302
}
265303

cel/optimizer_test.go

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
4545
if iss.Err() != nil {
4646
t.Fatalf("Compile() failed: %v", iss.Err())
4747
}
48-
opt := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()})
48+
opt, err := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()})
49+
if err != nil {
50+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
51+
}
4952
optAST, iss := opt.Optimize(e, exprAST)
5053
if iss.Err() != nil {
5154
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -59,28 +62,17 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
5962
if err != nil {
6063
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
6164
}
65+
sourceInfoPB.Positions = nil
6266
wantTextPB := `
6367
location: "<input>"
6468
line_offsets: 9
65-
positions: {
66-
key: 2
67-
value: 4
68-
}
69-
positions: {
70-
key: 3
71-
value: 5
72-
}
73-
positions: {
74-
key: 4
75-
value: 3
76-
}
7769
macro_calls: {
7870
key: 1
7971
value: {
8072
call_expr: {
8173
function: "has"
8274
args: {
83-
id: 21
75+
id: 24
8476
select_expr: {
8577
operand: {
8678
id: 2
@@ -186,7 +178,10 @@ func TestStaticOptimizerNewAST(t *testing.T) {
186178
if iss.Err() != nil {
187179
t.Fatalf("Compile(%q) failed: %v", tc, iss.Err())
188180
}
189-
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
181+
opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t})
182+
if err != nil {
183+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
184+
}
190185
optAST, iss := opt.Optimize(e, exprAST)
191186
if iss.Err() != nil {
192187
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
@@ -202,9 +197,69 @@ func TestStaticOptimizerNewAST(t *testing.T) {
202197
}
203198
}
204199

200+
func TestOptimizeWithSource(t *testing.T) {
201+
initial := `has(a.b)`
202+
replacement := `x["a"]`
203+
e := optimizerEnv(t)
204+
initialAST, iss := e.Compile(initial)
205+
if iss.Err() != nil {
206+
t.Fatalf("Compile(%q) failed: %v", initial, iss.Err())
207+
}
208+
replacementAST, iss := e.Compile(replacement)
209+
if iss.Err() != nil {
210+
t.Fatalf("Compile(%q) failed: %v", replacement, iss.Err())
211+
}
212+
213+
opt, err := cel.NewStaticOptimizer(
214+
&replaceOptimizer{t: t, targetAST: replacementAST.NativeRep()},
215+
cel.OptimizeWithSource(replacementAST.Source()),
216+
)
217+
if err != nil {
218+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
219+
}
220+
optAST, iss := opt.Optimize(e, initialAST)
221+
if iss.Err() != nil {
222+
t.Fatalf("Optimize() returned an error: %v", iss.Err())
223+
}
224+
225+
if optAST.Source().Content() != replacement {
226+
t.Errorf("got source content %q, wanted %q", optAST.Source().Content(), replacement)
227+
}
228+
sourceInfoPB, err := ast.SourceInfoToProto(optAST.NativeRep().SourceInfo())
229+
if err != nil {
230+
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
231+
}
232+
wantTextPB := `
233+
location: "<input>"
234+
line_offsets: 7
235+
positions: {
236+
key: 1
237+
value: 1
238+
}
239+
positions: {
240+
key: 2
241+
value: 0
242+
}
243+
positions: {
244+
key: 3
245+
value: 2
246+
}
247+
`
248+
var wantSourceInfoPB exprpb.SourceInfo
249+
if err := prototext.Unmarshal([]byte(wantTextPB), &wantSourceInfoPB); err != nil {
250+
t.Fatalf("prototext.Unmarshal() failed: %v", err)
251+
}
252+
if !proto.Equal(&wantSourceInfoPB, sourceInfoPB) {
253+
t.Errorf("got source info: %s, wanted %s", prototext.Format(sourceInfoPB), wantTextPB)
254+
}
255+
}
256+
205257
func TestStaticOptimizerNilAST(t *testing.T) {
206258
env := optimizerEnv(t)
207-
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
259+
opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t})
260+
if err != nil {
261+
t.Fatalf("NewStaticOptimizer() failed: %v", err)
262+
}
208263
optAST, iss := opt.Optimize(env, nil)
209264
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") {
210265
t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss)
@@ -245,6 +300,17 @@ func (opt *testOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.A
245300
return ctx.NewAST(a.Expr())
246301
}
247302

303+
type replaceOptimizer struct {
304+
t *testing.T
305+
targetAST *ast.AST
306+
}
307+
308+
func (opt *replaceOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
309+
opt.t.Helper()
310+
copy := ctx.CopyASTAndMetadata(opt.targetAST)
311+
return ctx.NewAST(copy)
312+
}
313+
248314
func getMacroKeys(macroCalls map[int64]ast.Expr) []int {
249315
keys := []int{}
250316
for k := range macroCalls {

0 commit comments

Comments
 (0)