Skip to content

Commit 0b97fe0

Browse files
committed
transform data with transformation sequence and return DataArray
1 parent b1a8c03 commit 0b97fe0

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed
Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
from dataclasses import asdict
2-
31
import numpy as np
42
from xarray import DataArray
53
import networkx as nx
4+
import ome_zarr_models as ozm
5+
6+
7+
def validata_point_shape(point: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence):
8+
for transformation in transformation_sequence.transformations:
9+
assert len(point) == transformation.ndim, "Point ndim doesn't match transformation ndim"
610

7-
def transform(data: np.ndarray, axes: list[str], transformation_graph: nx.DiGraph, input_coordinate_system_name: str, output_coordinate_system_name: str) -> DataArray:
11+
def transform_with_sequence(data: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence,
12+
output_axes: list[str]) -> DataArray:
813
# locate (inside the graph) the coordinate_system classes from the coordinate_system names
914
# first validate the input data wrt to axes and input_coordinate_system
1015
# 1. check that the data shape is (n x len(axes))
@@ -13,6 +18,20 @@ def transform(data: np.ndarray, axes: list[str], transformation_graph: nx.DiGrap
1318
# apply the transformations to the data (code to get inspired from https://github.com/scverse/spatialdata/blob/6652a03b1d66c8902a8f7a159176c51d8c9f823b/src/spatialdata/transformations/operations.py#L212)
1419
# tranform the data
1520
# return the transformed data as tuple (numpy array, output axes from the output coordinate systme)
16-
src_node = input_coordinate_system_name # assuming that node names are the coordinate system names
17-
def validate_input():
18-
pass
21+
22+
H, W, C = data.shape
23+
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
24+
points = np.stack([xx, yy], axis=-1).reshape(-1, 2)
25+
26+
validata_point_shape(points[0], transformation_sequence)
27+
28+
transformed_points = np.array([transformation_sequence.transform_point(p) for p in points])
29+
x_prime = transformed_points[:, 0].reshape(H, W)
30+
y_prime = transformed_points[:, 1].reshape(H, W)
31+
32+
return xarray.DataArray(data,
33+
coords={
34+
"x_prime": (("y", "x"), x_prime),
35+
"y_prime": (("y", "x"), y_prime),
36+
},
37+
dims=output_axes)

0 commit comments

Comments
 (0)