Skip to content

Commit d97d52c

Browse files
authored
chore: streamline generation metadata + consolidate sdg json (#226)
1 parent 1643645 commit d97d52c

11 files changed

Lines changed: 320 additions & 110 deletions

File tree

src/data_designer/engine/compiler.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77

88
from data_designer.config.column_configs import SeedDatasetColumnConfig
9-
from data_designer.config.config_builder import DataDesignerConfigBuilder
109
from data_designer.config.data_designer_config import DataDesignerConfig
1110
from data_designer.config.errors import InvalidConfigError
1211
from data_designer.engine.resources.resource_provider import ResourceProvider
@@ -16,13 +15,9 @@
1615
logger = logging.getLogger(__name__)
1716

1817

19-
def compile_data_designer_config(
20-
config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
21-
) -> DataDesignerConfig:
22-
config = config_builder.build()
18+
def compile_data_designer_config(config: DataDesignerConfig, resource_provider: ResourceProvider) -> DataDesignerConfig:
2319
_resolve_and_add_seed_columns(config, resource_provider.seed_reader)
2420
_validate(config)
25-
2621
return config
2722

2823

src/data_designer/engine/dataset_builders/artifact_storage.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
logger = logging.getLogger(__name__)
2525

2626
BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
27+
SDG_CONFIG_FILENAME = "sdg.json"
2728

2829

2930
class BatchStage(StrEnum):
@@ -170,12 +171,6 @@ def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
170171
shutil.move(partial_result_path, final_file_path)
171172
return final_file_path
172173

173-
def write_configs(self, json_file_name: str, configs: list[dict]) -> Path:
174-
self.mkdir_if_needed(self.base_dataset_path)
175-
with open(self.base_dataset_path / json_file_name, "w") as file:
176-
json.dump([c.model_dump(mode="json") for c in configs], file, indent=4)
177-
return self.base_dataset_path / json_file_name
178-
179174
def write_batch_to_parquet_file(
180175
self,
181176
batch_number: int,
@@ -200,11 +195,89 @@ def write_parquet_file(
200195
dataframe.to_parquet(file_path, index=False)
201196
return file_path
202197

198+
def get_parquet_file_paths(self) -> list[str]:
199+
"""Get list of parquet file paths relative to base_dataset_path.
200+
201+
Returns:
202+
List of relative paths to parquet files in the final dataset folder.
203+
"""
204+
return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
205+
206+
def get_processor_file_paths(self) -> dict[str, list[str]]:
207+
"""Get processor output files organized by processor name.
208+
209+
Returns:
210+
Dictionary mapping processor names to lists of relative file paths.
211+
"""
212+
processor_files: dict[str, list[str]] = {}
213+
if self.processors_outputs_path.exists():
214+
for processor_dir in sorted(self.processors_outputs_path.iterdir()):
215+
if processor_dir.is_dir():
216+
processor_name = processor_dir.name
217+
processor_files[processor_name] = [
218+
str(f.relative_to(self.base_dataset_path))
219+
for f in sorted(processor_dir.rglob("*"))
220+
if f.is_file()
221+
]
222+
return processor_files
223+
224+
def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
225+
"""Get all file paths organized by type.
226+
227+
Returns:
228+
Dictionary with 'parquet-files' and 'processor-files' keys.
229+
"""
230+
file_paths = {
231+
"parquet-files": self.get_parquet_file_paths(),
232+
}
233+
processor_file_paths = self.get_processor_file_paths()
234+
if processor_file_paths:
235+
file_paths["processor-files"] = processor_file_paths
236+
237+
return file_paths
238+
239+
def read_metadata(self) -> dict:
240+
"""Read metadata from the metadata.json file.
241+
242+
Returns:
243+
Dictionary containing the metadata.
244+
245+
Raises:
246+
FileNotFoundError: If metadata file doesn't exist.
247+
"""
248+
with open(self.metadata_file_path, "r") as file:
249+
return json.load(file)
250+
203251
def write_metadata(self, metadata: dict) -> Path:
252+
"""Write metadata to the metadata.json file.
253+
254+
Args:
255+
metadata: Dictionary containing metadata to write.
256+
257+
Returns:
258+
Path to the written metadata file.
259+
"""
204260
self.mkdir_if_needed(self.base_dataset_path)
205261
with open(self.metadata_file_path, "w") as file:
206-
json.dump(metadata, file)
262+
json.dump(metadata, file, indent=4, sort_keys=True)
207263
return self.metadata_file_path
208264

265+
def update_metadata(self, updates: dict) -> Path:
266+
"""Update existing metadata with new fields.
267+
268+
Args:
269+
updates: Dictionary of fields to add/update in metadata.
270+
271+
Returns:
272+
Path to the updated metadata file.
273+
"""
274+
try:
275+
existing_metadata = self.read_metadata()
276+
except FileNotFoundError:
277+
existing_metadata = {}
278+
279+
existing_metadata.update(updates)
280+
return self.write_metadata(existing_metadata)
281+
209282
def _get_stage_path(self, stage: BatchStage) -> Path:
210283
return getattr(self, resolve_string_enum(stage, BatchStage).value)

src/data_designer/engine/dataset_builders/column_wise_builder.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from typing import TYPE_CHECKING, Callable
1414

1515
from data_designer.config.column_types import ColumnConfigT
16+
from data_designer.config.config_builder import BuilderConfig
17+
from data_designer.config.data_designer_config import DataDesignerConfig
1618
from data_designer.config.dataset_builders import BuildStage
1719
from data_designer.config.processors import (
1820
DropColumnsProcessorConfig,
@@ -25,13 +27,15 @@
2527
GenerationStrategy,
2628
)
2729
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
28-
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
30+
from data_designer.engine.compiler import compile_data_designer_config
31+
from data_designer.engine.dataset_builders.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage
2932
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
30-
from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
33+
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
3134
from data_designer.engine.dataset_builders.utils.concurrency import (
3235
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
3336
ConcurrentThreadExecutor,
3437
)
38+
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
3539
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
3640
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
3741
from data_designer.engine.processing.processors.base import Processor
@@ -54,17 +58,20 @@
5458
class ColumnWiseDatasetBuilder:
5559
def __init__(
5660
self,
57-
column_configs: list[DatasetBuilderColumnConfigT],
58-
processor_configs: list[ProcessorConfig],
61+
data_designer_config: DataDesignerConfig,
5962
resource_provider: ResourceProvider,
6063
registry: DataDesignerRegistry | None = None,
6164
):
6265
self.batch_manager = DatasetBatchManager(resource_provider.artifact_storage)
6366
self._resource_provider = resource_provider
6467
self._records_to_drop: set[int] = set()
6568
self._registry = registry or DataDesignerRegistry()
66-
self._column_configs = column_configs
67-
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
69+
70+
self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
71+
self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
72+
self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(
73+
self._data_designer_config.processors or []
74+
)
6875
self._validate_column_configs()
6976

7077
@property
@@ -91,9 +98,8 @@ def build(
9198
num_records: int,
9299
on_batch_complete: Callable[[Path], None] | None = None,
93100
) -> Path:
94-
self._write_configs()
95101
self._run_model_health_check_if_needed()
96-
102+
self._write_builder_config()
97103
generators = self._initialize_generators()
98104
start_time = time.perf_counter()
99105
group_id = uuid.uuid4().hex
@@ -152,6 +158,12 @@ def _initialize_generators(self) -> list[ColumnGenerator]:
152158
for config in self._column_configs
153159
]
154160

161+
def _write_builder_config(self) -> None:
162+
self.artifact_storage.mkdir_if_needed(self.artifact_storage.base_dataset_path)
163+
BuilderConfig(data_designer=self._data_designer_config).to_json(
164+
self.artifact_storage.base_dataset_path / SDG_CONFIG_FILENAME
165+
)
166+
155167
def _run_batch(
156168
self, generators: list[ColumnGenerator], *, batch_mode: str, save_partial_results: bool = True, group_id: str
157169
) -> None:
@@ -303,16 +315,6 @@ def _worker_error_callback(self, exc: Exception, *, context: dict | None = None)
303315
def _worker_result_callback(self, result: dict, *, context: dict | None = None) -> None:
304316
self.batch_manager.update_record(context["index"], result)
305317

306-
def _write_configs(self) -> None:
307-
self.artifact_storage.write_configs(
308-
json_file_name="column_configs.json",
309-
configs=self._column_configs,
310-
)
311-
self.artifact_storage.write_configs(
312-
json_file_name="model_configs.json",
313-
configs=self._resource_provider.model_registry.model_configs.values(),
314-
)
315-
316318
def _emit_batch_inference_events(
317319
self, batch_mode: str, usage_deltas: dict[str, ModelUsageStats], group_id: str
318320
) -> None:

src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat
9191
"total_num_batches": self.num_batches,
9292
"buffer_size": self._buffer_size,
9393
"schema": {field.name: str(field.type) for field in pq.read_schema(final_file_path)},
94-
"file_paths": [str(f) for f in sorted(self.artifact_storage.final_dataset_path.glob("*.parquet"))],
95-
"num_records": self.num_records_list[: self._current_batch_number + 1],
94+
"file_paths": self.artifact_storage.get_file_paths(),
9695
"num_completed_batches": self._current_batch_number + 1,
9796
"dataset_name": self.artifact_storage.dataset_name,
9897
}

src/data_designer/interface/data_designer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
1111
from data_designer.config.config_builder import DataDesignerConfigBuilder
12+
from data_designer.config.data_designer_config import DataDesignerConfig
1213
from data_designer.config.default_model_settings import (
1314
get_default_model_configs,
1415
get_default_model_providers_missing_api_keys,
@@ -34,7 +35,6 @@
3435
from data_designer.engine.compiler import compile_data_designer_config
3536
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
3637
from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder
37-
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
3838
from data_designer.engine.model_provider import resolve_model_provider_registry
3939
from data_designer.engine.resources.managed_storage import init_managed_blob_storage
4040
from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider
@@ -165,7 +165,7 @@ def create(
165165

166166
resource_provider = self._create_resource_provider(dataset_name, config_builder)
167167

168-
builder = self._create_dataset_builder(config_builder, resource_provider)
168+
builder = self._create_dataset_builder(config_builder.build(), resource_provider)
169169

170170
try:
171171
builder.build(num_records=num_records)
@@ -183,6 +183,12 @@ def create(
183183

184184
dataset_metadata = resource_provider.get_dataset_metadata()
185185

186+
# Update metadata with column statistics from analysis
187+
if analysis:
188+
builder.artifact_storage.update_metadata(
189+
{"column_statistics": [stat.model_dump(mode="json") for stat in analysis.column_statistics]}
190+
)
191+
186192
return DatasetCreationResults(
187193
artifact_storage=builder.artifact_storage,
188194
analysis=analysis,
@@ -213,7 +219,7 @@ def preview(
213219
logger.info(f"{RandomEmoji.previewing()} Preview generation in progress")
214220

215221
resource_provider = self._create_resource_provider("preview-dataset", config_builder)
216-
builder = self._create_dataset_builder(config_builder, resource_provider)
222+
builder = self._create_dataset_builder(config_builder.build(), resource_provider)
217223

218224
try:
219225
raw_dataset = builder.build_preview(num_records=num_records)
@@ -277,7 +283,7 @@ def validate(self, config_builder: DataDesignerConfigBuilder) -> None:
277283
InvalidConfigError: If the configuration is invalid.
278284
"""
279285
resource_provider = self._create_resource_provider("validate-configuration", config_builder)
280-
compile_data_designer_config(config_builder, resource_provider)
286+
compile_data_designer_config(config_builder.build(), resource_provider)
281287

282288
def get_default_model_configs(self) -> list[ModelConfig]:
283289
"""Get the default model configurations.
@@ -342,14 +348,11 @@ def _resolve_model_providers(self, model_providers: list[ModelProvider] | None)
342348

343349
def _create_dataset_builder(
344350
self,
345-
config_builder: DataDesignerConfigBuilder,
351+
data_designer_config: DataDesignerConfig,
346352
resource_provider: ResourceProvider,
347353
) -> ColumnWiseDatasetBuilder:
348-
config = compile_data_designer_config(config_builder, resource_provider)
349-
350354
return ColumnWiseDatasetBuilder(
351-
column_configs=compile_dataset_builder_column_configs(config),
352-
processor_configs=config.processors or [],
355+
data_designer_config=data_designer_config,
353356
resource_provider=resource_provider,
354357
)
355358

tests/engine/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def stub_resource_provider(tmp_path, stub_model_facade):
3737
mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path)
3838
mock_provider.blob_storage = Mock(spec=ManagedBlobStorage)
3939
mock_provider.seed_reader = Mock()
40+
mock_provider.seed_reader.get_column_names.return_value = []
4041
mock_provider.run_config = RunConfig()
4142
return mock_provider
4243

0 commit comments

Comments
 (0)