Skip to content

Commit 945db18

Browse files
committed
Fix align_corners mismatch in AffineTransform
The to_norm_affine call was using hardcoded align_corners=False while affine_grid and grid_sample used self.align_corners (default True). This caused a half-pixel offset between coordinate systems. - Change to_norm_affine to use self.align_corners for consistency - Update test expected values to reflect correct behavior - Add tests for align_corners consistency verification Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 57fdd59 commit 945db18

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

monai/networks/layers/spatial_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def forward(
566566
affine=theta,
567567
src_size=src_size[2:],
568568
dst_size=dst_size[2:],
569-
align_corners=False,
569+
align_corners=self.align_corners,
570570
zero_centered=self.zero_centered,
571571
)
572572
if self.reverse_indexing:

tests/networks/layers/test_affine_transform.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,21 @@ def test_zoom_1(self):
154154
affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
155155
image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0"))
156156
out = AffineTransform()(image, affine, (1, 4))
157-
expected = [[[[2.333333, 3.333333, 4.333333, 5.333333]]]]
157+
expected = [[[[5.0, 6.0, 7.0, 8.0]]]]
158158
np.testing.assert_allclose(out, expected, atol=_rtol)
159159

160160
def test_zoom_2(self):
161161
affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32)
162162
image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0"))
163163
out = AffineTransform((1, 2))(image, affine)
164-
expected = [[[[1.458333, 4.958333]]]]
164+
expected = [[[[5.0, 7.0]]]]
165165
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
166166

167167
def test_zoom_zero_center(self):
168168
affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32)
169169
image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device("cpu:0"))
170170
out = AffineTransform((1, 2), zero_centered=True)(image, affine)
171-
expected = [[[[5.5, 7.5]]]]
171+
expected = [[[[5.0, 8.0]]]]
172172
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
173173

174174
def test_affine_transform_minimum(self):
@@ -380,6 +380,53 @@ def test_forward_3d(self):
380380
np.testing.assert_allclose(actual, expected)
381381
np.testing.assert_allclose(list(theta.shape), [1, 3, 4])
382382

383+
def test_align_corners_consistency(self):
384+
"""
385+
Test that align_corners is consistently used between to_norm_affine and grid_sample.
386+
387+
With an identity affine transform, the output should match the input regardless of
388+
the align_corners setting. This test verifies that the coordinate normalization
389+
in to_norm_affine uses the same align_corners value as affine_grid/grid_sample.
390+
"""
391+
# Create a simple test image
392+
image = torch.arange(1.0, 13.0).view(1, 1, 3, 4)
393+
394+
# Identity affine in pixel space (i, j, k convention with reverse_indexing=True)
395+
identity_affine = torch.eye(3).unsqueeze(0)
396+
397+
# Test with align_corners=True (the default)
398+
xform_true = AffineTransform(align_corners=True)
399+
out_true = xform_true(image, identity_affine)
400+
np.testing.assert_allclose(out_true.numpy(), image.numpy(), atol=1e-5, rtol=_rtol)
401+
402+
# Test with align_corners=False
403+
xform_false = AffineTransform(align_corners=False)
404+
out_false = xform_false(image, identity_affine)
405+
np.testing.assert_allclose(out_false.numpy(), image.numpy(), atol=1e-5, rtol=_rtol)
406+
407+
def test_align_corners_true_translation(self):
408+
"""
409+
Test that translation works correctly with align_corners=True.
410+
411+
This ensures to_norm_affine correctly converts pixel-space translations
412+
to normalized coordinates when align_corners=True.
413+
"""
414+
# 4x4 image
415+
image = torch.arange(1.0, 17.0).view(1, 1, 4, 4)
416+
417+
# Translate by +1 pixel in the j direction (column direction)
418+
# With reverse_indexing=True (default), this is the last spatial dimension
419+
# Positive translation in the affine shifts the sampling grid, resulting in
420+
# the output appearing shifted in the opposite direction
421+
affine = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]]])
422+
423+
xform = AffineTransform(align_corners=True, padding_mode="zeros")
424+
out = xform(image, affine)
425+
426+
# Expected: shift columns left by 1, rightmost column becomes 0
427+
expected = torch.tensor([[[[2, 3, 4, 0], [6, 7, 8, 0], [10, 11, 12, 0], [14, 15, 16, 0]]]], dtype=torch.float32)
428+
np.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-4, rtol=_rtol)
429+
383430

384431
if __name__ == "__main__":
385432
unittest.main()

0 commit comments

Comments
 (0)