44import com .google .gson .JsonArray ;
55import com .google .gson .JsonElement ;
66import com .google .gson .JsonParser ;
7- import java .util .ArrayList ;
87import java .util .List ;
98import org .apache .beam .sdk .Pipeline ;
109import org .apache .beam .sdk .io .gcp .spanner .SpannerWriteResult ;
1110import org .apache .beam .sdk .metrics .Counter ;
1211import org .apache .beam .sdk .metrics .Metrics ;
1312import org .apache .beam .sdk .options .PipelineOptionsFactory ;
14- import org .apache .beam .sdk .transforms .Flatten ;
13+ import org .apache .beam .sdk .transforms .Create ;
1514import org .apache .beam .sdk .transforms .Values ;
15+ import org .apache .beam .sdk .transforms .Wait ;
1616import org .apache .beam .sdk .values .PCollection ;
17- import org .apache .beam .sdk .values .PCollectionList ;
1817import org .apache .beam .sdk .values .PCollectionTuple ;
18+ import org .apache .beam .sdk .values .TypeDescriptor ;
1919import org .datacommons .ingestion .spanner .SpannerClient ;
2020import org .datacommons .ingestion .util .GraphReader ;
2121import org .datacommons .ingestion .util .PipelineUtils ;
@@ -70,12 +70,6 @@ public static void buildPipeline(
7070 Pipeline pipeline , IngestionPipelineOptions options , SpannerClient spannerClient ) {
7171 LOGGER .info ("Running import pipeline for imports: {}" , options .getImportList ());
7272
73- // Initialize lists to hold mutations from all imports.
74- List <PCollection <Void >> deleteOpsList = new ArrayList <>();
75- List <PCollection <Mutation >> obsMutationList = new ArrayList <>();
76- List <PCollection <Mutation >> edgeMutationList = new ArrayList <>();
77- List <PCollection <Mutation >> nodeMutationList = new ArrayList <>();
78-
7973 // Parse the input import list JSON.
8074 JsonElement jsonElement = JsonParser .parseString (options .getImportList ());
8175 JsonArray jsonArray = jsonElement .getAsJsonArray ();
@@ -97,37 +91,8 @@ public static void buildPipeline(
9791 String graphPath = pathElement .getAsString ();
9892
9993 // Process the individual import.
100- processImport (
101- pipeline ,
102- spannerClient ,
103- importName ,
104- graphPath ,
105- options .getSkipDelete (),
106- deleteOpsList ,
107- nodeMutationList ,
108- edgeMutationList ,
109- obsMutationList );
110- }
111- // Finally, aggregate all collected mutations and write them to Spanner.
112- // 1. Process Deletes:
113- // First, execute all delete mutations to clear old data for the imports.
114- PCollection <Void > deleted =
115- PCollectionList .of (deleteOpsList ).apply ("DeleteOps" , Flatten .pCollections ());
116-
117- // 2. Process Observations:
118- // Write observation mutations after deletes are complete.
119- if (options .getWriteObsGraph ()) {
120- spannerClient .writeMutations (pipeline , "Observations" , obsMutationList , deleted );
94+ processImport (pipeline , spannerClient , importName , graphPath , options .getSkipDelete ());
12195 }
122-
123- // 3. Process Nodes:
124- // Write node mutations after deletes are complete.
125- SpannerWriteResult writtenNodes =
126- spannerClient .writeMutations (pipeline , "Nodes" , nodeMutationList , deleted );
127-
128- // 4. Process Edges:
129- // Write edge mutations only after node mutations are complete to ensure referential integrity.
130- spannerClient .writeMutations (pipeline , "Edges" , edgeMutationList , writtenNodes .getOutput ());
13196 }
13297
13398 /**
@@ -138,67 +103,94 @@ public static void buildPipeline(
138103 * @param importName The name of the import.
139104 * @param graphPath The full path to the graph data.
140105 * @param skipDelete Whether to skip delete operations.
141- * @param deleteOpsList List to collect delete Ops.
142- * @param nodeMutationList List to collect node mutations.
143- * @param edgeMutationList List to collect edge mutations.
144- * @param obsMutationList List to collect observation mutations.
145106 */
146107 private static void processImport (
147108 Pipeline pipeline ,
148109 SpannerClient spannerClient ,
149110 String importName ,
150111 String graphPath ,
151- boolean skipDelete ,
152- List <PCollection <Void >> deleteOpsList ,
153- List <PCollection <Mutation >> nodeMutationList ,
154- List <PCollection <Mutation >> edgeMutationList ,
155- List <PCollection <Mutation >> obsMutationList ) {
112+ boolean skipDelete ) {
156113 LOGGER .info ("Import: {} Graph path: {}" , importName , graphPath );
157114
158115 String provenance = "dc/base/" + importName ;
159116
160117 // 1. Prepare Deletes:
161118 // Generate mutations to delete existing data for this import/provenance.
119+ // Create a dummy signal if deletes are skipped, so downstream dependencies are satisfied
120+ // immediately.
121+ PCollection <Void > deleteObsWait ;
122+ PCollection <Void > deleteEdgesWait ;
162123 if (!skipDelete ) {
163- List <PCollection <Void >> deleteOps =
164- GraphReader .deleteExistingDataForImport (importName , provenance , pipeline , spannerClient );
165- deleteOpsList .addAll (deleteOps );
124+ deleteObsWait =
125+ spannerClient .deleteDataForImport (
126+ pipeline , importName , spannerClient .getObservationTableName (), "import_name" );
127+ deleteEdgesWait =
128+ spannerClient .deleteDataForImport (
129+ pipeline , provenance , spannerClient .getEdgeTableName (), "provenance" );
130+ } else {
131+ deleteObsWait =
132+ pipeline .apply (
133+ "CreateEmptyObsWait-" + importName , Create .empty (TypeDescriptor .of (Void .class )));
134+ deleteEdgesWait =
135+ pipeline .apply (
136+ "CreateEmptyEdgesWait-" + importName , Create .empty (TypeDescriptor .of (Void .class )));
166137 }
167138
168139 // 2. Read and Split Graph:
169140 // Read the graph data (TFRecord or MCF files) and split into schema and observation nodes.
170141 PCollection <McfGraph > graph =
171142 graphPath .contains ("tfrecord" )
172- ? PipelineUtils .readMcfGraph (graphPath , pipeline )
173- : PipelineUtils .readMcfFiles (graphPath , pipeline );
174- PCollectionTuple graphNodes = PipelineUtils .splitGraph (graph );
143+ ? PipelineUtils .readMcfGraph (importName , graphPath , pipeline )
144+ : PipelineUtils .readMcfFiles (importName , graphPath , pipeline );
145+ PCollectionTuple graphNodes = PipelineUtils .splitGraph (importName , graph );
175146 PCollection <McfGraph > observationNodes = graphNodes .get (PipelineUtils .OBSERVATION_NODES_TAG );
176147 PCollection <McfGraph > schemaNodes = graphNodes .get (PipelineUtils .SCHEMA_NODES_TAG );
177148
178149 // 3. Process Schema Nodes:
179- // Combine schema nodes if required, then convert to Node and Edge mutations .
150+ // Combine nodes if required.
180151 PCollection <McfGraph > combinedGraph = schemaNodes ;
181152 if (IMPORTS_TO_COMBINE .contains (importName )) {
182- combinedGraph = PipelineUtils .combineGraphNodes (schemaNodes );
153+ combinedGraph = PipelineUtils .combineGraphNodes (importName , schemaNodes );
183154 }
155+
156+ // Convert all nodes to mutations
184157 PCollection <Mutation > nodeMutations =
185158 GraphReader .graphToNodes (
186- importName , combinedGraph , spannerClient , nodeCounter , nodeInvalidTypeCounter )
159+ "NodeMutations-" + importName ,
160+ combinedGraph ,
161+ spannerClient ,
162+ nodeCounter ,
163+ nodeInvalidTypeCounter )
187164 .apply ("ExtractNodeMutations-" + importName , Values .create ());
188165 PCollection <Mutation > edgeMutations =
189- GraphReader .graphToEdges (importName , combinedGraph , provenance , spannerClient , edgeCounter )
166+ GraphReader .graphToEdges (
167+ "EdgeMutations-" + importName ,
168+ combinedGraph ,
169+ provenance ,
170+ spannerClient ,
171+ edgeCounter )
190172 .apply ("ExtractEdgeMutations-" + importName , Values .create ());
191173
192- nodeMutationList .add (nodeMutations );
193- edgeMutationList .add (edgeMutations );
174+ // Write Nodes
175+ SpannerWriteResult writtenNodes =
176+ spannerClient .writeMutations (pipeline , "WriteNodesToSpanner-" + importName , nodeMutations );
177+
178+ // Write Edges (wait for Nodes write and Edges delete)
179+ edgeMutations .apply (
180+ "EdgesWaitOn-" + importName , Wait .on (List .of (writtenNodes .getOutput (), deleteEdgesWait )));
181+ spannerClient .writeMutations (pipeline , "WriteEdgesToSpanner-" + importName , edgeMutations );
194182
195183 // 4. Process Observation Nodes:
196184 // Build an optimized graph from observation nodes and convert to Observation mutations.
197185 PCollection <McfOptimizedGraph > optimizedGraph =
198- PipelineUtils .buildOptimizedMcfGraph (observationNodes );
186+ PipelineUtils .buildOptimizedMcfGraph (importName , observationNodes );
199187 PCollection <Mutation > observationMutations =
200188 GraphReader .graphToObservations (optimizedGraph , importName , spannerClient , obsCounter )
201- .apply ("ExtractObsMutations" , Values .create ());
202- obsMutationList .add (observationMutations );
189+ .apply ("ExtractObsMutations-" + importName , Values .create ());
190+ // Write Observations (wait for Obs delete)
191+ observationMutations .apply ("ObsWaitOn-" + importName , Wait .on (deleteObsWait ));
192+
193+ spannerClient .writeMutations (
194+ pipeline , "WriteObservationsToSpanner-" + importName , observationMutations );
203195 }
204196}
0 commit comments