Skip to content

Commit bab671b

Browse files
authored
Merge pull request #82 from fish-pace/copilot/fix-xarray-time-dimension
Fix NaN values when xarray Dataset variable has a time dimension
2 parents 0d1a4ae + 5334edf commit bab671b

2 files changed

Lines changed: 394 additions & 2 deletions

File tree

src/point_collocation/core/engine.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
_VALID_OPEN_METHODS = {"dataset", "datatree-merge"}
5252
_VALID_SPATIAL_METHODS = {"nearest", "xoak"}
5353

54+
# Time dimension names used as a fallback when cf_xarray is not installed or
55+
# when the dataset lacks CF-convention axis/units attributes. Tried in order.
56+
_TIME_DIM_NAMES = ["time", "Time", "TIME"]
57+
5458

5559
def matchup(
5660
plan: "Plan",
@@ -351,6 +355,96 @@ def _ensure_coords(ds: xr.Dataset, lon_name: str, lat_name: str) -> xr.Dataset:
351355
return ds
352356

353357

358+
def _find_time_dim(ds: xr.Dataset) -> str | None:
359+
"""Return the name of the time dimension in *ds*, or ``None`` if absent.
360+
361+
Detection strategy
362+
------------------
363+
1. **cf_xarray** (primary, if installed): inspects CF-convention attributes
364+
such as ``axis='T'``, ``standard_name``, and ``units`` to identify the
365+
time axis.
366+
2. **Name-based fallback**: if ``cf_xarray`` is not installed or the dataset
367+
lacks CF attributes, searches :data:`_TIME_DIM_NAMES` in ``ds.dims`` and
368+
``ds.coords``.
369+
370+
Only dimensions are returned (not scalar coordinates) because only a
371+
dimensional time axis requires special handling during extraction.
372+
"""
373+
# --- primary: cf_xarray ---
374+
try:
375+
import cf_xarray # noqa: F401 (registers the .cf accessor)
376+
377+
time_coords = ds.cf.axes.get("T", [])
378+
for name in time_coords:
379+
if name in ds.dims:
380+
return name
381+
# If cf_xarray found a time coordinate that is not a dimension (e.g. a
382+
# scalar), still check whether a dimension with a standard time name
383+
# exists so we do not silently miss it.
384+
except ImportError:
385+
pass
386+
except (AttributeError, KeyError):
387+
# cf_xarray is installed but the dataset lacks the attributes needed for
388+
# CF-axis detection (e.g. no standard_name / units on any variable).
389+
pass
390+
391+
# --- fallback: name-based search ---
392+
for name in _TIME_DIM_NAMES:
393+
if name in ds.dims:
394+
return name
395+
396+
return None
397+
398+
399+
def _select_time(
400+
da: xr.DataArray,
401+
time_dim: str,
402+
point_time: object,
403+
) -> xr.DataArray:
404+
"""Select the appropriate time step from *da* along *time_dim*.
405+
406+
Parameters
407+
----------
408+
da:
409+
DataArray produced after spatial selection (lat/lon already resolved).
410+
time_dim:
411+
Name of the time dimension to handle.
412+
point_time:
413+
Timestamp of the observation point (``row["time"]``). Used to find
414+
the nearest time step when *da* has multiple time steps.
415+
416+
Returns
417+
-------
418+
xr.DataArray
419+
* *da* unchanged if *time_dim* is not one of ``da.dims``.
420+
* *da* with the time dimension squeezed out when there is exactly one
421+
time step.
422+
* *da* with the nearest time step selected when there are multiple time
423+
steps and *point_time* is a valid timestamp; falls back to the first
424+
time step if *point_time* is unusable.
425+
"""
426+
if time_dim not in da.dims:
427+
return da
428+
429+
n_times = da.sizes[time_dim]
430+
431+
if n_times == 1:
432+
return da.squeeze(time_dim)
433+
434+
# Multiple time steps: select nearest to the point timestamp.
435+
try:
436+
ts = pd.Timestamp(point_time)
437+
if pd.isna(ts):
438+
raise ValueError("NaT")
439+
return da.sel({time_dim: ts}, method="nearest")
440+
except (TypeError, ValueError, KeyError):
441+
# Fallback: first time step.
442+
# - TypeError / ValueError: point_time cannot be converted to a Timestamp
443+
# or is NaT.
444+
# - KeyError: the time coordinate is absent or the sel fails on this ds.
445+
return da.isel({time_dim: 0})
446+
447+
354448
def _check_geometry(
355449
ds: xr.Dataset,
356450
lon_name: str,
@@ -662,6 +756,10 @@ def _execute_plan(
662756
# the dataset itself may not have a time coordinate at all.
663757
granule_time = gm.begin + (gm.end - gm.begin) / 2
664758

759+
# Detect time dimension once per granule so that
760+
# extraction functions can handle (time, lat, lon) variables.
761+
time_dim = _find_time_dim(ds)
762+
665763
if spatial_method == "xoak":
666764
# Build the k-d tree index once for all points in this
667765
# granule instead of rebuilding it per point. This
@@ -674,7 +772,7 @@ def _execute_plan(
674772
row["granule_id"] = gm.granule_id
675773
row["granule_time"] = granule_time
676774
rows_for_granule.append(row)
677-
_extract_xoak_batch(ds, rows_for_granule, variables, lon_name, lat_name)
775+
_extract_xoak_batch(ds, rows_for_granule, variables, lon_name, lat_name, time_dim)
678776
output_rows.extend(rows_for_granule)
679777
batch_rows.extend(rows_for_granule)
680778
else:
@@ -683,7 +781,7 @@ def _execute_plan(
683781
row["pc_id"] = pt_idx
684782
row["granule_id"] = gm.granule_id
685783
row["granule_time"] = granule_time
686-
_extract_nearest(ds, row, variables, lon_name, lat_name)
784+
_extract_nearest(ds, row, variables, lon_name, lat_name, time_dim)
687785
output_rows.append(row)
688786
batch_rows.append(row)
689787

@@ -865,12 +963,21 @@ def _extract_nearest(
865963
variables: list[str],
866964
lon_name: str,
867965
lat_name: str,
966+
time_dim: str | None = None,
868967
) -> None:
869968
"""Extract values using ``ds.sel(..., method='nearest')`` (1-D coords).
870969
871970
Modifies *row* in-place, including ``granule_lat`` and ``granule_lon``
872971
columns for the matched grid location. ``granule_time`` is set by the
873972
caller from granule metadata before this function is called.
973+
974+
Parameters
975+
----------
976+
time_dim:
977+
Name of the time dimension in *ds*, as detected by
978+
:func:`_find_time_dim`. When not ``None``, each variable is
979+
squeezed or nearest-selected along this dimension after spatial
980+
selection so that the result is always free of the time axis.
874981
"""
875982
# Extract the actual matched coordinates (nearest-neighbour grid position).
876983
try:
@@ -888,6 +995,8 @@ def _extract_nearest(
888995
{lat_name: row["lat"], lon_name: row["lon"]},
889996
method="nearest",
890997
)
998+
if time_dim is not None:
999+
selected = _select_time(selected, time_dim, row.get("time"))
8911000
if selected.ndim == 0:
8921001
row[var] = float(selected)
8931002
else:
@@ -905,6 +1014,7 @@ def _extract_xoak(
9051014
variables: list[str],
9061015
lon_name: str,
9071016
lat_name: str,
1017+
time_dim: str | None = None,
9081018
) -> None:
9091019
"""Extract values using xoak nearest-neighbour (1-D or 2-D lat/lon arrays).
9101020
@@ -917,6 +1027,14 @@ def _extract_xoak(
9171027
index over both dimensions.
9181028
9191029
Modifies *row* in-place.
1030+
1031+
Parameters
1032+
----------
1033+
time_dim:
1034+
Name of the time dimension in *ds*, as detected by
1035+
:func:`_find_time_dim`. When not ``None``, each variable is
1036+
squeezed or nearest-selected along this dimension after spatial
1037+
selection so that the result is always free of the time axis.
9201038
"""
9211039
try:
9221040
from xoak.tree_adapters import SklearnKDTreeAdapter # type: ignore[import-untyped]
@@ -973,6 +1091,8 @@ def _extract_xoak(
9731091
# variables become scalar (0-D) and 3-D variables (e.g. Rrs with
9741092
# a wavelength dimension) become 1-D.
9751093
squeezed = selected[var].squeeze()
1094+
if time_dim is not None:
1095+
squeezed = _select_time(squeezed, time_dim, row.get("time"))
9761096
if squeezed.ndim == 0:
9771097
row[var] = float(squeezed)
9781098
else:
@@ -994,6 +1114,7 @@ def _extract_xoak_batch(
9941114
variables: list[str],
9951115
lon_name: str,
9961116
lat_name: str,
1117+
time_dim: str | None = None,
9971118
) -> None:
9981119
"""Extract values for all *rows* using a single xoak k-d tree index.
9991120
@@ -1006,6 +1127,14 @@ def _extract_xoak_batch(
10061127
``SklearnKDTreeAdapter``.
10071128
10081129
Modifies each dict in *rows* in-place.
1130+
1131+
Parameters
1132+
----------
1133+
time_dim:
1134+
Name of the time dimension in *ds*, as detected by
1135+
:func:`_find_time_dim`. When not ``None``, each variable is
1136+
squeezed or nearest-selected along this dimension after spatial
1137+
selection so that the result is always free of the time axis.
10091138
"""
10101139
try:
10111140
from xoak.tree_adapters import SklearnKDTreeAdapter # type: ignore[import-untyped]
@@ -1084,6 +1213,8 @@ def _extract_xoak_batch(
10841213
# size-1 spatial dims so extra dims (e.g. wavelength) are
10851214
# kept intact.
10861215
point_data = var_data.isel({query_dim: i}).squeeze()
1216+
if time_dim is not None:
1217+
point_data = _select_time(point_data, time_dim, row.get("time"))
10871218
if point_data.ndim == 0:
10881219
row[var] = float(point_data)
10891220
else:

0 commit comments

Comments
 (0)