Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,29 +176,25 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
return processChild(optExpression, context);
}

// rewrite
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(project.getColumnRefMap());
context.aggregations.replaceAll((k, v) -> (CallOperator) rewriter.rewrite(v));
context.groupBys.replaceAll((k, v) -> rewriter.rewrite(v));
ColumnRefSet aggUsedColumns = new ColumnRefSet();
context.aggregations.values().forEach(v -> aggUsedColumns.union(v.getUsedColumns()));

if (project.getColumnRefMap().values().stream().allMatch(ScalarOperator::isColumnRef)) {
return processChild(optExpression, context);
}
Map<ColumnRefOperator, ScalarOperator> columnRefMap = project.getColumnRefMap();
Map<ColumnRefOperator, ScalarOperator> aggRewriteMap = columnRefMap;

// handle specials functions case-when/if
// split to groupBys and mock new aggregations by values, don't need to save
// origin predicate, we just do check in collect phase
for (Map.Entry<ColumnRefOperator, CallOperator> entry : context.aggregations.entrySet()) {
CallOperator aggFn = entry.getValue();
ScalarOperator aggInput = aggFn.getChild(0);
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : columnRefMap.entrySet()) {
ColumnRefOperator key = entry.getKey();
ScalarOperator value = entry.getValue();

if (!(aggInput instanceof CallOperator)) {
if (!aggUsedColumns.contains(key) || !(value instanceof CallOperator call)) {
continue;
}

CallOperator callInput = (CallOperator) aggInput;
if (aggInput instanceof CaseWhenOperator) {
CaseWhenOperator caseWhen = (CaseWhenOperator) aggInput;
if (call instanceof CaseWhenOperator) {
CaseWhenOperator caseWhen = (CaseWhenOperator) value;
for (ScalarOperator condition : caseWhen.getAllConditionClause()) {
condition.getUsedColumns().getStream().map(factory::getColumnRef)
.forEach(v -> context.groupBys.put(v, v));
Expand All @@ -218,21 +214,40 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
CaseWhenOperator newCaseWhen = new CaseWhenOperator(caseWhen.getType(), null,
caseWhen.hasElse() ? caseWhen.getElseClause() : null, newWhenThen);

// replace origin
aggFn.setChild(0, newCaseWhen);
} else if (callInput.getFunction() != null &&
FunctionSet.IF.equals(callInput.getFunction().getFunctionName().getFunction())) {
if (aggInput.getChildren().stream().skip(1).anyMatch(c -> c.isConstant() && !c.isConstantNull())) {
if (aggRewriteMap == columnRefMap) {
aggRewriteMap = Maps.newHashMap(columnRefMap);
}
aggRewriteMap.put(key, newCaseWhen);
} else if (call.getFunction() != null &&
FunctionSet.IF.equals(call.getFunction().getFunctionName().getFunction())) {
if (call.getChildren().stream().skip(1).anyMatch(c -> c.isConstant() && !c.isConstantNull())) {
// forbidden push down
return visit(optExpression, context);
}

aggInput.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
call.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
.forEach(v -> context.groupBys.put(v, v));
aggInput.setChild(0, ConstantOperator.createBoolean(false));

CallOperator newIf = new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(call.getArguments()),
call.getFunction());
newIf.setChild(0, ConstantOperator.createBoolean(false));

if (aggRewriteMap == columnRefMap) {
aggRewriteMap = Maps.newHashMap(columnRefMap);
}
aggRewriteMap.put(key, newIf);
}
}

ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(aggRewriteMap);
context.aggregations.replaceAll((k, v) -> (CallOperator) rewriter.rewrite(v));
if (aggRewriteMap != columnRefMap) {
ReplaceColumnRefRewriter originalRewriter = new ReplaceColumnRefRewriter(columnRefMap);
context.groupBys.replaceAll((k, v) -> originalRewriter.rewrite(v));
} else {
context.groupBys.replaceAll((k, v) -> rewriter.rewrite(v));
}

// check has constant aggregate, forbidden
if (!context.aggregations.isEmpty() &&
context.aggregations.values().stream().allMatch(ScalarOperator::isConstant)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ private void rewriteProject(AggregatePushDownContext context,
&& FunctionSet.IF.equals(((CallOperator) aggExpr).getFunction().getFunctionName().getFunction());

if (isCaseWhen) {
CaseWhenOperator caseWhen = (CaseWhenOperator) aggExpr;
// Clone to avoid mutating the shared object in originProjectMap/project's columnRefMap.
// Without clone, when multiple aggregations reference the same CASE WHEN column,
// the first aggregation's setThenClause/setElseClause corrupts the shared operator.
CaseWhenOperator caseWhen = (CaseWhenOperator) aggExpr.clone();
for (ScalarOperator condition : caseWhen.getAllConditionClause()) {
condition.getUsedColumns().getStream().map(factory::getColumnRef)
.forEach(v -> context.groupBys.put(v, v));
Expand Down Expand Up @@ -221,7 +224,8 @@ private void rewriteProject(AggregatePushDownContext context,
context.aggregations.remove(key);
originProjectMap.put(key, new CaseWhenOperator(key.getType(), caseWhen));
} else if (isIfFn) {
CallOperator ifFn = (CallOperator) aggExpr;
// Clone to avoid mutating the shared object (same reason as CaseWhen above).
CallOperator ifFn = (CallOperator) aggExpr.clone();
ifFn.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
.forEach(v -> context.groupBys.put(v, v));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ public void testPushDownPreAggEnableOnBroadcastJoin() {
}
}


@Test
public void testPushDownDistinctAggBelowWindow()
throws Exception {
Expand Down Expand Up @@ -294,4 +293,91 @@ public void testPruneDistinctWindow() throws Exception {
" args: DECIMAL128; result: DECIMAL128(38,2); args nullable: true; result nullable: true], ]");
assertContains(plan, "2:AGGREGATE (update finalize)");
}

@Test
public void testPushDownWithNestedCaseWhenIfs() throws Exception {
String sql = """
WITH cte1 AS (
SELECT
t.t1d AS fk,
t.t1a AS cat,
CASE WHEN t.t1b = 1 THEN t.t1e ELSE t.t1f END AS cval
FROM test_all_type t
),
cte2 AS (
SELECT a.cval, a.fk, a.cat
FROM cte1 a
LEFT JOIN t1 ON a.fk = t1.v4
),
cte3 AS (
SELECT CASE WHEN c.cat THEN c.cval ELSE NULL END gval, c.fk
FROM cte2 c
)
SELECT SUM(gval)
FROM cte3
GROUP BY fk;
""";
String plan = getVerboseExplain(sql);
assertContains(plan, " 2:AGGREGATE (update finalize)\n" +
" | aggregate: sum[([21: cast, DOUBLE, true]); args: DOUBLE; result: DOUBLE; args nullable: true; result" +
" nullable: true], sum[([6: t1f, DOUBLE, true]); args: DOUBLE; result: DOUBLE; args nullable: true; result" +
" nullable: true]\n" +
" | group by: [1: t1a, VARCHAR, true], [2: t1b, SMALLINT, true], [4: t1d, BIGINT, true]\n" +
" | cardinality: 1\n" +
" | \n" +
" 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: t1a, VARCHAR, true]\n" +
" | 2 <-> [2: t1b, SMALLINT, true]\n" +
" | 4 <-> [4: t1d, BIGINT, true]\n" +
" | 6 <-> [6: t1f, DOUBLE, true]\n" +
" | 21 <-> cast([5: t1e, FLOAT, true] as DOUBLE)\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode\n" +
" table: test_all_type, rollup: test_all_type\n" +
" preAggregation: on\n" +
" partitionsRatio=1/1, tabletsRatio=3/3\n" +
" tabletList=10140,10142,10144\n" +
" actualRows=0, avgRowSize=6.0\n" +
" cardinality: 1");

}

@Test
public void testRewriterSharedMutationWithCaseWhen() throws Exception {
// Bug: PushDownAggregateRewriter.rewriteProject() mutates shared CaseWhenOperator
// in-place via setThenClause(). When two aggregations (SUM + MIN) reference the same
// CASE WHEN column, the first aggregation's processing corrupts the CaseWhenOperator,
// causing the second aggregation to see pushed-down column refs instead of original columns.
String sql = "SELECT SUM(sub.cval), MIN(sub.cval), sub.fk " +
"FROM ( " +
" SELECT t1d AS fk, " +
" CASE WHEN t1b = 1 THEN t1e ELSE NULL END AS cval " +
" FROM test_all_type " +
") sub " +
"JOIN t0 ON sub.fk = t0.v1 " +
"GROUP BY sub.fk";
String plan = getVerboseExplain(sql);

assertContains(plan, "sum");
assertContains(plan, "min");
}

@Test
public void testRewriterSharedMutationWithIf() throws Exception {
// Bug: PushDownAggregateRewriter.rewriteProject() mutates shared CallOperator (IF)
// in-place via setChild(). Same root cause as the CaseWhen bug but on the IF path.
String sql = "SELECT SUM(sub.cval), MIN(sub.cval), sub.fk " +
"FROM ( " +
" SELECT t1d AS fk, " +
" IF(t1b = 1, t1e, NULL) AS cval " +
" FROM test_all_type " +
") sub " +
"JOIN t0 ON sub.fk = t0.v1 " +
"GROUP BY sub.fk";
String plan = getVerboseExplain(sql);
assertContains(plan, "sum");
assertContains(plan, "min");
}
}
Loading