2626 "zfp-round" : "ZFP-ROUND" ,
2727 "sz3" : "SZ3" ,
2828 "bitround-pco-conservative-rel" : "BitRound + PCO" ,
29- "bitround-conservative-rel" : "BitRound + Zlib " ,
30- "stochround" : "StochRound + Zlib " ,
29+ "bitround-conservative-rel" : "BitRound + Zstd " ,
30+ "stochround" : "StochRound + Zstd " ,
3131 "stochround-pco" : "StochRound + PCO" ,
3232 "tthresh" : "TTHRESH" ,
3333}
3434
3535
3636def plot_metrics (
37- basepath : Path = Path (), bound_names : list [str ] = ["low" , "mid" , "high" ]
37+ basepath : Path = Path (),
38+ data_loader_base_path : None | Path = None ,
39+ bound_names : list [str ] = ["low" , "mid" , "high" ],
3840):
3941 metrics_path = basepath / "metrics"
4042 plots_path = basepath / "plots"
43+ datasets = (data_loader_base_path or basepath ) / "datasets"
44+ compressed_datasets = basepath / "compressed-datasets"
4145
4246 df = pd .read_csv (metrics_path / "all_results.csv" )
4347 plot_per_variable_metrics (
44- basepath = basepath ,
48+ datasets = datasets ,
49+ compressed_datasets = compressed_datasets ,
4550 plots_path = plots_path ,
4651 all_results = df ,
4752 )
@@ -56,9 +61,9 @@ def plot_metrics(
5661 for metric in ["Relative MAE" , "Relative DSSIM" , "Relative MaxAbsError" ]:
5762 plot_aggregated_rd_curve (
5863 normalized_df ,
59- plots_path / f"rd_curve_{ metric .lower ().replace (' ' , '_' )} .pdf" ,
6064 compression_metric = "Relative CR" ,
6165 distortion_metric = metric ,
66+ outfile = plots_path / f"rd_curve_{ metric .lower ().replace (' ' , '_' )} .pdf" ,
6267 agg = "median" ,
6368 bound_names = bound_names ,
6469 )
@@ -138,7 +143,10 @@ def get_normalizer(row):
138143
139144
140145def plot_per_variable_metrics (
141- basepath : Path , plots_path : Path , all_results : pd .DataFrame
146+ datasets : Path ,
147+ compressed_datasets : Path ,
148+ plots_path : Path ,
149+ all_results : pd .DataFrame ,
142150):
143151 """Creates all the plots which only depend on a single variable."""
144152 for dataset in all_results ["Dataset" ].unique ():
@@ -155,8 +163,9 @@ def plot_per_variable_metrics(
155163 continue
156164 plot_variable_rd_curve (
157165 df [df ["Variable" ] == var ],
158- dataset_plots_path / f"{ var } _compression_ratio_{ metric_name } .pdf" ,
159166 distortion_metric = dist_metric ,
167+ outfile = dataset_plots_path
168+ / f"{ var } _compression_ratio_{ metric_name } .pdf" ,
160169 )
161170
162171 error_bounds = df [df ["Variable" ] == var ]["Error Bound" ].unique ()
@@ -170,51 +179,49 @@ def plot_per_variable_metrics(
170179 for comp in compressors :
171180 print (f"Plotting { var } error for { comp } ..." )
172181 plot_variable_error (
173- basepath ,
182+ datasets ,
183+ compressed_datasets ,
174184 dataset ,
175185 err_bound ,
176186 comp ,
177187 var ,
178- err_bound_path / f"{ var } _{ comp } .png" ,
188+ outfile = err_bound_path / f"{ var } _{ comp } .png" ,
179189 )
180190
181191
182- def plot_variable_error (basepath , dataset_name , error_bound , compressor , var , outfile ):
183- if outfile .exists ():
192+ def plot_variable_error (
193+ datasets : Path ,
194+ compressed_datasets : Path ,
195+ dataset_name : str ,
196+ error_bound : str ,
197+ compressor : str ,
198+ var : str ,
199+ outfile : None | Path = None ,
200+ ):
201+ if outfile is not None and outfile .exists ():
184202 # These plots can be quite expensive to generate, so we skip if they already exist.
185203 return
186204
187205 compressed = (
188- basepath
189- / ".."
190- / "compressor"
191- / "compressed-datasets"
206+ compressed_datasets
192207 / dataset_name
193208 / error_bound
194209 / compressor
195210 / "decompressed.zarr"
196211 )
197- input = (
198- basepath
199- / ".."
200- / "data-loader"
201- / "datasets"
202- / dataset_name
203- / "standardized.zarr"
204- )
212+ input = datasets / dataset_name / "standardized.zarr"
205213
206214 ds = xr .open_dataset (input , chunks = dict (), engine = "zarr" ).compute ()
207215 ds_new = xr .open_dataset (compressed , chunks = dict (), engine = "zarr" ).compute ()
208- ds , ds_new = ds [var ], ds_new [var ]
209216
210217 plotter = PLOTTERS .get (dataset_name , None )
211218 if plotter :
212- plotter ().plot (ds , ds_new , dataset_name , compressor , var , outfile )
219+ plotter ().plot (ds [ var ] , ds_new [ var ] , dataset_name , compressor , var , outfile )
213220 else :
214221 print (f"No plotter found for dataset { dataset_name } " )
215222
216223
217- def plot_variable_rd_curve (df , outfile , distortion_metric ):
224+ def plot_variable_rd_curve (df , distortion_metric , outfile : None | Path = None ):
218225 plt .figure (figsize = (8 , 6 ))
219226 compressors = df ["Compressor" ].unique ()
220227 for comp in compressors :
@@ -268,15 +275,17 @@ def plot_variable_rd_curve(df, outfile, distortion_metric):
268275 )
269276
270277 plt .tight_layout ()
271- plt .savefig (outfile , dpi = 300 )
278+ if outfile is not None :
279+ with outfile .open ("wb" ) as f :
280+ plt .savefig (f , dpi = 300 )
272281 plt .close ()
273282
274283
275284def plot_aggregated_rd_curve (
276285 normalized_df ,
277- outfile ,
278286 compression_metric ,
279287 distortion_metric ,
288+ outfile : None | Path = None ,
280289 agg = "median" ,
281290 bound_names = ["low" , "mid" , "high" ],
282291):
@@ -367,11 +376,13 @@ def plot_aggregated_rd_curve(
367376 )
368377
369378 plt .tight_layout ()
370- plt .savefig (outfile , dpi = 300 )
379+ if outfile is not None :
380+ with outfile .open ("wb" ) as f :
381+ plt .savefig (f , dpi = 300 )
371382 plt .close ()
372383
373384
374- def plot_bound_violations (df , bound_names , outfile ):
385+ def plot_bound_violations (df , bound_names , outfile : None | Path = None ):
375386 fig , axs = plt .subplots (1 , 3 , figsize = (len (bound_names ) * 6 , 6 ), sharey = True )
376387
377388 for i , bound_name in enumerate (bound_names ):
@@ -401,7 +412,9 @@ def plot_bound_violations(df, bound_names, outfile):
401412 axs [i ].set_ylabel ("" )
402413
403414 fig .tight_layout ()
404- fig .savefig (outfile , dpi = 300 )
415+ if outfile is not None :
416+ with outfile .open ("wb" ) as f :
417+ fig .savefig (f , dpi = 300 )
405418 plt .close ()
406419
407420
0 commit comments