Skip to content

Commit 061a60e

Browse files
treigermjuntyr
andauthored
Make plotting compatible with notebook demo (#30)
* Make plotting compatible with notebook demo * Make mypy happy * Make bound violations plot conditional * Allow plotting to virtual paths * Move py.typed to the proper location --------- Co-authored-by: Juniper Tyree <juniper.tyree@helsinki.fi>
1 parent 7681219 commit 061a60e

File tree

3 files changed

+51
-33
lines changed

3 files changed

+51
-33
lines changed

src/climatebenchpress/compressor/plotting/plot_metrics.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,27 @@
2626
"zfp-round": "ZFP-ROUND",
2727
"sz3": "SZ3",
2828
"bitround-pco-conservative-rel": "BitRound + PCO",
29-
"bitround-conservative-rel": "BitRound + Zlib",
30-
"stochround": "StochRound + Zlib",
29+
"bitround-conservative-rel": "BitRound + Zstd",
30+
"stochround": "StochRound + Zstd",
3131
"stochround-pco": "StochRound + PCO",
3232
"tthresh": "TTHRESH",
3333
}
3434

3535

3636
def plot_metrics(
37-
basepath: Path = Path(), bound_names: list[str] = ["low", "mid", "high"]
37+
basepath: Path = Path(),
38+
data_loader_base_path: None | Path = None,
39+
bound_names: list[str] = ["low", "mid", "high"],
3840
):
3941
metrics_path = basepath / "metrics"
4042
plots_path = basepath / "plots"
43+
datasets = (data_loader_base_path or basepath) / "datasets"
44+
compressed_datasets = basepath / "compressed-datasets"
4145

4246
df = pd.read_csv(metrics_path / "all_results.csv")
4347
plot_per_variable_metrics(
44-
basepath=basepath,
48+
datasets=datasets,
49+
compressed_datasets=compressed_datasets,
4550
plots_path=plots_path,
4651
all_results=df,
4752
)
@@ -56,9 +61,9 @@ def plot_metrics(
5661
for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]:
5762
plot_aggregated_rd_curve(
5863
normalized_df,
59-
plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
6064
compression_metric="Relative CR",
6165
distortion_metric=metric,
66+
outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
6267
agg="median",
6368
bound_names=bound_names,
6469
)
@@ -138,7 +143,10 @@ def get_normalizer(row):
138143

139144

140145
def plot_per_variable_metrics(
141-
basepath: Path, plots_path: Path, all_results: pd.DataFrame
146+
datasets: Path,
147+
compressed_datasets: Path,
148+
plots_path: Path,
149+
all_results: pd.DataFrame,
142150
):
143151
"""Creates all the plots which only depend on a single variable."""
144152
for dataset in all_results["Dataset"].unique():
@@ -155,8 +163,9 @@ def plot_per_variable_metrics(
155163
continue
156164
plot_variable_rd_curve(
157165
df[df["Variable"] == var],
158-
dataset_plots_path / f"{var}_compression_ratio_{metric_name}.pdf",
159166
distortion_metric=dist_metric,
167+
outfile=dataset_plots_path
168+
/ f"{var}_compression_ratio_{metric_name}.pdf",
160169
)
161170

162171
error_bounds = df[df["Variable"] == var]["Error Bound"].unique()
@@ -170,51 +179,49 @@ def plot_per_variable_metrics(
170179
for comp in compressors:
171180
print(f"Plotting {var} error for {comp}...")
172181
plot_variable_error(
173-
basepath,
182+
datasets,
183+
compressed_datasets,
174184
dataset,
175185
err_bound,
176186
comp,
177187
var,
178-
err_bound_path / f"{var}_{comp}.png",
188+
outfile=err_bound_path / f"{var}_{comp}.png",
179189
)
180190

181191

182-
def plot_variable_error(basepath, dataset_name, error_bound, compressor, var, outfile):
183-
if outfile.exists():
192+
def plot_variable_error(
193+
datasets: Path,
194+
compressed_datasets: Path,
195+
dataset_name: str,
196+
error_bound: str,
197+
compressor: str,
198+
var: str,
199+
outfile: None | Path = None,
200+
):
201+
if outfile is not None and outfile.exists():
184202
# These plots can be quite expensive to generate, so we skip if they already exist.
185203
return
186204

187205
compressed = (
188-
basepath
189-
/ ".."
190-
/ "compressor"
191-
/ "compressed-datasets"
206+
compressed_datasets
192207
/ dataset_name
193208
/ error_bound
194209
/ compressor
195210
/ "decompressed.zarr"
196211
)
197-
input = (
198-
basepath
199-
/ ".."
200-
/ "data-loader"
201-
/ "datasets"
202-
/ dataset_name
203-
/ "standardized.zarr"
204-
)
212+
input = datasets / dataset_name / "standardized.zarr"
205213

206214
ds = xr.open_dataset(input, chunks=dict(), engine="zarr").compute()
207215
ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr").compute()
208-
ds, ds_new = ds[var], ds_new[var]
209216

210217
plotter = PLOTTERS.get(dataset_name, None)
211218
if plotter:
212-
plotter().plot(ds, ds_new, dataset_name, compressor, var, outfile)
219+
plotter().plot(ds[var], ds_new[var], dataset_name, compressor, var, outfile)
213220
else:
214221
print(f"No plotter found for dataset {dataset_name}")
215222

216223

217-
def plot_variable_rd_curve(df, outfile, distortion_metric):
224+
def plot_variable_rd_curve(df, distortion_metric, outfile: None | Path = None):
218225
plt.figure(figsize=(8, 6))
219226
compressors = df["Compressor"].unique()
220227
for comp in compressors:
@@ -268,15 +275,17 @@ def plot_variable_rd_curve(df, outfile, distortion_metric):
268275
)
269276

270277
plt.tight_layout()
271-
plt.savefig(outfile, dpi=300)
278+
if outfile is not None:
279+
with outfile.open("wb") as f:
280+
plt.savefig(f, dpi=300)
272281
plt.close()
273282

274283

275284
def plot_aggregated_rd_curve(
276285
normalized_df,
277-
outfile,
278286
compression_metric,
279287
distortion_metric,
288+
outfile: None | Path = None,
280289
agg="median",
281290
bound_names=["low", "mid", "high"],
282291
):
@@ -367,11 +376,13 @@ def plot_aggregated_rd_curve(
367376
)
368377

369378
plt.tight_layout()
370-
plt.savefig(outfile, dpi=300)
379+
if outfile is not None:
380+
with outfile.open("wb") as f:
381+
plt.savefig(f, dpi=300)
371382
plt.close()
372383

373384

374-
def plot_bound_violations(df, bound_names, outfile):
385+
def plot_bound_violations(df, bound_names, outfile: None | Path = None):
375386
fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True)
376387

377388
for i, bound_name in enumerate(bound_names):
@@ -401,7 +412,9 @@ def plot_bound_violations(df, bound_names, outfile):
401412
axs[i].set_ylabel("")
402413

403414
fig.tight_layout()
404-
fig.savefig(outfile, dpi=300)
415+
if outfile is not None:
416+
with outfile.open("wb") as f:
417+
fig.savefig(f, dpi=300)
405418
plt.close()
406419

407420

src/climatebenchpress/compressor/plotting/variable_plotters.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from pathlib import Path
23

34
import cartopy.crs as ccrs
45
import matplotlib.colors as mcolors
@@ -16,7 +17,9 @@ def __init__(self):
1617
def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var):
1718
pass
1819

19-
def plot(self, ds, ds_new, dataset_name, compressor, var, outfile):
20+
def plot(
21+
self, ds, ds_new, dataset_name, compressor, var, outfile: None | Path = None
22+
):
2023
fig, ax = plt.subplots(
2124
nrows=1,
2225
ncols=3,
@@ -32,7 +35,9 @@ def plot(self, ds, ds_new, dataset_name, compressor, var, outfile):
3235
ax[2].set_title("Error")
3336
fig.suptitle(f"{var} Error for {dataset_name} ({compressor})")
3437
fig.tight_layout()
35-
fig.savefig(outfile, dpi=300)
38+
if outfile is not None:
39+
with outfile.open("wb") as f:
40+
fig.savefig(f, dpi=300)
3641
plt.close()
3742

3843

0 commit comments

Comments
 (0)