diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..944d6ec 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -2,54 +2,474 @@ Shared utilities for tilefusion. GPU/CPU detection, array operations, and helper functions. +All functions support GPU acceleration via PyTorch with automatic CPU fallback. """ +from typing import Any, Callable + import numpy as np +__all__ = [ + # GPU detection flags + "TORCH_AVAILABLE", + "CUDA_AVAILABLE", + "USING_GPU", + # Array module (legacy compatibility) + "xp", + "cp", + # Core functions + "phase_cross_correlation", + "shift_array", + "match_histograms", + "block_reduce", + "compute_ssim", + # Utility functions + "make_1d_profile", + "to_numpy", + "to_device", +] + +# GPU detection - PyTorch based try: - import cupy as cp - from cupyx.scipy.ndimage import shift as cp_shift - from cucim.skimage.exposure import match_histograms - from cucim.skimage.measure import block_reduce - from cucim.skimage.registration import phase_cross_correlation - from opm_processing.imageprocessing.ssim_cuda import ( - structural_similarity_cupy_sep_shared as ssim_cuda, - ) + import torch + import torch.nn.functional as F + + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + torch = None + F = None + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False + +# CPU fallbacks +from scipy.ndimage import shift as _shift_cpu +from skimage.exposure import match_histograms as _match_histograms_cpu +from skimage.measure import block_reduce as _block_reduce_cpu +from skimage.metrics import structural_similarity as _ssim_cpu +from skimage.registration import phase_cross_correlation as _phase_cross_correlation_cpu + +# Legacy compatibility - used by core.py and registration.py +# xp: array module (numpy, since cupy was removed) +# cp: cupy module (always None now, kept for API compatibility) +# USING_GPU: exported in __init__.py for user code +USING_GPU = CUDA_AVAILABLE +xp = np +cp = None # cupy removed; GPU ops now use PyTorch internally + +# Constants +_FFT_EPS = 1e-10 # Epsilon for FFT normalization to avoid division by zero +_PARABOLIC_EPS = 1e-10 # Epsilon for parabolic fit denominator check +_SSIM_K1 = 0.01 # SSIM constant K1 (luminance) +_SSIM_K2 = 0.03 # SSIM constant K2 (contrast) + + +# ============================================================================= +# Phase Cross-Correlation (GPU FFT) +# ============================================================================= + + +def phase_cross_correlation( + reference_image: np.ndarray, + moving_image: np.ndarray, + upsample_factor: int = 1, + **kwargs, +) -> tuple[np.ndarray, float, float]: + """ + Phase cross-correlation using GPU (torch FFT) or CPU (skimage). + + Parameters + ---------- + reference_image : ndarray + Reference image. + moving_image : ndarray + Image to register. + upsample_factor : int + Upsampling factor for subpixel precision. + + Returns + ------- + shift : ndarray + Shift vector (y, x). + error : float + Translation invariant normalized RMS error. + Note: GPU path returns 0.0 (not computed). + phasediff : float + Global phase difference. + Note: GPU path returns 0.0 (not computed). + """ + ref_np = np.asarray(reference_image) + mov_np = np.asarray(moving_image) + + if CUDA_AVAILABLE and ref_np.ndim == 2: + return _phase_cross_correlation_torch(ref_np, mov_np, upsample_factor) + return _phase_cross_correlation_cpu(ref_np, mov_np, upsample_factor=upsample_factor, **kwargs) + + +def _phase_cross_correlation_torch( + reference_image: np.ndarray, moving_image: np.ndarray, upsample_factor: int = 1 +) -> tuple: + """GPU phase cross-correlation using torch FFT.""" + ref = torch.from_numpy(reference_image.astype(np.float32)).cuda() + mov = torch.from_numpy(moving_image.astype(np.float32)).cuda() + + # Compute cross-power spectrum + ref_fft = torch.fft.fft2(ref) + mov_fft = torch.fft.fft2(mov) + cross_power = ref_fft * torch.conj(mov_fft) + cross_power = cross_power / (torch.abs(cross_power) + _FFT_EPS) + + # Inverse FFT to get correlation + correlation = torch.fft.ifft2(cross_power).real + + # Find peak + max_idx = torch.argmax(correlation) + h, w = correlation.shape + peak_y = (max_idx // w).item() + peak_x = (max_idx % w).item() + + # Handle wraparound for negative shifts + if peak_y > h // 2: + peak_y -= h + if peak_x > w // 2: + peak_x -= w + + shift = np.array([float(peak_y), float(peak_x)]) + + # Subpixel refinement if requested + if upsample_factor > 1: + shift = _subpixel_refine_torch(correlation, peak_y, peak_x, h, w) + + return shift, 0.0, 0.0 + + +def _subpixel_refine_torch(correlation, peak_y, peak_x, h, w): + """Subpixel refinement using parabolic fit around peak.""" + py = peak_y % h + px = peak_x % w + + y_indices = [(py - 1) % h, py, (py + 1) % h] + x_indices = [(px - 1) % w, px, (px + 1) % w] + + neighborhood = torch.zeros(3, 3, device="cuda") + for i, yi in enumerate(y_indices): + for j, xj in enumerate(x_indices): + neighborhood[i, j] = correlation[yi, xj] + + center_val = neighborhood[1, 1].item() + + # Y direction parabolic fit + if neighborhood[0, 1].item() != center_val or neighborhood[2, 1].item() != center_val: + denom = 2 * (2 * center_val - neighborhood[0, 1].item() - neighborhood[2, 1].item()) + dy = ( + (neighborhood[0, 1].item() - neighborhood[2, 1].item()) / denom + if abs(denom) > _PARABOLIC_EPS + else 0.0 + ) + else: + dy = 0.0 + + # X direction parabolic fit + if neighborhood[1, 0].item() != center_val or neighborhood[1, 2].item() != center_val: + denom = 2 * (2 * center_val - neighborhood[1, 0].item() - neighborhood[1, 2].item()) + dx = ( + (neighborhood[1, 0].item() - neighborhood[1, 2].item()) / denom + if abs(denom) > _PARABOLIC_EPS + else 0.0 + ) + else: + dx = 0.0 + + dy = max(-0.5, min(0.5, dy)) + dx = max(-0.5, min(0.5, dx)) + + return np.array([float(peak_y) + dy, float(peak_x) + dx]) + + +# ============================================================================= +# Shift Array (GPU grid_sample) +# ============================================================================= + + +def shift_array( + arr: np.ndarray, + shift_vec: tuple[float, float], + preserve_dtype: bool = True, +) -> np.ndarray: + """ + Shift array by subpixel amounts using GPU (torch) or CPU (scipy). + + Parameters + ---------- + arr : ndarray + 2D input array. + shift_vec : tuple[float, float] + (dy, dx) shift amounts. + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. + + Returns + ------- + shifted : ndarray + Shifted array, same shape as input. + """ + arr_np = np.asarray(arr) + original_dtype = arr_np.dtype + + if CUDA_AVAILABLE and arr_np.ndim == 2: + result = _shift_array_torch(arr_np, shift_vec) + else: + # Compute in float for consistency with GPU path + arr_float = arr_np.astype(np.float64) + result = _shift_cpu(arr_float, shift=shift_vec, order=1, prefilter=False) + + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result + + +def _shift_array_torch(arr: np.ndarray, shift_vec: tuple[float, float]) -> np.ndarray: + """GPU shift using torch.nn.functional.grid_sample.""" + h, w = arr.shape + + # Guard against degenerate arrays (need at least 2 pixels for interpolation) + if h < 2 or w < 2: + return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) + + dy, dx = float(shift_vec[0]), float(shift_vec[1]) + + # Create pixel coordinate grids + y_coords = torch.arange(h, device="cuda", dtype=torch.float32) + x_coords = torch.arange(w, device="cuda", dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") + + # Apply shift: to shift output by (dy, dx), sample from (y-dy, x-dx) + sample_y = grid_y - dy + sample_x = grid_x - dx + + # Normalize to [-1, 1] for grid_sample (align_corners=True) + sample_x = 2 * sample_x / (w - 1) - 1 + sample_y = 2 * sample_y / (h - 1) - 1 + + # Stack to (H, W, 2) with (x, y) order, add batch dim -> (1, H, W, 2) + grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0) + + # Input: (1, 1, H, W) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) + + # grid_sample with bilinear interpolation + out = F.grid_sample(t, grid, mode="bilinear", padding_mode="zeros", align_corners=True) + + return out.squeeze().cpu().numpy() + + +# ============================================================================= +# Match Histograms (GPU sort/quantile) +# ============================================================================= + - xp = cp - USING_GPU = True -except Exception: - cp = None - cp_shift = None - from skimage.exposure import match_histograms - from skimage.measure import block_reduce - from skimage.registration import phase_cross_correlation - from scipy.ndimage import shift as _shift_cpu - from skimage.metrics import structural_similarity as _ssim_cpu - - xp = np - USING_GPU = False - - -def shift_array(arr, shift_vec): - """Shift array using GPU if available, else CPU fallback.""" - if USING_GPU and cp_shift is not None: - return cp_shift(arr, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) - - -def compute_ssim(arr1, arr2, win_size: int) -> float: - """SSIM wrapper that routes to GPU kernel or CPU skimage.""" - if USING_GPU and "ssim_cuda" in globals(): - return float(ssim_cuda(arr1, arr2, win_size=win_size)) - arr1_np = np.asarray(arr1) - arr2_np = np.asarray(arr2) +def match_histograms( + image: np.ndarray, + reference: np.ndarray, + preserve_dtype: bool = True, +) -> np.ndarray: + """ + Match histogram of image to reference using GPU (torch) or CPU (skimage). + + Parameters + ---------- + image : ndarray + Image to transform. + reference : ndarray + Reference image for histogram matching. + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. + + Returns + ------- + matched : ndarray + Image with matched histogram. + """ + image_np = np.asarray(image) + reference_np = np.asarray(reference) + original_dtype = image_np.dtype + + if CUDA_AVAILABLE and image_np.ndim == 2: + result = _match_histograms_torch(image_np, reference_np) + else: + result = _match_histograms_cpu(image_np, reference_np) + + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result + + +def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndarray: + """GPU histogram matching using torch operations.""" + # Move to GPU + img = torch.from_numpy(image.astype(np.float32)).cuda().flatten() + ref = torch.from_numpy(reference.astype(np.float32)).cuda().flatten() + + # Get sorted indices + _, img_indices = torch.sort(img) + ref_sorted, _ = torch.sort(ref) + + # Create inverse mapping (rank of each pixel) + inv_indices = torch.empty_like(img_indices) + inv_indices[img_indices] = torch.arange(len(img), device="cuda") + + # Map image values to reference values via quantile matching + # inv_indices[i] = rank of pixel i, so look up ref value at that scaled rank + interp_values = ref_sorted[ + (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) + ] + + return interp_values.reshape(image.shape).cpu().numpy() + + +# ============================================================================= +# Block Reduce (GPU avg_pool2d) +# ============================================================================= + + +def block_reduce( + arr: np.ndarray, + block_size: tuple[int, ...], + func: Callable = np.mean, + preserve_dtype: bool = True, +) -> np.ndarray: + """ + Block reduce array using GPU (torch) or CPU (skimage). + + Parameters + ---------- + arr : ndarray + Input array (2D or 3D with channel dim first). + block_size : tuple[int, ...] + Reduction factors per dimension. + func : Callable + Reduction function (only np.mean supported on GPU). + preserve_dtype : bool + If True, output dtype matches input dtype. Default True. + + Returns + ------- + reduced : ndarray + """ + arr_np = np.asarray(arr) + original_dtype = arr_np.dtype + + if CUDA_AVAILABLE and func == np.mean and arr_np.ndim >= 2: + result = _block_reduce_torch(arr_np, block_size) + else: + result = _block_reduce_cpu(arr_np, block_size, func) + + if preserve_dtype and result.dtype != original_dtype: + return result.astype(original_dtype) + return result + + +def _block_reduce_torch(arr: np.ndarray, block_size: tuple) -> np.ndarray: + """GPU block reduce using torch.nn.functional.avg_pool2d.""" + ndim = arr.ndim + + if ndim == 2: + kernel = (block_size[0], block_size[1]) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) + out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) + return out.squeeze().cpu().numpy() + + elif ndim == 3: + # For 3D arrays (C, H, W), extract spatial kernel from block_size + if len(block_size) == 3: + # block_size is (c_factor, h_factor, w_factor) + # Only use spatial dimensions for avg_pool2d + kernel = (block_size[1], block_size[2]) + else: + # block_size is (h_factor, w_factor) - apply to spatial dims + kernel = (block_size[0], block_size[1]) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0) + out = torch.nn.functional.avg_pool2d(t, kernel, stride=kernel) + return out.squeeze(0).cpu().numpy() + + return _block_reduce_cpu(arr, block_size, np.mean) + + +# ============================================================================= +# Compute SSIM (GPU conv2d) +# ============================================================================= + + +def compute_ssim(arr1: np.ndarray, arr2: np.ndarray, win_size: int) -> float: + """ + Compute SSIM using GPU (torch) or CPU (skimage). + + Parameters + ---------- + arr1, arr2 : ndarray + Input images (2D). + win_size : int + Window size for local statistics. + + Returns + ------- + ssim : float + Mean SSIM value. + """ + arr1_np = np.asarray(arr1, dtype=np.float32) + arr2_np = np.asarray(arr2, dtype=np.float32) + + # Compute data range once data_range = float(arr1_np.max() - arr1_np.min()) if data_range == 0: data_range = 1.0 + + if CUDA_AVAILABLE and arr1_np.ndim == 2: + return _compute_ssim_torch(arr1_np, arr2_np, win_size, data_range) + return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) +def _compute_ssim_torch( + arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float +) -> float: + """GPU SSIM using torch conv2d for local statistics.""" + C1 = (_SSIM_K1 * data_range) ** 2 + C2 = (_SSIM_K2 * data_range) ** 2 + + # Create uniform window + window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) + + # Convert to tensors (1, 1, H, W) + img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) + img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) + + # Compute local means + mu1 = F.conv2d(img1, window, padding=win_size // 2) + mu2 = F.conv2d(img2, window, padding=win_size // 2) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + + # Compute local variances and covariance + sigma1_sq = F.conv2d(img1**2, window, padding=win_size // 2) - mu1_sq + sigma2_sq = F.conv2d(img2**2, window, padding=win_size // 2) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=win_size // 2) - mu1_mu2 + + # SSIM formula + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + return float(ssim_map.mean().cpu()) + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + def make_1d_profile(length: int, blend: int) -> np.ndarray: """ Create a linear ramp profile over `blend` pixels at each end. @@ -75,13 +495,18 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: return prof -def to_numpy(arr): +def to_numpy(arr) -> np.ndarray: """Convert array to numpy, handling both CPU and GPU arrays.""" - if USING_GPU and cp is not None and isinstance(arr, cp.ndarray): - return cp.asnumpy(arr) + if TORCH_AVAILABLE and isinstance(arr, torch.Tensor): + return arr.cpu().numpy() return np.asarray(arr) -def to_device(arr): - """Move array to current device (GPU if available, else CPU).""" - return xp.asarray(arr) +def to_device(arr) -> Any: + """Move array to GPU if available, else return numpy array. + + Returns torch.Tensor on GPU if CUDA available, else np.ndarray. + """ + if CUDA_AVAILABLE: + return torch.from_numpy(np.asarray(arr)).cuda() + return np.asarray(arr) diff --git a/tests/conftest.py b/tests/conftest.py index 2518116..a1a3105 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,3 +20,13 @@ def sample_tile(rng): def sample_multichannel_tile(rng): """Generate a sample multi-channel tile.""" return rng.random((3, 100, 100), dtype=np.float32) * 65535 + + +@pytest.fixture +def force_cpu(monkeypatch): + """Force CPU fallback by setting CUDA_AVAILABLE to False.""" + import tilefusion.utils as utils + + monkeypatch.setattr(utils, "CUDA_AVAILABLE", False) + yield + # monkeypatch automatically restores after test diff --git a/tests/test_block_reduce.py b/tests/test_block_reduce.py new file mode 100644 index 0000000..95ad033 --- /dev/null +++ b/tests/test_block_reduce.py @@ -0,0 +1,67 @@ +"""Unit tests for GPU block_reduce.""" + +import numpy as np +import pytest +import sys +from skimage.measure import block_reduce as skimage_block_reduce + +sys.path.insert(0, "src") + +from tilefusion.utils import block_reduce + + +class TestBlockReduce: + """Test block_reduce GPU vs CPU equivalence.""" + + def test_2d_basic(self, rng): + """Test 2D block reduce matches skimage.""" + arr = rng.random((256, 256)).astype(np.float32) + block_size = (4, 4) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_2d_large(self, rng): + """Test larger 2D array.""" + arr = rng.random((1024, 1024)).astype(np.float32) + block_size = (8, 8) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_3d_multichannel(self, rng): + """Test 3D array with channel dimension.""" + arr = rng.random((3, 256, 256)).astype(np.float32) + block_size = (1, 4, 4) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + def test_output_shape(self, rng): + """Test output shape is correct.""" + arr = rng.random((512, 512)).astype(np.float32) + block_size = (4, 4) + + result = block_reduce(arr, block_size, np.mean) + + assert result.shape == (128, 128) + + def test_non_divisible_shape(self, rng): + """Test block reduce with non-divisible dimensions.""" + arr = rng.random((100, 100)).astype(np.float32) + block_size = (8, 8) + + result = block_reduce(arr, block_size, np.mean) + expected = skimage_block_reduce(arr, block_size, np.mean) + + np.testing.assert_allclose(result, expected, rtol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_cpu_fallback.py b/tests/test_cpu_fallback.py new file mode 100644 index 0000000..1b72206 --- /dev/null +++ b/tests/test_cpu_fallback.py @@ -0,0 +1,136 @@ +"""Tests for CPU fallback paths and dtype preservation.""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import ( + phase_cross_correlation, + shift_array, + match_histograms, + block_reduce, + compute_ssim, +) + + +class TestCPUFallback: + """Test that CPU fallback paths work correctly.""" + + def test_phase_cross_correlation_cpu(self, force_cpu, rng): + """Test phase_cross_correlation with CPU fallback.""" + ref = rng.random((128, 128)).astype(np.float32) + mov = np.roll(ref, 5, axis=0) + + shift, error, phasediff = phase_cross_correlation(ref, mov) + + assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" + + def test_shift_array_cpu(self, force_cpu, rng): + """Test shift_array with CPU fallback.""" + arr = rng.random((128, 128)).astype(np.float32) + result = shift_array(arr, (3.0, -2.0)) + + assert result.shape == arr.shape + assert result.dtype == arr.dtype + + def test_match_histograms_cpu(self, force_cpu, rng): + """Test match_histograms with CPU fallback.""" + img = rng.random((128, 128)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) * 2 + + result = match_histograms(img, ref) + + assert result.shape == img.shape + + def test_block_reduce_cpu(self, force_cpu, rng): + """Test block_reduce with CPU fallback.""" + arr = rng.random((128, 128)).astype(np.float32) + result = block_reduce(arr, (4, 4), np.mean) + + assert result.shape == (32, 32) + + def test_compute_ssim_cpu(self, force_cpu, rng): + """Test compute_ssim with CPU fallback.""" + arr1 = rng.random((128, 128)).astype(np.float32) + arr2 = arr1 + rng.random((128, 128)).astype(np.float32) * 0.1 + + ssim = compute_ssim(arr1, arr2, win_size=7) + + assert 0.0 <= ssim <= 1.0 + + +class TestDtypePreservation: + """Test that dtype is preserved when preserve_dtype=True.""" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_shift_array_dtype(self, dtype, force_cpu, rng): + """Test shift_array preserves dtype.""" + arr = (rng.random((64, 64)) * 255).astype(dtype) + result = shift_array(arr, (1.5, -1.5), preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_match_histograms_dtype(self, dtype, force_cpu, rng): + """Test match_histograms preserves dtype.""" + img = (rng.random((64, 64)) * 255).astype(dtype) + ref = (rng.random((64, 64)) * 255).astype(dtype) + result = match_histograms(img, ref, preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + @pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64]) + def test_block_reduce_dtype(self, dtype, force_cpu, rng): + """Test block_reduce preserves dtype.""" + arr = (rng.random((64, 64)) * 255).astype(dtype) + result = block_reduce(arr, (4, 4), np.mean, preserve_dtype=True) + + assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}" + + def test_shift_array_no_preserve(self, force_cpu, rng): + """Test shift_array returns float when preserve_dtype=False.""" + arr = (rng.random((64, 64)) * 255).astype(np.uint16) + result = shift_array(arr, (1.5, -1.5), preserve_dtype=False) + + # Should return float64 (scipy default) + assert result.dtype in [np.float32, np.float64] + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_shift_zero(self, force_cpu, rng): + """Test that zero shift returns nearly identical array.""" + arr = rng.random((64, 64)).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + + np.testing.assert_allclose(result, arr, rtol=1e-5, atol=1e-5) + + def test_identical_images_ssim(self, force_cpu, rng): + """Test SSIM of identical images is ~1.0.""" + arr = rng.random((64, 64)).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=7) + + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + + def test_block_reduce_3d(self, force_cpu, rng): + """Test block_reduce with 3D array.""" + arr = rng.random((3, 64, 64)).astype(np.float32) + result = block_reduce(arr, (1, 4, 4), np.mean) + + assert result.shape == (3, 16, 16) + + def test_different_size_histogram_match(self, force_cpu, rng): + """Test histogram matching with different sized images.""" + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) + + result = match_histograms(img, ref) + + assert result.shape == img.shape + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_fft.py b/tests/test_fft.py new file mode 100644 index 0000000..27e19bd --- /dev/null +++ b/tests/test_fft.py @@ -0,0 +1,102 @@ +"""Unit tests for GPU phase_cross_correlation (FFT).""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import phase_cross_correlation +from skimage.registration import phase_cross_correlation as skimage_pcc + + +class TestPhaseCorrelation: + """Tests for phase_cross_correlation function.""" + + def test_known_shift(self, rng): + """Test detection of known integer shift.""" + ref = rng.random((256, 256)).astype(np.float32) + + # Create shifted version: mov is ref shifted by (+5, -3) + # phase_cross_correlation returns shift to apply to mov to align with ref + # So it should return (-5, +3) + mov = np.zeros_like(ref) + mov[5:, :253] = ref[:-5, 3:] + + shift, _, _ = phase_cross_correlation(ref, mov) + + assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5" + assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3" + + def test_zero_shift(self, rng): + """Test that identical images give zero shift.""" + ref = rng.random((256, 256)).astype(np.float32) + + shift, _, _ = phase_cross_correlation(ref, ref) + + assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0" + assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0" + + def test_matches_skimage_direction(self, rng): + """Test that shift direction matches skimage convention.""" + ref = rng.random((128, 128)).astype(np.float32) + + # Shift by rolling + mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1) + + gpu_shift, _, _ = phase_cross_correlation(ref, mov) + cpu_shift, _, _ = skimage_pcc(ref, mov) + + # Directions should match + assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch" + assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch" + + +class TestSubpixelRefinement: + """Tests for subpixel phase correlation refinement.""" + + def test_subpixel_refinement(self, rng): + """Test subpixel accuracy with upsample_factor > 1.""" + ref = rng.random((128, 128)).astype(np.float32) + + # Use integer shift for ground truth (subpixel refinement should still work) + mov = np.roll(np.roll(ref, 7, axis=0), -4, axis=1) + + # Test with upsample_factor=10 for subpixel refinement + shift_subpixel, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) + + # Should detect the shift direction correctly + assert ( + abs(shift_subpixel[0] - (-7)) < 1 + ), f"Subpixel Y shift {shift_subpixel[0]} not close to -7" + assert ( + abs(shift_subpixel[1] - 4) < 1 + ), f"Subpixel X shift {shift_subpixel[1]} not close to 4" + + # Verify reasonable range + assert -10 < shift_subpixel[0] < 0, f"Subpixel Y shift {shift_subpixel[0]} out of range" + assert 0 < shift_subpixel[1] < 10, f"Subpixel X shift {shift_subpixel[1]} out of range" + + def test_subpixel_vs_integer_consistency(self, rng): + """Test that subpixel and integer modes give consistent direction.""" + ref = rng.random((64, 64)).astype(np.float32) + mov = np.roll(np.roll(ref, 3, axis=0), -2, axis=1) + + shift_int, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1) + shift_sub, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10) + + # Signs should match + assert np.sign(shift_int[0]) == np.sign( + shift_sub[0] + ), "Y direction mismatch between int/subpixel" + assert np.sign(shift_int[1]) == np.sign( + shift_sub[1] + ), "X direction mismatch between int/subpixel" + + # Magnitudes should be close + assert abs(shift_int[0] - shift_sub[0]) < 1, "Y magnitude differs too much" + assert abs(shift_int[1] - shift_sub[1]) < 1, "X magnitude differs too much" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py new file mode 100644 index 0000000..9223c50 --- /dev/null +++ b/tests/test_histogram_match.py @@ -0,0 +1,64 @@ +"""Unit tests for GPU histogram matching.""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import match_histograms +from skimage.exposure import match_histograms as skimage_match + + +class TestMatchHistograms: + """Tests for match_histograms function.""" + + def test_histogram_range(self, rng): + """Test output is in reference range.""" + img = rng.random((256, 256)).astype(np.float32) + ref = rng.random((256, 256)).astype(np.float32) * 2 + 1 + result = match_histograms(img, ref) + # Output should be in reference range + assert result.min() >= ref.min() - 0.1 + assert result.max() <= ref.max() + 0.1 + + def test_histogram_correlation(self, rng): + """Test histogram correlation with skimage.""" + img = rng.random((256, 256)).astype(np.float32) + ref = rng.random((256, 256)).astype(np.float32) + + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) + + cpu_hist, _ = np.histogram(cpu.flatten(), bins=100) + gpu_hist, _ = np.histogram(gpu.flatten(), bins=100) + corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1] + assert corr > 0.99, f"Histogram correlation {corr} too low" + + def test_same_image(self, rng): + """Test matching image to itself.""" + img = rng.random((128, 128)).astype(np.float32) + result = match_histograms(img, img) + np.testing.assert_allclose(result, img, rtol=1e-5) + + def test_different_sizes(self, rng): + """Test matching images of different sizes.""" + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((128, 128)).astype(np.float32) + result = match_histograms(img, ref) + assert result.shape == img.shape + + def test_pixel_values_match_skimage(self, rng): + """Test pixel-by-pixel matching against skimage.""" + img = rng.random((64, 64)).astype(np.float32) + ref = rng.random((64, 64)).astype(np.float32) * 2 + 1 + + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) + + # Pixel values should be close (not just histogram shape) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py new file mode 100644 index 0000000..4843eeb --- /dev/null +++ b/tests/test_shift_array.py @@ -0,0 +1,52 @@ +"""Unit tests for GPU shift_array.""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import shift_array +from scipy.ndimage import shift as scipy_shift + + +class TestShiftArray: + """Tests for shift_array function.""" + + def test_integer_shift(self, rng): + """Test integer shift matches scipy.""" + arr = rng.random((256, 256)).astype(np.float32) + cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False) + gpu = shift_array(arr, (3.0, -5.0)) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + + def test_subpixel_mean_error(self, rng): + """Test subpixel shift has low mean error vs scipy.""" + arr = rng.random((256, 256)).astype(np.float32) + cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False) + gpu = shift_array(arr, (5.5, -3.2)) + mean_diff = np.abs(cpu - gpu).mean() + assert mean_diff < 0.01, f"Mean diff {mean_diff} too high" + + def test_zero_shift(self, rng): + """Test zero shift returns nearly identical array.""" + arr = rng.random((256, 256)).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + # Allow small tolerance due to grid_sample interpolation + np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4) + + def test_small_array(self, rng): + """Test shift works on small arrays (edge case).""" + arr = rng.random((4, 4)).astype(np.float32) + result = shift_array(arr, (1.0, 1.0)) + assert result.shape == arr.shape + + def test_1pixel_fallback(self): + """Test 1-pixel array falls back to CPU.""" + arr = np.array([[1.0]], dtype=np.float32) + result = shift_array(arr, (0.5, 0.5)) + assert result.shape == (1, 1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_ssim.py b/tests/test_ssim.py new file mode 100644 index 0000000..d143889 --- /dev/null +++ b/tests/test_ssim.py @@ -0,0 +1,49 @@ +"""Unit tests for GPU SSIM.""" + +import numpy as np +import pytest +import sys + +sys.path.insert(0, "src") + +from tilefusion.utils import compute_ssim +from skimage.metrics import structural_similarity as skimage_ssim + + +class TestComputeSSIM: + """Tests for compute_ssim function.""" + + def test_ssim_similar_images(self, rng): + """Test SSIM of similar images matches skimage.""" + arr1 = rng.random((256, 256)).astype(np.float32) + arr2 = arr1 + rng.random((256, 256)).astype(np.float32) * 0.1 + + data_range = arr1.max() - arr1.min() + cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) + gpu = compute_ssim(arr1, arr2, win_size=15) + + assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu - gpu)} too high" + + def test_ssim_identical_images(self, rng): + """Test SSIM of identical images is ~1.0.""" + arr = rng.random((256, 256)).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=15) + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + + def test_ssim_different_images(self, rng): + """Test SSIM of random images is low.""" + arr1 = rng.random((256, 256)).astype(np.float32) + arr2 = rng.random((256, 256)).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=15) + assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" + + def test_ssim_range(self, rng): + """Test SSIM is in valid range [0, 1].""" + arr1 = rng.random((128, 128)).astype(np.float32) + arr2 = rng.random((128, 128)).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=7) + assert 0.0 <= ssim <= 1.0, f"SSIM {ssim} outside valid range" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])