Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -83,6 +84,13 @@
_CLIENT_VERSION: str = get_library_version()


@dataclass
class _ResumeState:
num_completed_batches: int
actual_num_records: int
buffer_size: int


class DatasetBuilder:
def __init__(
self,
Expand Down Expand Up @@ -146,6 +154,7 @@ def build(
num_records: int,
on_batch_complete: Callable[[Path], None] | None = None,
save_multimedia_to_disk: bool = True,
resume: bool = False,
) -> Path:
"""Build the dataset.

Expand All @@ -155,6 +164,10 @@ def build(
save_multimedia_to_disk: Whether to save generated multimedia (images, audio, video) to disk.
If False, multimedia is stored directly in the DataFrame (e.g., images as base64).
Default is True.
resume: If True, resume generation from the last completed batch (sync engine) or
row group (async engine) found in the existing artifact directory. The run parameters
(num_records, buffer_size) must match those of the original run. Any in-flight
partial results from the interrupted run are discarded.

Returns:
Path to the generated dataset directory.
Expand All @@ -172,9 +185,23 @@ def build(
start_time = time.perf_counter()
buffer_size = self._resource_provider.run_config.buffer_size

if resume and not self.artifact_storage.metadata_file_path.exists():
# No metadata.json means the previous run was interrupted before any batch (sync) or
# row group (async) completed. Nothing to resume — discard any leftover partial
# results and start fresh.
logger.info(
"▶️ No metadata.json found — the previous run was interrupted before any batch "
"completed. Starting generation from the beginning."
)
self.artifact_storage.clear_partial_results()
resume = False

generated = True
if DATA_DESIGNER_ASYNC_ENGINE:
self._validate_async_compatibility()
self._build_async(generators, num_records, buffer_size, on_batch_complete)
generated = self._build_async(generators, num_records, buffer_size, on_batch_complete, resume=resume)
elif resume:
generated = self._build_with_resume(generators, num_records, buffer_size, on_batch_complete)
else:
group_id = uuid.uuid4().hex
self.batch_manager.start(num_records=num_records, buffer_size=buffer_size)
Expand All @@ -189,11 +216,97 @@ def build(
)
self.batch_manager.finish()

self._processor_runner.run_after_generation(buffer_size)
if generated:
self._processor_runner.run_after_generation(buffer_size)
self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time)

return self.artifact_storage.final_dataset_path

def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState:
"""Read and validate resume state from an existing metadata.json.

Raises:
DatasetGenerationError: If metadata is missing or incompatible with the current run parameters.
"""
try:
metadata = self.artifact_storage.read_metadata()
except FileNotFoundError:
raise DatasetGenerationError(
"🛑 Cannot resume: metadata.json not found in the existing dataset directory. "
"Run without resume=True to start a new generation."
)

target = metadata.get("target_num_records")
if target != num_records:
raise DatasetGenerationError(
f"🛑 Cannot resume: num_records={num_records} does not match the original run's "
f"target_num_records={target}. Use the same num_records as the interrupted run, "
"or start a new run without resume=True."
)

meta_buffer_size = metadata.get("buffer_size")
if meta_buffer_size != buffer_size:
raise DatasetGenerationError(
f"🛑 Cannot resume: buffer_size={buffer_size} does not match the original run's "
f"buffer_size={meta_buffer_size}. Use the same buffer_size as the interrupted run, "
"or start a new run without resume=True."
)

return _ResumeState(
num_completed_batches=metadata["num_completed_batches"],
actual_num_records=metadata["actual_num_records"],
buffer_size=buffer_size,
)

def _build_with_resume(
self,
generators: list[ColumnGenerator],
num_records: int,
buffer_size: int,
on_batch_complete: Callable[[Path], None] | None,
) -> bool:
"""Resume generation from the last completed batch.

Returns:
False if the dataset was already complete (no new records generated),
True after successfully generating the remaining batches.
"""
state = self._load_resume_state(num_records, buffer_size)

self.batch_manager.start(
num_records=num_records,
buffer_size=buffer_size,
start_batch=state.num_completed_batches,
initial_actual_num_records=state.actual_num_records,
)

if state.num_completed_batches >= self.batch_manager.num_batches:
logger.warning(
"⚠️ Dataset is already complete — all batches were found in the existing artifact directory. "
"Nothing to resume. Remove resume=True if you want to generate a new dataset."
)
return False

logger.info(
f"▶️ Resuming from batch {state.num_completed_batches + 1} of {self.batch_manager.num_batches} "
f"({state.actual_num_records} records already generated)."
)

self.artifact_storage.clear_partial_results()

group_id = uuid.uuid4().hex
for batch_idx in range(state.num_completed_batches, self.batch_manager.num_batches):
logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}")
self._run_batch(
generators,
batch_mode="batch",
group_id=group_id,
current_batch_number=batch_idx,
on_batch_complete=on_batch_complete,
)
self.batch_manager.finish()
return True

def build_preview(self, *, num_records: int) -> pd.DataFrame:
self._run_model_health_check_if_needed()
self._run_mcp_tool_check_if_needed()
Expand Down Expand Up @@ -253,25 +366,78 @@ def _validate_async_compatibility(self) -> None:
f"disable the async scheduler."
)

def _find_completed_row_group_ids(self) -> set[int]:
"""Scan the final dataset directory for already-written row group parquet files.

Returns:
Set of row-group IDs (batch numbers) that have a parquet file in ``parquet-files/``.
"""
final_path = self.artifact_storage.final_dataset_path
if not final_path.exists():
return set()
ids: set[int] = set()
for p in final_path.glob("batch_*.parquet"):
try:
ids.add(int(p.stem.split("_", 1)[1]))
except (ValueError, IndexError):
continue
return ids

def _build_async(
self,
generators: list[ColumnGenerator],
num_records: int,
buffer_size: int,
on_batch_complete: Callable[[Path], None] | None = None,
) -> None:
"""Async task-queue builder path - dispatches tasks based on dependency readiness."""
*,
resume: bool = False,
) -> bool:
"""Async task-queue builder path - dispatches tasks based on dependency readiness.

Returns:
False if the dataset was already complete (no new records generated),
True after successfully running the scheduler.
"""
logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue builder")

settings = self._resource_provider.run_config
trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1"

skip_row_groups: frozenset[int] = frozenset()
initial_actual_num_records = 0
initial_total_num_batches = 0

if resume:
state = self._load_resume_state(num_records, buffer_size)
completed_ids = self._find_completed_row_group_ids()
skip_row_groups = frozenset(completed_ids)
initial_actual_num_records = state.actual_num_records
# Use filesystem count as source of truth — metadata may lag by one row group
# if a crash occurred between move_partial_result_to_final_file_path and write_metadata.
initial_total_num_batches = len(completed_ids)
self.artifact_storage.clear_partial_results()

total_row_groups = -(-num_records // buffer_size) # ceiling division
if len(completed_ids) >= total_row_groups:
logger.warning(
"⚠️ Dataset is already complete — all row groups were found in the existing artifact "
"directory. Nothing to resume. Remove resume=True if you want to generate a new dataset."
)
return False

logger.info(
f"▶️ Resuming async run: {len(completed_ids)} of {total_row_groups} row group(s) already "
f"complete ({state.actual_num_records} records), skipping them."
)

def finalize_row_group(rg_id: int) -> None:
def on_complete(final_path: Path | str | None) -> None:
if final_path is not None and on_batch_complete:
on_batch_complete(final_path)

buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete)
# Write incremental metadata after each row group so interrupted runs can be resumed.
buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size)

scheduler, buffer_manager = self._prepare_async_run(
generators,
Expand All @@ -282,6 +448,9 @@ def on_complete(final_path: Path | str | None) -> None:
shutdown_error_window=settings.shutdown_error_window,
disable_early_shutdown=settings.disable_early_shutdown,
trace=trace_enabled,
skip_row_groups=skip_row_groups,
initial_actual_num_records=initial_actual_num_records,
initial_total_num_batches=initial_total_num_batches,
)

# Telemetry snapshot
Expand All @@ -302,8 +471,9 @@ def on_complete(final_path: Path | str | None) -> None:
except Exception:
logger.debug("Failed to emit batch telemetry for async run", exc_info=True)

# Write metadata
# Write final metadata (overwrites the last incremental write with identical content).
buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size)
return True

def _prepare_async_run(
self,
Expand All @@ -317,6 +487,9 @@ def _prepare_async_run(
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
trace: bool = False,
skip_row_groups: frozenset[int] = frozenset(),
initial_actual_num_records: int = 0,
initial_total_num_batches: int = 0,
) -> tuple[AsyncTaskScheduler, RowGroupBufferManager]:
"""Build a fully-wired scheduler and buffer manager for async generation.

Expand All @@ -339,18 +512,23 @@ def _prepare_async_run(
for gen in generators:
gen.log_pre_generation()

# Partition into row groups
# Partition into row groups, skipping any already completed on resume.
row_groups: list[tuple[int, int]] = []
remaining = num_records
rg_id = 0
while remaining > 0:
size = min(buffer_size, remaining)
row_groups.append((rg_id, size))
if rg_id not in skip_row_groups:
row_groups.append((rg_id, size))
remaining -= size
rg_id += 1

tracker = CompletionTracker.with_graph(graph, row_groups)
buffer_manager = RowGroupBufferManager(self.artifact_storage)
buffer_manager = RowGroupBufferManager(
self.artifact_storage,
initial_actual_num_records=initial_actual_num_records,
initial_total_num_batches=initial_total_num_batches,
)

# Pre-batch processor callback: runs after seed tasks complete for a row group.
# If it raises, the scheduler drops all rows in the row group (skips it).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,14 @@ def reset(self, delete_files: bool = False) -> None:
except OSError as e:
raise DatasetBatchManagementError(f"🛑 Failed to delete directory {dir_path}: {e}")

def start(self, *, num_records: int, buffer_size: int) -> None:
def start(
self,
*,
num_records: int,
buffer_size: int,
start_batch: int = 0,
initial_actual_num_records: int = 0,
) -> None:
if num_records <= 0:
raise DatasetBatchManagementError("🛑 num_records must be positive.")
if buffer_size <= 0:
Expand All @@ -168,6 +175,8 @@ def start(self, *, num_records: int, buffer_size: int) -> None:
if remaining_records := num_records % buffer_size:
self._num_records_list.append(remaining_records)
self.reset()
self._current_batch_number = start_batch
self._actual_num_records = initial_actual_num_records

def write(self) -> Path | None:
"""Write the current batch to a parquet file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ class RowGroupBufferManager:
exclusively by the async scheduler.
"""

def __init__(self, artifact_storage: ArtifactStorage) -> None:
def __init__(
self,
artifact_storage: ArtifactStorage,
initial_actual_num_records: int = 0,
initial_total_num_batches: int = 0,
) -> None:
self._buffers: dict[int, list[dict]] = {}
self._row_group_sizes: dict[int, int] = {}
self._dropped: dict[int, set[int]] = {}
self._artifact_storage = artifact_storage
self._actual_num_records: int = 0
self._total_num_batches: int = 0
self._actual_num_records: int = initial_actual_num_records
self._total_num_batches: int = initial_total_num_batches

def init_row_group(self, row_group: int, size: int) -> None:
"""Allocate a buffer for *row_group* with *size* empty rows."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ArtifactStorage(BaseModel):
partial_results_folder_name: str = "tmp-partial-parquet-files"
dropped_columns_folder_name: str = "dropped-columns-parquet-files"
processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME
resume: bool = False
_media_storage: MediaStorage = PrivateAttr(default=None)

@property
Expand All @@ -67,12 +68,19 @@ def artifact_path_exists(self) -> bool:
def resolved_dataset_name(self) -> str:
dataset_path = self.artifact_path / self.dataset_name
if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
if self.resume:
return self.dataset_name
new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}"
logger.info(
f"📂 Dataset path {str(dataset_path)!r} already exists. Dataset from this session"
f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead."
)
return new_dataset_name
if self.resume:
raise ArtifactStorageError(
f"🛑 Cannot resume: no existing dataset found at {str(dataset_path)!r}. "
"Run without resume=True to start a new generation."
)
return self.dataset_name

@property
Expand Down Expand Up @@ -204,6 +212,11 @@ def load_dataset_with_dropped_columns(self) -> pd.DataFrame:
df = lazy.pd.concat([df, df_dropped], axis=1)
return df

def clear_partial_results(self) -> None:
"""Remove any in-flight partial results left over from an interrupted run."""
if self.partial_results_path.exists():
shutil.rmtree(self.partial_results_path)

def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT)
if not partial_result_path.exists():
Expand Down
Loading
Loading