-
Notifications
You must be signed in to change notification settings - Fork 2
feat: Add GPU-accelerated operations via PyTorch #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Consolidates PRs #4-8 into a single feature branch: - phase_cross_correlation: GPU FFT via torch.fft (~46x speedup) - shift_array: GPU grid_sample for subpixel shifts (~6.7x speedup) - match_histograms: GPU sort/quantile mapping (~13.3x speedup) - block_reduce: GPU avg_pool2d (~4x speedup) - compute_ssim: GPU conv2d for local statistics (~6.4x speedup) All functions include automatic CPU fallback when CUDA is unavailable. Replaces cupy/cucim dependency with PyTorch for broader compatibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add type hints to phase_cross_correlation, shift_array, match_histograms, block_reduce, compute_ssim - Add return type hints to to_numpy and to_device - Import Callable, Any, Union from typing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Fix incorrect pixel assignment in _match_histograms_torch The previous code used unnecessary indexing that permuted results incorrectly - Simplify to_device return type from Union[Any, np.ndarray] to Any - Remove unused Union import - Add pixel-by-pixel test comparing GPU vs skimage results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Remove redundant `torch is not None` check in to_numpy - Add type hint to _shift_array_torch shift_vec parameter - Fix shift_array CPU path to compute in float64 for API consistency (preserve_dtype=False now returns float on both GPU and CPU paths) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds GPU-accelerated image processing operations to the tilefusion library by replacing the cupy/cucim stack with PyTorch. The implementation provides automatic CUDA detection with CPU fallback for broader compatibility.
Key Changes:
- Implements GPU-accelerated versions of 5 core functions using PyTorch (FFT, grid_sample, sort/quantile, avg_pool2d, conv2d)
- Adds automatic CUDA availability detection with transparent CPU fallback via scipy/skimage
- Provides comprehensive test coverage for GPU implementations, CPU fallback paths, and dtype preservation
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| src/tilefusion/utils.py | Core implementation of GPU-accelerated functions with PyTorch backend and CPU fallbacks |
| tests/test_fft.py | Tests for GPU phase cross-correlation with known shifts and subpixel refinement |
| tests/test_shift_array.py | Tests for GPU array shifting including integer/subpixel shifts and edge cases |
| tests/test_histogram_match.py | Tests for GPU histogram matching with range/correlation validation |
| tests/test_block_reduce.py | Tests for GPU block reduction with 2D/3D arrays and various block sizes |
| tests/test_ssim.py | Tests for GPU SSIM computation comparing against skimage reference |
| tests/test_cpu_fallback.py | Tests verifying CPU fallback behavior and dtype preservation across all functions |
| tests/conftest.py | Adds force_cpu fixture to test CPU fallback paths by disabling CUDA |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # inv_indices[i] = rank of pixel i, so look up ref value at that scaled rank | ||
| interp_values = ref_sorted[ | ||
| (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The histogram matching implementation has a potential issue with the quantile mapping calculation. The line (inv_indices.float() / len(img) * len(ref)) should use (len(img) - 1) and (len(ref) - 1) to properly map indices to the range [0, len-1]. Without the -1, the calculation can produce indices equal to len(ref) which are then clamped, leading to incorrect mapping at the high end of the distribution.
| # inv_indices[i] = rank of pixel i, so look up ref value at that scaled rank | |
| interp_values = ref_sorted[ | |
| (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) | |
| # inv_indices[i] = rank of pixel i in [0, len(img)-1]; scale to [0, len(ref)-1] | |
| interp_values = ref_sorted[ | |
| ( | |
| inv_indices.float() | |
| / (len(img) - 1) | |
| * (len(ref) - 1) | |
| ).long().clamp(0, len(ref) - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - current implementation passes pixel-by-pixel test vs skimage; the clamp(0, len(ref)-1) handles edge cases correctly.
| # Create uniform window | ||
| window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) | ||
|
|
||
| # Convert to tensors (1, 1, H, W) | ||
| img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) | ||
| img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) | ||
|
|
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _compute_ssim_torch function always creates the window tensor on CUDA device with hardcoded device="cuda". However, this doesn't account for multi-GPU systems where the default device might not be cuda:0. Consider using torch.device("cuda") or better yet, derive the device from the input tensors. Similarly, in other torch functions throughout the file, .cuda() calls should be reviewed to ensure they work correctly in multi-GPU environments.
| # Create uniform window | |
| window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) | |
| # Convert to tensors (1, 1, H, W) | |
| img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) | |
| img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) | |
| # Select CUDA device (assumes CUDA_AVAILABLE) | |
| device = torch.device("cuda") | |
| # Convert to tensors (1, 1, H, W) on the selected device | |
| img1 = ( | |
| torch.from_numpy(arr1) | |
| .to(device=device, dtype=torch.float32) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| ) | |
| img2 = ( | |
| torch.from_numpy(arr2) | |
| .to(device=device, dtype=torch.float32) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| ) | |
| # Create uniform window on the same device as the images | |
| window = torch.ones( | |
| 1, 1, win_size, win_size, device=img1.device, dtype=img1.dtype | |
| ) / (win_size * win_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - multi-GPU support is out of scope for this consolidation PR. Single-GPU with default device works for the target use case.
| def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: | ||
| """GPU block reduce using torch.nn.functional.avg_pool2d.""" | ||
| ndim = arr.ndim | ||
|
|
||
| if ndim == 2: | ||
| kernel = (block_size[0], block_size[1]) | ||
| t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) | ||
| out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) | ||
| return out.squeeze().cpu().numpy() | ||
|
|
||
| elif ndim == 3: | ||
| # For 3D arrays (C, H, W), extract spatial kernel from block_size | ||
| if len(block_size) == 3: | ||
| # block_size is (c_factor, h_factor, w_factor) | ||
| # Only use spatial dimensions for avg_pool2d | ||
| kernel = (block_size[1], block_size[2]) | ||
| else: | ||
| # block_size is (h_factor, w_factor) - apply to spatial dims | ||
| kernel = (block_size[0], block_size[1]) | ||
| t = torch.from_numpy(arr).float().cuda().unsqueeze(0) | ||
| out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) | ||
| return out.squeeze(0).cpu().numpy() | ||
|
|
||
| return _block_reduce_cpu(arr, block_size, np.mean) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _block_reduce_torch function handles 2D and 3D arrays but falls back to CPU for any other dimensionality. However, the function doesn't validate that the dimensions are compatible with the block_size tuple length. For a 2D array with a 3-element block_size tuple, or vice versa, this could cause confusing behavior. Consider adding validation to ensure block_size matches the array dimensionality before attempting GPU operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - mismatched dimensions safely fall back to CPU path via the final return statement.
| # Guard against degenerate arrays (need at least 2 pixels for interpolation) | ||
| if h < 2 or w < 2: | ||
| return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shift_array function falls back to CPU for arrays with dimensions less than 2x2, but the check if h < 2 or w < 2 doesn't match the actual requirement for grid_sample. PyTorch's grid_sample can handle 1xN or Nx1 arrays. The more accurate check would be if h == 1 and w == 1 since that's the only truly degenerate case where interpolation cannot work. This overly conservative fallback reduces GPU acceleration opportunities for small but valid arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] False positive - the check is correct. The normalization uses (w-1) and (h-1) as divisors, which would cause division by zero for 1-pixel dimensions. The guard is necessary.
| def _phase_cross_correlation_torch( | ||
| reference_image: np.ndarray, moving_image: np.ndarray, upsample_factor: int = 1 | ||
| ) -> tuple: | ||
| """GPU phase cross-correlation using torch FFT.""" | ||
| ref = torch.from_numpy(reference_image.astype(np.float32)).cuda() | ||
| mov = torch.from_numpy(moving_image.astype(np.float32)).cuda() | ||
|
|
||
| # Compute cross-power spectrum | ||
| ref_fft = torch.fft.fft2(ref) | ||
| mov_fft = torch.fft.fft2(mov) | ||
| cross_power = ref_fft * torch.conj(mov_fft) | ||
| cross_power = cross_power / (torch.abs(cross_power) + _FFT_EPS) | ||
|
|
||
| # Inverse FFT to get correlation | ||
| correlation = torch.fft.ifft2(cross_power).real | ||
|
|
||
| # Find peak | ||
| max_idx = torch.argmax(correlation) | ||
| h, w = correlation.shape | ||
| peak_y = (max_idx // w).item() | ||
| peak_x = (max_idx % w).item() | ||
|
|
||
| # Handle wraparound for negative shifts | ||
| if peak_y > h // 2: | ||
| peak_y -= h | ||
| if peak_x > w // 2: | ||
| peak_x -= w | ||
|
|
||
| shift = np.array([float(peak_y), float(peak_x)]) | ||
|
|
||
| # Subpixel refinement if requested | ||
| if upsample_factor > 1: | ||
| shift = _subpixel_refine_torch(correlation, peak_y, peak_x, h, w) | ||
|
|
||
| return shift, 0.0, 0.0 |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All GPU tensor operations in this file lack explicit memory cleanup. While PyTorch has automatic garbage collection, for large arrays in a tight loop (common in image stitching), it's better to explicitly call torch.cuda.empty_cache() or delete intermediate tensors. Consider adding a try-finally block or context manager to ensure GPU memory is properly released, especially for the large FFT operations in _phase_cross_correlation_torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - explicit torch.cuda.empty_cache() forces synchronization which can slow down processing. PyTorch's automatic GC handles memory well in practice.
|
|
||
| def test_integer_shift(self, rng): | ||
| """Test integer shift matches scipy.""" | ||
| arr = rng.random((256, 256)).astype(np.float32) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test uses rng.random() without specifying dtype, which defaults to float64, but then casts to float32. This is inefficient and could be simplified to rng.random((256, 256), dtype=np.float32) directly. The same pattern appears throughout this test file on multiple lines (18, 25-26, 31-32, 40).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - low value change. The cast to float32 is explicit and doesn't affect test correctness.
| img = rng.random((256, 256)).astype(np.float32) | ||
| ref = rng.random((256, 256)).astype(np.float32) * 2 + 1 |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test uses rng.random() without specifying dtype, which defaults to float64, but then casts to float32. This is inefficient and could be simplified to rng.random((256, 256), dtype=np.float32) directly. The same pattern appears throughout this test file on multiple lines (18-19, 27-28, 40, 46-47, 53-54).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - low value change. The cast to float32 is explicit and doesn't affect test correctness.
| def _shift_array_torch(arr: np.ndarray, shift_vec: tuple[float, float]) -> np.ndarray: | ||
| """GPU shift using torch.nn.functional.grid_sample.""" | ||
| h, w = arr.shape | ||
|
|
||
| # Guard against degenerate arrays (need at least 2 pixels for interpolation) | ||
| if h < 2 or w < 2: | ||
| return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) | ||
|
|
||
| dy, dx = float(shift_vec[0]), float(shift_vec[1]) | ||
|
|
||
| # Create pixel coordinate grids | ||
| y_coords = torch.arange(h, device="cuda", dtype=torch.float32) | ||
| x_coords = torch.arange(w, device="cuda", dtype=torch.float32) | ||
| grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") | ||
|
|
||
| # Apply shift: to shift output by (dy, dx), sample from (y-dy, x-dx) | ||
| sample_y = grid_y - dy | ||
| sample_x = grid_x - dx | ||
|
|
||
| # Normalize to [-1, 1] for grid_sample (align_corners=True) | ||
| sample_x = 2 * sample_x / (w - 1) - 1 | ||
| sample_y = 2 * sample_y / (h - 1) - 1 | ||
|
|
||
| # Stack to (H, W, 2) with (x, y) order, add batch dim -> (1, H, W, 2) | ||
| grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0) | ||
|
|
||
| # Input: (1, 1, H, W) | ||
| t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) | ||
|
|
||
| # grid_sample with bilinear interpolation | ||
| out = F.grid_sample(t, grid, mode="bilinear", padding_mode="zeros", align_corners=True) | ||
|
|
||
| return out.squeeze().cpu().numpy() |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GPU functions like _shift_array_torch, _match_histograms_torch, etc., don't have any error handling for CUDA out-of-memory errors. In production use with large images, GPU memory can be exhausted. Consider wrapping GPU operations in try-except blocks to catch torch.cuda.OutOfMemoryError and automatically fall back to CPU, providing a more robust user experience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - OOM error handling is out of scope for this consolidation PR. Would require significant refactoring.
| def to_device(arr) -> Any: | ||
| """Move array to GPU if available, else return numpy array. | ||
| Returns torch.Tensor on GPU if CUDA available, else np.ndarray. | ||
| """ | ||
| if CUDA_AVAILABLE: | ||
| return torch.from_numpy(np.asarray(arr)).cuda() | ||
| return np.asarray(arr) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function signature for to_device returns Any type, but the docstring and implementation clearly show it returns either torch.Tensor or np.ndarray. The return type should be more specific: torch.Tensor | np.ndarray (or Union[torch.Tensor, np.ndarray] for older Python versions) to provide proper type hints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Intentional - using Any because Union[torch.Tensor, np.ndarray] requires conditional imports when torch may not be installed. The docstring documents actual behavior.
| def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): | ||
| """Subpixel refinement using parabolic fit around peak.""" | ||
| py = peak_y % h | ||
| px = peak_x % w | ||
|
|
||
| y_indices = [(py - 1) % h, py, (py + 1) % h] | ||
| x_indices = [(px - 1) % w, px, (px + 1) % w] | ||
|
|
||
| neighborhood = torch.zeros(3, 3, device="cuda") | ||
| for i, yi in enumerate(y_indices): | ||
| for j, xj in enumerate(x_indices): | ||
| neighborhood[i, j] = correlation[yi, xj] | ||
|
|
||
| center_val = neighborhood[1, 1].item() | ||
|
|
||
| # Y direction parabolic fit | ||
| if neighborhood[0, 1].item() != center_val or neighborhood[2, 1].item() != center_val: | ||
| denom = 2 * (2 * center_val - neighborhood[0, 1].item() - neighborhood[2, 1].item()) | ||
| dy = ( | ||
| (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom | ||
| if abs(denom) > _PARABOLIC_EPS | ||
| else 0.0 | ||
| ) | ||
| else: | ||
| dy = 0.0 | ||
|
|
||
| # X direction parabolic fit | ||
| if neighborhood[1, 0].item() != center_val or neighborhood[1, 2].item() != center_val: | ||
| denom = 2 * (2 * center_val - neighborhood[1, 0].item() - neighborhood[1, 2].item()) | ||
| dx = ( | ||
| (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom | ||
| if abs(denom) > _PARABOLIC_EPS | ||
| else 0.0 | ||
| ) | ||
| else: | ||
| dx = 0.0 | ||
|
|
||
| dy = max(-0.5, min(0.5, dy)) | ||
| dx = max(-0.5, min(0.5, dx)) | ||
|
|
||
| return np.array([float(peak_y) + dy, float(peak_x) + dx]) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The subpixel refinement implementation uses a 3x3 neighborhood but doesn't validate that the array dimensions are at least 3x3. If the correlation map is smaller than 3x3 (which could happen with very small input images), the modulo indexing (py - 1) % h could access incorrect positions. Consider adding a size check before applying subpixel refinement or documenting minimum image size requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - edge case for tiny images. Subpixel refinement is only called with upsample_factor>1, typically on larger images. The modulo indexing handles wraparound correctly.
The GPU implementation returns 0.0 for error and phasediff values since these are not computed. Added notes to docstring to clarify. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 16 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) | ||
| img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _compute_ssim_torch function doesn't validate that arr1 and arr2 have the same shape. If they have different shapes, the operations will fail with an unclear error. Consider adding a shape validation check before processing to provide a clearer error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - torch operations will error naturally on shape mismatch with a clear error message.
| import sys | ||
|
|
||
| sys.path.insert(0, "src") | ||
|
|
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test imports are not using the installed package but instead modifying sys.path to import from 'src'. This is inconsistent with test_utils.py which imports directly from tilefusion.utils. For consistency and to test the actual installed package, consider using the same import pattern as test_utils.py or ensure the package is installed in editable mode for testing.
| import sys | |
| sys.path.insert(0, "src") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - sys.path modification is a common pattern for testing src layout projects during development.
| class TestPhaseCorrelation: | ||
| """Tests for phase_cross_correlation function.""" | ||
|
|
||
| def test_known_shift(self, rng): | ||
| """Test detection of known integer shift.""" | ||
| ref = rng.random((256, 256)).astype(np.float32) | ||
|
|
||
| # Create shifted version: mov is ref shifted by (+5, -3) | ||
| # phase_cross_correlation returns shift to apply to mov to align with ref | ||
| # So it should return (-5, +3) | ||
| mov = np.zeros_like(ref) | ||
| mov[5:, :253] = ref[:-5, 3:] | ||
|
|
||
| shift, _, _ = phase_cross_correlation(ref, mov) | ||
|
|
||
| assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" | ||
| assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3" | ||
|
|
||
| def test_zero_shift(self, rng): | ||
| """Test that identical images give zero shift.""" | ||
| ref = rng.random((256, 256)).astype(np.float32) | ||
|
|
||
| shift, _, _ = phase_cross_correlation(ref, ref) | ||
|
|
||
| assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0" | ||
| assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0" | ||
|
|
||
| def test_matches_skimage_direction(self, rng): | ||
| """Test that shift direction matches skimage convention.""" | ||
| ref = rng.random((128, 128)).astype(np.float32) | ||
|
|
||
| # Shift by rolling | ||
| mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1) | ||
|
|
||
| gpu_shift, _, _ = phase_cross_correlation(ref, mov) | ||
| cpu_shift, _, _ = skimage_pcc(ref, mov) | ||
|
|
||
| # Directions should match | ||
| assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch" | ||
| assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite lacks tests for error conditions such as mismatched image shapes for phase_cross_correlation and compute_ssim. While the implementation doesn't currently validate shapes, adding negative tests would document expected behavior when invalid inputs are provided. Consider adding tests that verify the functions handle or reject mismatched shapes appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - negative tests for error conditions are nice-to-have but not critical for this PR.
| class TestComputeSSIM: | ||
| """Tests for compute_ssim function.""" | ||
|
|
||
| def test_ssim_similar_images(self, rng): | ||
| """Test SSIM of similar images matches skimage.""" | ||
| arr1 = rng.random((256, 256)).astype(np.float32) | ||
| arr2 = arr1 + rng.random((256, 256)).astype(np.float32) * 0.1 | ||
|
|
||
| data_range = arr1.max() - arr1.min() | ||
| cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) | ||
| gpu = compute_ssim(arr1, arr2, win_size=15) | ||
|
|
||
| assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu - gpu)} too high" | ||
|
|
||
| def test_ssim_identical_images(self, rng): | ||
| """Test SSIM of identical images is ~1.0.""" | ||
| arr = rng.random((256, 256)).astype(np.float32) | ||
| ssim = compute_ssim(arr, arr, win_size=15) | ||
| assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" | ||
|
|
||
| def test_ssim_different_images(self, rng): | ||
| """Test SSIM of random images is low.""" | ||
| arr1 = rng.random((256, 256)).astype(np.float32) | ||
| arr2 = rng.random((256, 256)).astype(np.float32) | ||
| ssim = compute_ssim(arr1, arr2, win_size=15) | ||
| assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" | ||
|
|
||
| def test_ssim_range(self, rng): | ||
| """Test SSIM is in valid range [0, 1].""" | ||
| arr1 = rng.random((128, 128)).astype(np.float32) | ||
| arr2 = rng.random((128, 128)).astype(np.float32) | ||
| ssim = compute_ssim(arr1, arr2, win_size=7) | ||
| assert 0.0 <= ssim <= 1.0, f"SSIM {ssim} outside valid range" |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite lacks tests for error conditions such as mismatched image shapes for compute_ssim. While the implementation doesn't currently validate shapes, adding negative tests would document expected behavior when invalid inputs are provided. Consider adding tests that verify the function handles or rejects mismatched shapes appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - negative tests for error conditions are nice-to-have but not critical for this PR.
| def shift_array( | ||
| arr: np.ndarray, | ||
| shift_vec: tuple[float, float], | ||
| preserve_dtype: bool = True, |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shift_array function's default parameter preserve_dtype=True changes the existing API behavior. Previously, shift_array returned the dtype from scipy/cupy, but now it explicitly preserves the input dtype by default. This is a breaking change that could affect existing code that depends on the dtype conversion behavior. Functions in registration.py call shift_array without specifying preserve_dtype, which means they will now receive different dtypes than before. Consider making preserve_dtype=False to maintain backward compatibility, or add migration notes in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Intentional - preserve_dtype=True is more intuitive (input uint16 → output uint16). All existing tests pass. The previous inconsistent behavior was a quirk, not a feature.
| # Input: (1, 1, H, W) | ||
| t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _shift_array_torch function doesn't convert the input array to float32 before passing to torch, unlike _phase_cross_correlation_torch which explicitly does astype(np.float32). This means if a non-float32 array is passed, torch.from_numpy will use the original dtype, which might not be ideal for GPU operations. Consider explicitly converting to float32 for consistency with other GPU functions.
| # Input: (1, 1, H, W) | |
| t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) | |
| # Ensure input is float32 and on GPU: (1, 1, H, W) | |
| arr_f32 = arr.astype(np.float32, copy=False) | |
| t = torch.from_numpy(arr_f32).cuda().unsqueeze(0).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] False positive - the code already uses .float() which converts to float32: torch.from_numpy(arr).float().cuda()
| import sys | ||
|
|
||
| sys.path.insert(0, "src") | ||
|
|
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test imports are not using the installed package but instead modifying sys.path to import from 'src'. This is inconsistent with test_utils.py which imports directly from tilefusion.utils. For consistency and to test the actual installed package, consider using the same import pattern as test_utils.py or ensure the package is installed in editable mode for testing.
| import sys | |
| sys.path.insert(0, "src") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - sys.path modification is a common pattern for testing src layout projects during development.
| import sys | ||
| from skimage.measure import block_reduce as skimage_block_reduce | ||
|
|
||
| sys.path.insert(0, "src") | ||
|
|
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test imports are not using the installed package but instead modifying sys.path to import from 'src'. This is inconsistent with test_utils.py which imports directly from tilefusion.utils. For consistency and to test the actual installed package, consider using the same import pattern as test_utils.py or ensure the package is installed in editable mode for testing.
| import sys | |
| from skimage.measure import block_reduce as skimage_block_reduce | |
| sys.path.insert(0, "src") | |
| from skimage.measure import block_reduce as skimage_block_reduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - sys.path modification is a common pattern for testing src layout projects during development.
| import sys | ||
|
|
||
| sys.path.insert(0, "src") | ||
|
|
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test imports are not using the installed package but instead modifying sys.path to import from 'src'. This is inconsistent with test_utils.py which imports directly from tilefusion.utils. For consistency and to test the actual installed package, consider using the same import pattern as test_utils.py or ensure the package is installed in editable mode for testing.
| import sys | |
| sys.path.insert(0, "src") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - sys.path modification is a common pattern for testing src layout projects during development.
| def _phase_cross_correlation_torch( | ||
| reference_image: np.ndarray, moving_image: np.ndarray, upsample_factor: int = 1 | ||
| ) -> tuple: | ||
| """GPU phase cross-correlation using torch FFT.""" |
Copilot
AI
Jan 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _phase_cross_correlation_torch function doesn't validate that reference_image and moving_image have the same shape. If they have different shapes, torch.fft.fft2 will still work but the cross-power spectrum multiplication will fail or produce incorrect results. Consider adding a shape validation check before processing.
| """GPU phase cross-correlation using torch FFT.""" | |
| """GPU phase cross-correlation using torch FFT.""" | |
| # Validate that input images have the same shape to ensure elementwise FFT operations are valid | |
| if reference_image.shape != moving_image.shape: | |
| raise ValueError( | |
| f"reference_image and moving_image must have the same shape; " | |
| f"got {reference_image.shape} and {moving_image.shape}." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Claude Code] Skipped - shape validation adds overhead to every call. Torch operations will error naturally with a clear message if shapes mismatch.
Summary
Consolidates PRs #4-8 into a single feature branch, replacing the cupy/cucim GPU stack with PyTorch for broader compatibility and easier installation.
GPU-accelerated functions:
phase_cross_correlationshift_arraymatch_histogramsblock_reducecompute_ssimKey changes:
Supersedes:
Test plan
pytest tests/test_*.py🤖 Generated with Claude Code