Skip to content

Commit 533a94b

Browse files
fix: async engine side-effect column propagation and collision resolution (#509)
* fix: async engine side-effect column propagation and collision resolution ExecutionGraph.set_side_effect() now uses first-writer-wins instead of last-writer-wins, matching sync engine semantics where earlier consumers see the first producer's value. This prevents false DAGCircularDependencyError when multiple generators declare the same side-effect column at different pipeline stages. AsyncTaskScheduler now includes side-effect columns in _instance_to_columns so their values are written to the RowGroupBufferManager and available to downstream prompt templates. Fixes #508 * fix: separate side-effect columns from completion tracking in async scheduler Side-effect columns added to _instance_to_columns caused KeyError in CompletionTracker._validate_strategy() because they are not registered in the execution graph. Split into _instance_to_write_columns (buffer writes, includes side-effects) and _instance_to_columns (completion tracking, real columns only). * fix: warn on side-effect collision and clarify scheduler column maps Log a warning when multiple producers register the same side-effect column (first-writer-wins still applies). Rename _instance_to_columns and _instance_to_write_columns per review feedback for clarity. * fix: raise ConfigCompilationError on duplicate side-effect producers Replace first-writer-wins collision handling with a hard error. Each side-effect column must have exactly one producer; duplicates are a configuration issue to be fixed at the source. * fix: reject duplicate side-effect producers in sync DAG path Mirror the async path check: raise ConfigCompilationError when two custom columns declare the same side-effect column name during topological sort.
1 parent aee3d3f commit 533a94b

6 files changed

Lines changed: 142 additions & 22 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,24 @@ def __init__(
131131
self._disable_early_shutdown = disable_early_shutdown
132132
self._early_shutdown = False
133133

134-
# Multi-column dedup: group output columns by generator identity
135-
instance_to_columns: dict[int, list[str]] = {}
134+
# Multi-column dedup: group output columns by generator identity.
135+
# _gen_instance_to_columns holds only real (graph-registered) columns
136+
# and is used for completion tracking.
137+
# _gen_instance_to_columns_including_side_effects extends that with
138+
# side-effect columns for buffer writes only.
139+
gen_instance_to_columns: dict[int, list[str]] = {}
136140
for col, gen in generators.items():
137-
instance_to_columns.setdefault(id(gen), []).append(col)
138-
self._instance_to_columns = instance_to_columns
141+
gen_instance_to_columns.setdefault(id(gen), []).append(col)
142+
self._gen_instance_to_columns = gen_instance_to_columns
143+
144+
seen_cols: set[str] = {col for col in generators}
145+
gen_instance_to_columns_incl_se: dict[int, list[str]] = {k: list(v) for k, v in gen_instance_to_columns.items()}
146+
for col, gen in generators.items():
147+
for side_effect_col in getattr(gen.config, "side_effect_columns", []):
148+
if side_effect_col not in seen_cols:
149+
gen_instance_to_columns_incl_se.setdefault(id(gen), []).append(side_effect_col)
150+
seen_cols.add(side_effect_col)
151+
self._gen_instance_to_columns_including_side_effects = gen_instance_to_columns_incl_se
139152

140153
# Stateful generator tracking: instance_id → asyncio.Lock
141154
self._stateful_locks: dict[int, asyncio.Lock] = {}
@@ -356,7 +369,7 @@ async def _salvage_rounds(
356369
self._dispatched.discard(
357370
Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch")
358371
)
359-
for sibling in self._instance_to_columns.get(gid, []):
372+
for sibling in self._gen_instance_to_columns.get(gid, []):
360373
if sibling != task.column:
361374
self._dispatched.discard(
362375
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
@@ -377,7 +390,7 @@ async def _salvage_rounds(
377390
)
378391
# Re-mark sibling columns as dispatched to mirror _dispatch_seeds
379392
# and prevent _drain_frontier from re-dispatching them.
380-
for sibling in self._instance_to_columns.get(gid, []):
393+
for sibling in self._gen_instance_to_columns.get(gid, []):
381394
if sibling != task.column:
382395
self._dispatched.add(
383396
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
@@ -620,7 +633,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
620633
self._dispatched.add(task)
621634
self._dispatched.add(batch_alias)
622635
# Also mark all sibling output columns as dispatched (multi-column dedup)
623-
for sibling_col in self._instance_to_columns.get(gid, []):
636+
for sibling_col in self._gen_instance_to_columns.get(gid, []):
624637
if sibling_col != col:
625638
self._dispatched.add(
626639
Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="from_scratch")
@@ -665,7 +678,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
665678
trace.dispatched_at = time.perf_counter()
666679

667680
generator = self._generators[task.column]
668-
output_cols = self._instance_to_columns.get(id(generator), [task.column])
681+
output_cols = self._gen_instance_to_columns.get(id(generator), [task.column])
669682
retryable = False
670683
# When True, skip removing from _dispatched so the task isn't re-dispatched
671684
# from the frontier (it was never completed, so it stays in the frontier).
@@ -765,10 +778,10 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any
765778
else:
766779
result_df = await generator.agenerate(lazy.pd.DataFrame())
767780

768-
# Write results to buffer
781+
# Write results to buffer (include side-effect columns)
769782
if self._buffer_manager is not None:
770-
output_cols = self._instance_to_columns.get(id(generator), [task.column])
771-
for col in output_cols:
783+
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
784+
for col in write_cols:
772785
if col in result_df.columns:
773786
values = result_df[col].tolist()
774787
self._buffer_manager.update_batch(task.row_group, col, values)
@@ -791,10 +804,10 @@ async def _run_cell(self, task: Task, generator: ColumnGenerator) -> Any:
791804

792805
result = await generator.agenerate(row_data)
793806

794-
# Write back to buffer
807+
# Write back to buffer (include side-effect columns)
795808
if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index):
796-
output_cols = self._instance_to_columns.get(id(generator), [task.column])
797-
for col in output_cols:
809+
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
810+
for col in write_cols:
798811
if col in result:
799812
self._buffer_manager.update_cell(task.row_group, task.row_index, col, result[col])
800813

@@ -815,9 +828,9 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
815828

816829
result_df = await generator.agenerate(batch_df)
817830

818-
# Merge result columns back to buffer
831+
# Merge result columns back to buffer (include side-effect columns)
819832
if self._buffer_manager is not None:
820-
output_cols = self._instance_to_columns.get(id(generator), [task.column])
833+
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
821834
active_rows = rg_size - len(pre_dropped)
822835
if len(result_df) != active_rows:
823836
raise ValueError(
@@ -830,7 +843,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
830843
continue
831844
# Skip writing to rows dropped by concurrent tasks during the await
832845
if not self._buffer_manager.is_dropped(task.row_group, ri):
833-
for col in output_cols:
846+
for col in write_cols:
834847
if col in result_df.columns:
835848
self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col])
836849
result_idx += 1

packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dag.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import data_designer.lazy_heavy_imports as lazy
99
from data_designer.config.column_types import ColumnConfigT
1010
from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
11-
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
11+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
1212
from data_designer.logging import LOG_INDENT
1313

1414
logger = logging.getLogger(__name__)
@@ -29,6 +29,18 @@ def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> li
2929

3030
side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()}
3131

32+
side_effect_to_producer: dict[str, str] = {}
33+
for producer, cols in side_effect_dict.items():
34+
for col in cols:
35+
existing = side_effect_to_producer.get(col)
36+
if existing is not None and existing != producer:
37+
raise ConfigCompilationError(
38+
f"Side-effect column {col!r} is already produced by {existing!r}; "
39+
f"cannot register a second producer {producer!r}. "
40+
f"Use distinct side-effect column names for each pipeline stage."
41+
)
42+
side_effect_to_producer[col] = producer
43+
3244
logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
3345
for name, col in dag_column_config_dict.items():
3446
dag.add_node(name)

packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import logging
67
import math
78
from collections import deque
89

@@ -11,9 +12,11 @@
1112
DatasetBuilderColumnConfigT,
1213
MultiColumnConfig,
1314
)
14-
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
15+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
1516
from data_designer.engine.dataset_builders.utils.task_model import SliceRef
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class ExecutionGraph:
1922
"""Column-level static execution graph built from column configs.
@@ -105,7 +108,19 @@ def add_edge(self, upstream: str, downstream: str) -> None:
105108
self._downstream.setdefault(upstream, set()).add(downstream)
106109

107110
def set_side_effect(self, side_effect_col: str, producer: str) -> None:
108-
"""Map a side-effect column name to its producing column."""
111+
"""Map a side-effect column name to its producing column.
112+
113+
Each side-effect column must have exactly one producer. Duplicate
114+
registrations from a different producer are a configuration error -
115+
use distinct column names for each pipeline stage instead.
116+
"""
117+
existing = self._side_effect_map.get(side_effect_col)
118+
if existing is not None and existing != producer:
119+
raise ConfigCompilationError(
120+
f"Side-effect column {side_effect_col!r} is already produced by {existing!r}; "
121+
f"cannot register a second producer {producer!r}. "
122+
f"Use distinct side-effect column names for each pipeline stage."
123+
)
109124
self._side_effect_map[side_effect_col] = producer
110125

111126
def resolve_side_effect(self, column: str) -> str:

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,51 @@ async def test_scheduler_rg_semaphore_deadlock_with_transient_failures() -> None
14321432
assert tracker.is_row_group_complete(1, 2, ["seed", "col"])
14331433

14341434

1435+
def test_side_effect_columns_separated_from_completion_tracking() -> None:
1436+
"""Side-effect columns must appear in _gen_instance_to_columns_including_side_effects
1437+
(buffer writes) but NOT in _gen_instance_to_columns (completion tracking), because
1438+
they are not registered in the execution graph and would cause KeyError in
1439+
CompletionTracker.
1440+
"""
1441+
graph = ExecutionGraph()
1442+
graph.add_column("seed", GenerationStrategy.FULL_COLUMN)
1443+
graph.add_column("primary", GenerationStrategy.CELL_BY_CELL)
1444+
graph.add_edge(upstream="seed", downstream="primary")
1445+
1446+
row_groups = [(0, 2)]
1447+
tracker = CompletionTracker.with_graph(graph, row_groups)
1448+
1449+
provider = _mock_provider()
1450+
seed_gen = MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)
1451+
cell_gen = MockCellGenerator(config=_expr_config("primary"), resource_provider=provider)
1452+
# Replace the config with a mock that reports side-effect columns.
1453+
mock_config = MagicMock()
1454+
mock_config.side_effect_columns = ["side_a", "side_b"]
1455+
object.__setattr__(cell_gen, "_config", mock_config)
1456+
1457+
generators: dict[str, ColumnGenerator] = {"seed": seed_gen, "primary": cell_gen}
1458+
1459+
scheduler = AsyncTaskScheduler(
1460+
generators=generators,
1461+
graph=graph,
1462+
tracker=tracker,
1463+
row_groups=row_groups,
1464+
)
1465+
1466+
cell_id = id(cell_gen)
1467+
1468+
# Completion tracking dict: only real columns
1469+
assert "side_a" not in scheduler._gen_instance_to_columns.get(cell_id, [])
1470+
assert "side_b" not in scheduler._gen_instance_to_columns.get(cell_id, [])
1471+
assert "primary" in scheduler._gen_instance_to_columns.get(cell_id, [])
1472+
1473+
# Buffer write dict: includes side-effect columns
1474+
write_cols = scheduler._gen_instance_to_columns_including_side_effects.get(cell_id, [])
1475+
assert "primary" in write_cols
1476+
assert "side_a" in write_cols
1477+
assert "side_b" in write_cols
1478+
1479+
14351480
# -- TrackingSemaphore tests ---------------------------------------------------
14361481

14371482

packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dag.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from typing import Any
5+
46
import pytest
57

68
from data_designer.config.column_configs import (
9+
CustomColumnConfig,
710
ExpressionColumnConfig,
811
LLMCodeColumnConfig,
912
LLMJudgeColumnConfig,
@@ -13,12 +16,13 @@
1316
ValidationColumnConfig,
1417
)
1518
from data_designer.config.column_types import DataDesignerColumnType
19+
from data_designer.config.custom_column import custom_column_generator
1620
from data_designer.config.sampler_params import SamplerType
1721
from data_designer.config.utils.code_lang import CodeLang
1822
from data_designer.config.validator_params import CodeValidatorParams
1923
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
2024
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
21-
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
25+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
2226

2327
MODEL_ALIAS = "stub-model-alias"
2428

@@ -111,3 +115,23 @@ def test_circular_dependencies():
111115
)
112116
with pytest.raises(DAGCircularDependencyError, match="cyclic dependencies"):
113117
topologically_sort_column_configs(column_configs)
118+
119+
120+
def test_duplicate_side_effect_producers_raises() -> None:
121+
"""Two custom columns declaring the same side-effect column is a configuration error."""
122+
123+
@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
124+
def gen_a(row: dict[str, Any]) -> dict[str, Any]:
125+
return row
126+
127+
@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
128+
def gen_b(row: dict[str, Any]) -> dict[str, Any]:
129+
return row
130+
131+
column_configs = [
132+
LLMTextColumnConfig(name="text", prompt="hello", model_alias=MODEL_ALIAS),
133+
CustomColumnConfig(name="col_a", generator_function=gen_a),
134+
CustomColumnConfig(name="col_b", generator_function=gen_b),
135+
]
136+
with pytest.raises(ConfigCompilationError, match="already produced by"):
137+
topologically_sort_column_configs(column_configs)

packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from data_designer.config.utils.code_lang import CodeLang
2020
from data_designer.config.validator_params import CodeValidatorParams
2121
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
22-
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
22+
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
2323
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
2424
from data_designer.engine.dataset_builders.utils.task_model import SliceRef
2525

@@ -156,6 +156,17 @@ def test_side_effect_name_collision_prefers_real_column() -> None:
156156
assert graph.get_downstream_columns("summary") == set()
157157

158158

159+
def test_side_effect_collision_raises() -> None:
160+
"""Two producers for the same side-effect column is a configuration error."""
161+
graph = ExecutionGraph()
162+
graph.add_column("producer_a", GenerationStrategy.CELL_BY_CELL)
163+
graph.add_column("producer_b", GenerationStrategy.CELL_BY_CELL)
164+
165+
graph.set_side_effect("shared_se", "producer_a")
166+
with pytest.raises(ConfigCompilationError, match="already produced by 'producer_a'"):
167+
graph.set_side_effect("shared_se", "producer_b")
168+
169+
159170
# -- Validation tests -------------------------------------------------------
160171

161172

0 commit comments

Comments
 (0)