Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.annotations.VisibleForTesting;
import io.trino.operator.HashGenerator;
import io.trino.operator.JoinOperatorType;
import io.trino.operator.NullSafeHashCompiler;
Expand All @@ -37,6 +38,7 @@
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -59,7 +61,7 @@ public class LookupJoinOperatorFactory
private final JoinProbeFactory joinProbeFactory;
private final Optional<OperatorFactory> outerOperatorFactory;
private final JoinBridgeManager<? extends LookupSourceFactory> joinBridgeManager;
private final OptionalInt totalOperatorsCount;
private final AtomicReference<OptionalInt> totalOperatorsCount;
private final HashGenerator probeHashGenerator;
private final PartitioningSpillerFactory partitioningSpillerFactory;

Expand Down Expand Up @@ -102,7 +104,7 @@ public LookupJoinOperatorFactory(
buildOutputTypes,
lookupSourceFactoryManager));
}
this.totalOperatorsCount = requireNonNull(totalOperatorsCount, "totalOperatorsCount is null");
this.totalOperatorsCount = new AtomicReference<>(requireNonNull(totalOperatorsCount, "totalOperatorsCount is null"));

requireNonNull(probeJoinChannels, "probeJoinChannels is null");
List<Type> hashTypes = probeJoinChannels.stream()
Expand Down Expand Up @@ -176,7 +178,7 @@ public WorkProcessorOperator create(ProcessorContext processorContext, WorkProce
lookupSourceFactory,
joinProbeFactory,
joinBridgeManager::probeOperatorClosed,
totalOperatorsCount,
totalOperatorsCount.get(),
probeHashGenerator,
partitioningSpillerFactory,
processorContext,
Expand All @@ -196,4 +198,16 @@ public LookupJoinOperatorFactory duplicate()
{
return new LookupJoinOperatorFactory(this);
}

public void incrementTotalOperatorsCount()
{
totalOperatorsCount.updateAndGet(current ->
current.isPresent() ? OptionalInt.of(current.getAsInt() + 1) : current);
}

@VisibleForTesting
public OptionalInt getTotalOperatorsCount()
{
return totalOperatorsCount.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,17 @@ private void addLookupOuterDrivers(boolean isOutputDriver, List<OperatorFactory>
operatorFactories.subList(i + 1, operatorFactories.size()).stream()
.map(OperatorFactory::duplicate)
.forEach(newOperators::add);

// each duplicated spilling join operator in the outer driver pipeline will call.
// finishProbeOperator, which requires incrementing the totalOperatorsCount of the corresponding original factory by 1.
// ensure that PartitionedLookupSourceFactory waits for the correct number of operators to complete before releasing resources.
operatorFactories.subList(i + 1, operatorFactories.size()).stream()
.filter(WorkProcessorOperatorAdapter.Factory.class::isInstance)
.map(WorkProcessorOperatorAdapter.Factory.class::cast)
.map(WorkProcessorOperatorAdapter.Factory::getWorkProcessorOperatorFactory)
.filter(LookupJoinOperatorFactory.class::isInstance)
.map(LookupJoinOperatorFactory.class::cast)
.forEach(LookupJoinOperatorFactory::incrementTotalOperatorsCount);

addDriverFactory(false, isOutputDriver, newOperators.build(), OptionalInt.of(1));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1801,4 +1801,81 @@ private static <T> List<T> concat(List<T> initialElements, List<T> moreElements)
{
return ImmutableList.copyOf(Iterables.concat(initialElements, moreElements));
}


/**
* Regression test for the bug where {@code addLookupOuterDrivers} duplicates a spilling
* {@link LookupJoinOperatorFactory} into the outer driver pipeline without incrementing
* {@code totalOperatorsCount}. This caused two failures in fault-tolerant execution when
* {@code AdaptiveReorderPartitionedJoin} flipped a LEFT JOIN to a RIGHT JOIN (LOOKUP_OUTER)
* and the probe pipeline contained another spilling join downstream:
*
* <ul>
* <li>{@code IllegalStateException}: "N+1 probe operators finished out of N declared"
* <li>{@code NullPointerException}: {@code lookupSourceSupplier} is null inside
* {@code SpillAwareLookupSourceProvider.withLease}
* </ul>
*
* The fix stores {@code totalOperatorsCount} in a shared {@link java.util.concurrent.atomic.AtomicReference}
* so that {@code incrementTotalOperatorsCount()} is visible to both the original factory and
* every duplicated factory.
*/
@Test
public void testIncrementTotalOperatorsCountOnDuplicate()
{
int driverCount = 4;
TaskContext taskContext = createTaskContext();

// Build side setup
RowPagesBuilder buildPages = rowPagesBuilder(true, Ints.asList(0), ImmutableList.of(VARCHAR));
BuildSideSetup buildSideSetup = setupBuildSide(
nodePartitioningManager,
false,
taskContext,
buildPages,
Optional.empty(),
false,
SINGLE_STREAM_SPILLER_FACTORY);
JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager =
buildSideSetup.getLookupSourceFactoryManager();

// Create the spilling join factory with totalOperatorsCount = driverCount
RowPagesBuilder probePages = rowPagesBuilder(true, Ints.asList(0), ImmutableList.of(VARCHAR));
OperatorFactory joinOperatorFactory = spillingJoin(
innerJoin(false, false),
0,
new PlanNodeId("test"),
lookupSourceFactoryManager,
probePages.getTypes(),
Ints.asList(0),
getHashChannelAsInt(probePages),
Optional.empty(),
OptionalInt.of(driverCount),
PARTITIONING_SPILLER_FACTORY,
TYPE_OPERATORS);

// Simulate addLookupOuterDrivers: duplicate then increment.
// Before the fix, the duplicate held the stale value N because OptionalInt was copied by
// value and incrementTotalOperatorsCount only updated the original. After the fix both
// share an AtomicReference so both see N+1.
OperatorFactory duplicatedFactory = joinOperatorFactory.duplicate();

LookupJoinOperatorFactory originalInner = (LookupJoinOperatorFactory)
((io.trino.operator.WorkProcessorOperatorAdapter.Factory) joinOperatorFactory)
.getWorkProcessorOperatorFactory();
LookupJoinOperatorFactory duplicatedInner = (LookupJoinOperatorFactory)
((io.trino.operator.WorkProcessorOperatorAdapter.Factory) duplicatedFactory)
.getWorkProcessorOperatorFactory();

// Before increment: both see driverCount
assertThat(originalInner.getTotalOperatorsCount()).isEqualTo(OptionalInt.of(driverCount));
assertThat(duplicatedInner.getTotalOperatorsCount()).isEqualTo(OptionalInt.of(driverCount));

// After increment: BOTH must see driverCount + 1 (shared AtomicReference)
originalInner.incrementTotalOperatorsCount();
assertThat(originalInner.getTotalOperatorsCount()).isEqualTo(OptionalInt.of(driverCount + 1));
assertThat(duplicatedInner.getTotalOperatorsCount())
.as("duplicated factory must see the incremented count via shared AtomicReference")
.isEqualTo(OptionalInt.of(driverCount + 1));
}
}
Loading