Skip to content

Commit 4d6fb1d

Browse files
Fix non-Var GROUP BY expression decomposition in distributed aggregates (#8527)
DESCRIPTION: process non-Var GROUP BY expressions correctly in GROUP BY on distributed tables.. When Citus distributes a query whose target list mixes a non-Var GROUP BY expression (e.g., f(col)) with an aggregate (e.g., sum(val)), WorkerAggregateWalker would recurse into the GROUP BY expression and decompose it into bare Var nodes. The worker query then contained a bare column reference not in GROUP BY, causing PostgreSQL to reject it. Add non-Var GROUP BY expression recognition to both WorkerAggregateWalker and MasterAggregateMutator. Each walker now checks incoming nodes against the groupByTargetEntryList via equal() and emits matched expressions as atomic units rather than recursing into their children. A hasNonVarGrouping fast-path flag avoids the matching loop when all GROUP BY expressions are simple column references. Co-authored-by: Colm <colm.mchugh@gmail.com>
1 parent 029f381 commit 4d6fb1d

File tree

3 files changed

+543
-2
lines changed

3 files changed

+543
-2
lines changed

src/backend/distributed/planner/multi_logical_optimizer.c

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ typedef struct MasterAggregateWalkerContext
7373
{
7474
const ExtendedOpNodeProperties *extendedOpNodeProperties;
7575
AttrNumber columnId;
76+
List *groupByTargetEntryList;
77+
bool haveNonVarGrouping;
7678
} MasterAggregateWalkerContext;
7779

7880
typedef struct WorkerAggregateWalkerContext
7981
{
8082
const ExtendedOpNodeProperties *extendedOpNodeProperties;
8183
List *expressionList;
8284
bool createGroupByClause;
85+
List *groupByTargetEntryList;
86+
bool haveNonVarGrouping;
8387
} WorkerAggregateWalkerContext;
8488

8589

@@ -227,11 +231,14 @@ static MultiExtendedOp * WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
227231
static void ProcessTargetListForWorkerQuery(List *targetEntryList,
228232
ExtendedOpNodeProperties *
229233
extendedOpNodeProperties,
234+
List *groupClauseList,
230235
QueryTargetList *queryTargetList,
231236
QueryGroupClause *queryGroupClause);
232237
static void ProcessHavingClauseForWorkerQuery(Node *havingQual,
233238
ExtendedOpNodeProperties *
234239
extendedOpNodeProperties,
240+
List *groupClauseList,
241+
List *targetEntryList,
235242
Node **workerHavingQual,
236243
QueryTargetList *queryTargetList,
237244
QueryGroupClause *queryGroupClause);
@@ -324,6 +331,7 @@ static List * WorkerSortClauseList(Node *limitCount,
324331
List *groupClauseList, List *sortClauseList,
325332
OrderByLimitReference orderByLimitReference);
326333
static bool CanPushDownLimitApproximate(List *sortClauseList, List *targetList);
334+
static bool HaveNonVarGrouping(List *groupByTargetEntryList);
327335
static bool HasOrderByAggregate(List *sortClauseList, List *targetList);
328336
static bool HasOrderByNonCommutativeAggregate(List *sortClauseList, List *targetList);
329337
static bool HasOrderByComplexExpression(List *sortClauseList, List *targetList);
@@ -1424,9 +1432,22 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode,
14241432
List *newGroupClauseList = NIL;
14251433
Node *originalHavingQual = originalOpNode->havingQual;
14261434
Node *newHavingQual = NULL;
1435+
1436+
/*
1437+
* Build GROUP BY target entry list for the master-side mutator so it
1438+
* can recognize GROUP BY subexpressions and map them to a single
1439+
* worker output column instead of recursing into their children.
1440+
* This must match what WorkerAggregateWalker does on the worker side.
1441+
*/
1442+
List *groupByTargetEntryList = GroupTargetEntryList(
1443+
originalOpNode->groupClauseList, targetEntryList);
1444+
bool haveNonVarGrouping = HaveNonVarGrouping(groupByTargetEntryList);
1445+
14271446
MasterAggregateWalkerContext walkerContext = {
14281447
.extendedOpNodeProperties = extendedOpNodeProperties,
14291448
.columnId = 1,
1449+
.groupByTargetEntryList = groupByTargetEntryList,
1450+
.haveNonVarGrouping = haveNonVarGrouping,
14301451
};
14311452

14321453
/* iterate over original target entries */
@@ -1573,6 +1594,34 @@ MasterAggregateMutator(Node *originalNode, MasterAggregateWalkerContext *walkerC
15731594
}
15741595
else
15751596
{
1597+
/*
1598+
* If the current node matches a non-Var GROUP BY expression, map it
1599+
* to a single worker output column reference. The worker emits this
1600+
* expression intact, so the master must consume exactly one column
1601+
* for it rather than recursing into its children.
1602+
*/
1603+
if (walkerContext->haveNonVarGrouping)
1604+
{
1605+
TargetEntry *groupByTargetEntry = NULL;
1606+
foreach_declared_ptr(groupByTargetEntry,
1607+
walkerContext->groupByTargetEntryList)
1608+
{
1609+
if (equal(originalNode, groupByTargetEntry->expr))
1610+
{
1611+
Oid nodeType = exprType(originalNode);
1612+
int32 nodeTypmod = exprTypmod(originalNode);
1613+
Oid nodeColl = exprCollation(originalNode);
1614+
Var *column = makeVar(masterTableId,
1615+
walkerContext->columnId,
1616+
nodeType, nodeTypmod,
1617+
nodeColl, 0);
1618+
walkerContext->columnId++;
1619+
newNode = (Node *) column;
1620+
return newNode;
1621+
}
1622+
}
1623+
}
1624+
15761625
newNode = expression_tree_mutator(originalNode, MasterAggregateMutator,
15771626
(void *) walkerContext);
15781627
}
@@ -2407,9 +2456,12 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
24072456

24082457
/* process each part of the query in order to generate the worker query's parts */
24092458
ProcessTargetListForWorkerQuery(originalTargetEntryList, extendedOpNodeProperties,
2459+
originalGroupClauseList,
24102460
&queryTargetList, &queryGroupClause);
24112461

24122462
ProcessHavingClauseForWorkerQuery(originalHavingQual, extendedOpNodeProperties,
2463+
originalGroupClauseList,
2464+
originalTargetEntryList,
24132465
&queryHavingQual, &queryTargetList,
24142466
&queryGroupClause);
24152467

@@ -2522,17 +2574,30 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
25222574
* list of worker extended operator. This approach guarantees the distinctness
25232575
* in the worker queries.
25242576
*
2525-
* inputs: targetEntryList, extendedOpNodeProperties
2577+
* inputs: targetEntryList, extendedOpNodeProperties, groupClauseList
25262578
* outputs: queryTargetList, queryGroupClause
25272579
*/
25282580
static void
25292581
ProcessTargetListForWorkerQuery(List *targetEntryList,
25302582
ExtendedOpNodeProperties *extendedOpNodeProperties,
2583+
List *groupClauseList,
25312584
QueryTargetList *queryTargetList,
25322585
QueryGroupClause *queryGroupClause)
25332586
{
2587+
/*
2588+
* Build the list of GROUP BY target entries and check whether any are
2589+
* non-Var expressions. WorkerAggregateWalker needs this so it can
2590+
* recognize GROUP BY subexpressions inside complex target entries and
2591+
* emit them intact instead of decomposing them into bare Var references.
2592+
*/
2593+
List *groupByTargetEntryList = GroupTargetEntryList(
2594+
groupClauseList, targetEntryList);
2595+
bool haveNonVarGrouping = HaveNonVarGrouping(groupByTargetEntryList);
2596+
25342597
WorkerAggregateWalkerContext workerAggContext = {
25352598
.extendedOpNodeProperties = extendedOpNodeProperties,
2599+
.groupByTargetEntryList = groupByTargetEntryList,
2600+
.haveNonVarGrouping = haveNonVarGrouping,
25362601
};
25372602

25382603
/* iterate over original target entries */
@@ -2578,12 +2643,14 @@ ProcessTargetListForWorkerQuery(List *targetEntryList,
25782643
* having clause is safe to pushdown to the workers, workerHavingQual is set to
25792644
* be the original having clause.
25802645
*
2581-
* inputs: originalHavingQual, extendedOpNodeProperties
2646+
* inputs: originalHavingQual, extendedOpNodeProperties, groupClauseList, targetEntryList
25822647
* outputs: workerHavingQual, queryTargetList, queryGroupClause
25832648
*/
25842649
static void
25852650
ProcessHavingClauseForWorkerQuery(Node *originalHavingQual,
25862651
ExtendedOpNodeProperties *extendedOpNodeProperties,
2652+
List *groupClauseList,
2653+
List *targetEntryList,
25872654
Node **workerHavingQual,
25882655
QueryTargetList *queryTargetList,
25892656
QueryGroupClause *queryGroupClause)
@@ -2619,8 +2686,12 @@ ProcessHavingClauseForWorkerQuery(Node *originalHavingQual,
26192686
* If the GROUP BY or PARTITION BY is not on the distribution column
26202687
* then we need to combine the aggregates in the HAVING across shards.
26212688
*/
2689+
List *groupByTargetEntryList = GroupTargetEntryList(
2690+
groupClauseList, targetEntryList);
26222691
WorkerAggregateWalkerContext workerAggContext = {
26232692
.extendedOpNodeProperties = extendedOpNodeProperties,
2693+
.groupByTargetEntryList = groupByTargetEntryList,
2694+
.haveNonVarGrouping = HaveNonVarGrouping(groupByTargetEntryList),
26242695
};
26252696

26262697
WorkerAggregateWalker(originalHavingQual, &workerAggContext);
@@ -3047,6 +3118,29 @@ WorkerAggregateWalker(Node *node, WorkerAggregateWalkerContext *walkerContext)
30473118
}
30483119
else
30493120
{
3121+
/*
3122+
* If the GROUP BY contains non-Var expressions, check whether the
3123+
* current node matches one of them. If so, emit it as-is rather
3124+
* than descending into its children. Without this, a GROUP BY
3125+
* expression like f(col) would be decomposed into the bare col
3126+
* Var, which then fails on the worker because col is not in the
3127+
* GROUP BY clause.
3128+
*/
3129+
if (walkerContext->haveNonVarGrouping)
3130+
{
3131+
TargetEntry *groupByTargetEntry = NULL;
3132+
foreach_declared_ptr(groupByTargetEntry,
3133+
walkerContext->groupByTargetEntryList)
3134+
{
3135+
if (equal(node, groupByTargetEntry->expr))
3136+
{
3137+
walkerContext->expressionList =
3138+
lappend(walkerContext->expressionList, node);
3139+
return false;
3140+
}
3141+
}
3142+
}
3143+
30503144
walkerResult = expression_tree_walker(node, WorkerAggregateWalker,
30513145
(void *) walkerContext);
30523146
}
@@ -4585,6 +4679,25 @@ SubqueryMultiTableList(MultiNode *multiNode)
45854679
}
45864680

45874681

4682+
/*
4683+
* HaveNonVarGrouping returns true if any GROUP BY expression in the given
4684+
* target entry list is not a simple Var (column reference).
4685+
*/
4686+
static bool
4687+
HaveNonVarGrouping(List *groupByTargetEntryList)
4688+
{
4689+
TargetEntry *gte = NULL;
4690+
foreach_declared_ptr(gte, groupByTargetEntryList)
4691+
{
4692+
if (!IsA(gte->expr, Var))
4693+
{
4694+
return true;
4695+
}
4696+
}
4697+
return false;
4698+
}
4699+
4700+
45884701
/*
45894702
* GroupTargetEntryList walks over group clauses in the given list, finds
45904703
* matching target entries and return them in a new list.

0 commit comments

Comments
 (0)