Skip to content

Commit 0e89407

Browse files
committed
fix(filter): enforce CEL syntax semantics
Reject non-standard truthy numeric expressions in filters and document the parser as a supported subset of standard CEL syntax. - remove legacy filter rewrites - support standard equality in tag exists predicates - add regression coverage for accepted and rejected expressions
1 parent d3f6e8e commit 0e89407

11 files changed

Lines changed: 141 additions & 89 deletions

File tree

plugin/filter/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Memo Filter Engine
22

3-
This package houses the memo-only filter engine that turns CEL expressions into
4-
SQL fragments. The engine follows a three phase pipeline inspired by systems
3+
This package houses the memo-only filter engine that turns standard CEL syntax
4+
into SQL fragments for the subset of expressions supported by the memo schema.
5+
The engine follows a three phase pipeline inspired by systems
56
such as Calcite or Prisma:
67

78
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
89
the memo-specific environment declared in `schema.go`. Only fields that
9-
exist in the schema can surface in the filter.
10+
exist in the schema can surface in the filter, and non-standard legacy
11+
coercions are rejected.
1012
2. **Normalization** – the raw CEL AST is converted into an intermediate
1113
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
1214
conditions (logical operators, comparisons, list membership, etc.). This

plugin/filter/engine.go

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package filter
22

33
import (
44
"context"
5-
"fmt"
65
"strings"
76
"sync"
87

@@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
4544
return nil, errors.New("filter expression is empty")
4645
}
4746

48-
filter = normalizeLegacyFilter(filter)
49-
5047
ast, issues := e.env.Compile(filter)
5148
if issues != nil && issues.Err() != nil {
5249
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
@@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) {
119116
})
120117
return defaultAttachmentInst, defaultAttachmentErr
121118
}
122-
123-
func normalizeLegacyFilter(expr string) string {
124-
expr = rewriteNumericLogicalOperand(expr, "&&")
125-
expr = rewriteNumericLogicalOperand(expr, "||")
126-
return expr
127-
}
128-
129-
func rewriteNumericLogicalOperand(expr, op string) string {
130-
var builder strings.Builder
131-
n := len(expr)
132-
i := 0
133-
var inQuote rune
134-
135-
for i < n {
136-
ch := expr[i]
137-
138-
if inQuote != 0 {
139-
builder.WriteByte(ch)
140-
if ch == '\\' && i+1 < n {
141-
builder.WriteByte(expr[i+1])
142-
i += 2
143-
continue
144-
}
145-
if ch == byte(inQuote) {
146-
inQuote = 0
147-
}
148-
i++
149-
continue
150-
}
151-
152-
if ch == '\'' || ch == '"' {
153-
inQuote = rune(ch)
154-
builder.WriteByte(ch)
155-
i++
156-
continue
157-
}
158-
159-
if strings.HasPrefix(expr[i:], op) {
160-
builder.WriteString(op)
161-
i += len(op)
162-
163-
// Preserve whitespace following the operator.
164-
wsStart := i
165-
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
166-
i++
167-
}
168-
builder.WriteString(expr[wsStart:i])
169-
170-
signStart := i
171-
if i < n && (expr[i] == '+' || expr[i] == '-') {
172-
i++
173-
}
174-
for i < n && expr[i] >= '0' && expr[i] <= '9' {
175-
i++
176-
}
177-
if i > signStart {
178-
numLiteral := expr[signStart:i]
179-
fmt.Fprintf(&builder, "(%s != 0)", numLiteral)
180-
} else {
181-
builder.WriteString(expr[signStart:i])
182-
}
183-
continue
184-
}
185-
186-
builder.WriteByte(ch)
187-
i++
188-
}
189-
190-
return builder.String()
191-
}

plugin/filter/engine_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package filter
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) {
11+
t.Parallel()
12+
13+
engine, err := NewEngine(NewSchema())
14+
require.NoError(t, err)
15+
16+
_, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`)
17+
require.NoError(t, err)
18+
}
19+
20+
func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) {
21+
t.Parallel()
22+
23+
engine, err := NewEngine(NewSchema())
24+
require.NoError(t, err)
25+
26+
_, err = engine.Compile(context.Background(), `pinned && 1`)
27+
require.Error(t, err)
28+
require.Contains(t, err.Error(), "failed to compile filter")
29+
}
30+
31+
func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) {
32+
t.Parallel()
33+
34+
engine, err := NewEngine(NewSchema())
35+
require.NoError(t, err)
36+
37+
_, err = engine.Compile(context.Background(), `1`)
38+
require.EqualError(t, err, "filter must evaluate to a boolean value")
39+
}

plugin/filter/ir.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,10 @@ type ContainsPredicate struct {
157157
}
158158

159159
func (*ContainsPredicate) isPredicateExpr() {}
160+
161+
// EqualsPredicate represents t == "value".
162+
type EqualsPredicate struct {
163+
Value string
164+
}
165+
166+
func (*EqualsPredicate) isPredicateExpr() {}

plugin/filter/parser.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
1616
if err != nil {
1717
return nil, err
1818
}
19-
switch v := val.(type) {
20-
case bool:
19+
if v, ok := val.(bool); ok {
2120
return &ConstantCondition{Value: v}, nil
22-
case int64:
23-
return &ConstantCondition{Value: v != 0}, nil
24-
case float64:
25-
return &ConstantCondition{Value: v != 0}, nil
26-
default:
27-
return nil, errors.New("filter must evaluate to a boolean value")
2821
}
22+
return nil, errors.New("filter must evaluate to a boolean value")
2923
case *exprv1.Expr_IdentExpr:
3024
name := v.IdentExpr.GetName()
3125
field, ok := schema.Field(name)
@@ -504,15 +498,51 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
504498

505499
// Handle different predicate functions
506500
switch predicateCall.Function {
501+
case "_==_":
502+
return buildEqualsPredicate(predicateCall, comp.IterVar)
507503
case "startsWith":
508504
return buildStartsWithPredicate(predicateCall, comp.IterVar)
509505
case "endsWith":
510506
return buildEndsWithPredicate(predicateCall, comp.IterVar)
511507
case "contains":
512508
return buildContainsPredicate(predicateCall, comp.IterVar)
513509
default:
514-
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
510+
return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function)
511+
}
512+
}
513+
514+
// buildEqualsPredicate extracts the value from t == "value".
515+
func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
516+
if len(call.Args) != 2 {
517+
return nil, errors.New("equality predicate expects exactly two arguments")
518+
}
519+
520+
var constExpr *exprv1.Expr
521+
switch {
522+
case isIterVarExpr(call.Args[0], iterVar):
523+
constExpr = call.Args[1]
524+
case isIterVarExpr(call.Args[1], iterVar):
525+
constExpr = call.Args[0]
526+
default:
527+
return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar)
528+
}
529+
530+
value, err := getConstValue(constExpr)
531+
if err != nil {
532+
return nil, errors.Wrap(err, "equality argument must be a constant string")
533+
}
534+
535+
valueStr, ok := value.(string)
536+
if !ok {
537+
return nil, errors.New("equality argument must be a string")
515538
}
539+
540+
return &EqualsPredicate{Value: valueStr}, nil
541+
}
542+
543+
func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool {
544+
target := expr.GetIdentExpr()
545+
return target != nil && target.GetName() == iterVar
516546
}
517547

518548
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").

plugin/filter/render.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
480480

481481
// Render based on predicate type
482482
switch pred := cond.Predicate.(type) {
483+
case *EqualsPredicate:
484+
return r.renderTagEquals(field, pred.Value, cond.Kind)
483485
case *StartsWithPredicate:
484486
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
485487
case *EndsWithPredicate:
@@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
491493
}
492494
}
493495

496+
// renderTagEquals generates SQL for tags.exists(t, t == "value").
497+
func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) {
498+
arrayExpr := jsonArrayExpr(r.dialect, field)
499+
500+
switch r.dialect {
501+
case DialectSQLite, DialectMySQL:
502+
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value))
503+
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
504+
case DialectPostgres:
505+
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value)))
506+
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
507+
default:
508+
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
509+
}
510+
}
511+
494512
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
495513
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
496514
arrayExpr := jsonArrayExpr(r.dialect, field)

server/router/mcp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ For Streamable HTTP safety, requests with an `Origin` header must be same-origin
4848

4949
| Tool | Description | Required params | Optional params |
5050
|---|---|---|---|
51-
| `list_memos` | List memos || `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) |
51+
| `list_memos` | List memos || `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) |
5252
| `get_memo` | Get a single memo | `name` ||
5353
| `search_memos` | Full-text search | `query` ||
5454
| `create_memo` | Create a memo | `content` | `visibility` |

server/router/mcp/access.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ func checkMemoOwnership(memo *store.Memo, userID int32) error {
4141
return nil
4242
}
4343

44+
func hasMemoOwnership(memo *store.Memo, userID int32) bool {
45+
return memo.CreatorID == userID
46+
}
47+
4448
// applyVisibilityFilter restricts find to memos the caller may see.
4549
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
4650
if rowStatus != nil && *rowStatus == store.Archived {

server/router/mcp/tools_memo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
223223
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
224224
),
225225
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
226-
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
226+
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
227227
), s.handleListMemos)
228228

229229
mcpSrv.AddTool(mcp.NewTool("get_memo",

server/router/mcp/tools_relation.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
142142
if srcMemo == nil {
143143
return mcp.NewToolResultError("source memo not found"), nil
144144
}
145-
if err := checkMemoOwnership(srcMemo, userID); err != nil {
145+
if !hasMemoOwnership(srcMemo, userID) {
146146
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
147147
}
148148

@@ -199,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
199199
if srcMemo == nil {
200200
return mcp.NewToolResultError("source memo not found"), nil
201201
}
202-
if err := checkMemoOwnership(srcMemo, userID); err != nil {
202+
if !hasMemoOwnership(srcMemo, userID) {
203203
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
204204
}
205205

0 commit comments

Comments
 (0)