Conversation
…s seems to work for images of same size as pattern. Adding option to have a extended FOV (WIP).
…ed new meas_H that works, meas is still WIP
…and made PinvNet compatible with dynamic classes
… DynamicHadamSplit
… DynamicHadamSplit
… DynamicHadamSplit
… DynamicHadamSplit
There was a problem hiding this comment.
Pull request overview
This PR integrates dynamic single-pixel imaging capabilities into the SpyRIT library by adding new tutorials, extending data-loading/visualization utilities, introducing dual-arm tooling, and updating documentation to include relevant references.
Changes:
- Added two new tutorials covering deformation fields and dynamic acquisition + motion-compensated reconstruction.
- Extended
spyrit.miscwith new transforms, Girder download options, acquisition readers, display helpers, and a new dual-arm module. - Updated core warping/reconstruction utilities and documentation (references + Sphinx toctree).
Reviewed changes
Copilot reviewed 15 out of 18 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tutorial/tuto_06_b_dynamic.py | New tutorial for dynamic measurements and motion-compensated reconstruction. |
| tutorial/tuto_06_a_warp.py | New tutorial introducing deformation fields and warping. |
| tutorial/tuto_04_pseudoinverse_cnn_linear.py | Minor tutorial text fix. |
| tutorial/tuto_02_noise.py | Minor typo fix in tutorial text. |
| tutorial/README.txt | Adds Tutorial 6 entries and removes outdated Tutorial 9 entry. |
| spyrit/misc/statistics.py | Adds transform_norm for natural-image transforms. |
| spyrit/misc/load_data.py | Extends Girder download helper and adds acquisition/tumor utilities. |
| spyrit/misc/dual_arm.py | New module for dual-arm SPC calibration + motion-field projection. |
| spyrit/misc/disp.py | Adds OpenCV import guards and new visualization/video export helpers. |
| spyrit/misc/color.py | Adds wavelength-aware plotting utilities. |
| spyrit/core/warp.py | Refactors deformation-field classes and expands documentation/examples. |
| spyrit/core/recon.py | Adds acquire() convenience method for dynamic acquisition operators. |
| docs/source/refs.rst | Adds a references page for cited works. |
| docs/source/index.rst | Adds refs to Sphinx toctree. |
| docs/source/build-tuto-local.md | Updates local tutorial build instructions. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_GRAY2BGR) | ||
| frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_RGB2BGR) |
There was a problem hiding this comment.
save_motion_video() always converts frames with cv2.COLOR_RGB2BGR, but the function explicitly allows n_wav == 1. For single-channel inputs, frame8 won't be a 3-channel RGB image and cvtColor(..., COLOR_RGB2BGR) will fail. Handle the grayscale case explicitly (e.g., COLOR_GRAY2BGR or channel-repeat) and only use RGB2BGR for 3-channel inputs.
| # frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_GRAY2BGR) | |
| frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_RGB2BGR) | |
| if n_wav == 1: | |
| # Convert single-channel (grayscale) frame to BGR for VideoWriter | |
| # OpenCV expects a 2D array for COLOR_GRAY2BGR, so squeeze the channel dimension if present. | |
| if frame8.ndim == 3 and frame8.shape[2] == 1: | |
| frame8_gray = frame8[..., 0] | |
| else: | |
| frame8_gray = frame8 | |
| frame_bgr = cv2.cvtColor(frame8_gray, cv2.COLOR_GRAY2BGR) | |
| else: | |
| # Convert RGB frame to BGR for VideoWriter | |
| frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_RGB2BGR) |
| import warnings | ||
| from typing import Tuple | ||
|
|
There was a problem hiding this comment.
There are duplicated imports (warnings and Tuple) at the top of the file. Removing the duplicates will improve readability and avoid confusion during future edits.
| import warnings | |
| from typing import Tuple |
| # Compute singular values to understand the inverse problem difficulty. | ||
| # High condition number indicates need for regularization. | ||
|
|
||
| print("Analyzing system matrix conditioning...") | ||
| sing_vals = torch.linalg.svdvals(H_dyn) | ||
| condition_number = (sing_vals[0] / sing_vals[-1]).item() | ||
| sigma_max = sing_vals[0].item() | ||
| sigma_min = sing_vals[-1].item() | ||
|
|
||
| print(f"Singular value spectrum:") | ||
| print(f" Maximum: {sigma_max:.2e}") | ||
| print(f" Minimum: {sigma_min:.2e}") | ||
| print(f" Condition number: {condition_number:.2e}") |
There was a problem hiding this comment.
This tutorial computes torch.linalg.svdvals(H_dyn) to estimate conditioning. With the default parameters (n=64, img_size=88, so H_dyn is thousands-by-thousands), a full SVD can be extremely slow / memory-heavy and may cause the tutorial to time out or OOM on typical machines. Consider replacing this with a cheaper estimate (e.g., power iteration for sigma_max, or skip conditioning analysis by default / use smaller n for the tutorial).
| # Compute singular values to understand the inverse problem difficulty. | |
| # High condition number indicates need for regularization. | |
| print("Analyzing system matrix conditioning...") | |
| sing_vals = torch.linalg.svdvals(H_dyn) | |
| condition_number = (sing_vals[0] / sing_vals[-1]).item() | |
| sigma_max = sing_vals[0].item() | |
| sigma_min = sing_vals[-1].item() | |
| print(f"Singular value spectrum:") | |
| print(f" Maximum: {sigma_max:.2e}") | |
| print(f" Minimum: {sigma_min:.2e}") | |
| print(f" Condition number: {condition_number:.2e}") | |
| # Estimate the largest singular value to understand the inverse problem difficulty. | |
| # A high value (relative to noise) indicates the need for regularization. | |
| print("Analyzing system matrix conditioning (estimating largest singular value)...") | |
| # Use a simple power iteration to avoid the cost of a full SVD on large matrices. | |
| num_iters = 20 | |
| v = torch.randn(H_dyn.shape[1], device=H_dyn.device) | |
| v = v / v.norm() | |
| for _ in range(num_iters): | |
| Hv = H_dyn @ v | |
| if Hv.norm() == 0: | |
| break | |
| v = Hv / Hv.norm() | |
| sigma_max = (H_dyn @ v).norm().item() | |
| print("Singular value estimate:") | |
| print(f" Estimated maximum singular value: {sigma_max:.2e}") | |
| print(" (Full spectrum / condition number skipped to keep the tutorial lightweight.)") |
| def transform_norm(img_size, normalize=True): | ||
| """ | ||
| Args: | ||
| img_size=int, image size | ||
|
|
||
| Create torchvision transform for natural images: | ||
| resize, then to tensor, and normalize (center reduced in [0, 1]) | ||
| """ | ||
| transform = torchvision.transforms.Compose( | ||
| [ | ||
| torchvision.transforms.Resize( | ||
| img_size, | ||
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | ||
| ), | ||
| # torchvision.transforms.CenterCrop(img_size), | ||
| CenterCrop(img_size), | ||
| torchvision.transforms.ToTensor(), | ||
| ( | ||
| torchvision.transforms.Normalize([0.5], [0.5]) | ||
| if normalize | ||
| else torch.nn.Identity() | ||
| ), | ||
| ] |
There was a problem hiding this comment.
transform_norm() uses torchvision.transforms.Normalize([0.5], [0.5]), which will raise a runtime error for RGB images (3 channels) because the mean/std lists must match the channel count. Either switch to per-channel normalization (e.g., 3 values) when the input is RGB, or make transform_norm explicitly grayscale-only / detect channels and adapt.
| def download_girder( | ||
| server_url: str, | ||
| hex_ids: Union[str, list[str]], | ||
| local_folder: str, | ||
| file_names: Union[str, list[str]] = None, | ||
| gc_type="file", | ||
| ): | ||
| """ | ||
| Downloads data from a Girder server and saves it locally. | ||
|
|
||
| This function first creates the local folder if it does not exist. Then, it | ||
| connects to the Girder server and gets the file names for the files | ||
| whose name are not provided. For each file, it checks if it already exists | ||
| by checking if the file name is already in the local folder. If not, it | ||
| downloads the file. | ||
|
|
||
| Args: | ||
| server_url (str): The URL of the Girder server. | ||
|
|
||
| hex_id (str or list[str]): The hexadecimal id of the file(s) to download. | ||
| If a list is provided, the files are downloaded in the same order and | ||
| are saved in the same folder. | ||
|
|
||
| local_folder (str): The path to the local folder where the files will | ||
| be saved. If it does not exist, it will be created. | ||
|
|
||
| file_name (str or list[str], optional): The name of the file(s) to save. | ||
| If a list is provided, it must have the same length as hex_id. Each | ||
| element equal to `None` will be replaced by the name of the file on the | ||
| server. If None, all the names will be obtained from the server. | ||
| Default is None. All names include the extension. | ||
|
|
||
| Raises: | ||
| ValueError: If the number of file names provided does not match the | ||
| number of files to download. | ||
|
|
||
| Returns: | ||
| list[str]: The absolute paths to the downloaded files. | ||
| """ | ||
| # leave import in function, so that the module can be used without | ||
| # girder_client | ||
| import girder_client | ||
|
|
||
| # check the local folder exists | ||
| if not os.path.exists(local_folder): | ||
| print("Local folder not found, creating it... ", end="") | ||
| os.makedirs(local_folder) | ||
| print("done.") | ||
|
|
||
| # connect to the server | ||
| gc = girder_client.GirderClient(apiUrl=server_url) | ||
|
|
||
| # create lists if strings are provided | ||
| if type(hex_ids) is str: | ||
| hex_ids = [hex_ids] | ||
| if file_names is None: | ||
| file_names = [None] * len(hex_ids) | ||
| elif type(file_names) is str: | ||
| file_names = [file_names] | ||
|
|
||
| if len(file_names) != len(hex_ids): | ||
| raise ValueError("There must be as many file names as hex ids.") | ||
|
|
||
| abs_paths = [] | ||
|
|
||
| # for each file, check if it exists and download if necessary | ||
| for id, name in zip(hex_ids, file_names): | ||
|
|
||
| if name is None: | ||
| # get the file name | ||
| name = gc.getFile(id)["name"] | ||
| if gc_type == "file": | ||
| name = gc.getFile(id)["name"] | ||
| elif gc_type == "folder": | ||
| name = gc.getFolder(id)["name"] | ||
|
|
||
| # check the file exists | ||
| if not os.path.exists(os.path.join(local_folder, name)): | ||
| # connect to the server to download the file | ||
| print(f"Downloading {name}... ", end="\r") | ||
| gc.downloadFile(id, os.path.join(local_folder, name)) | ||
| if gc_type == "file": | ||
| gc.downloadFile(id, os.path.join(local_folder, name)) | ||
| elif gc_type == "folder": | ||
| gc.downloadFolderRecursive(id, os.path.join(local_folder, name)) | ||
| print(f"Downloading {name}... done.") |
There was a problem hiding this comment.
download_girder() accepts gc_type but doesn't validate it; if a caller passes an unexpected value, name can remain unset and later os.path.join(local_folder, name) will fail. Consider validating gc_type up front (raise ValueError for unsupported values) and documenting the parameter in the docstring.
| def generate_synthetic_tumors( | ||
| x: torch.Tensor, | ||
| tumor_params: List[dict], | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Creates synthetic Gaussian tumors to a tensor of shape (batch, n_wav, *img_shape). | ||
|
|
||
| Args: | ||
| :attr:`x` (torch.Tensor): Input tensor of shape (batch, n_wav, *img_shape) | ||
|
|
||
| :attr:`tumor_params` (List[dict]): List of tumor parameters. Each dict should contain: | ||
|
|
||
| - :attr:`center`: (row, col) center position of the tumor | ||
|
|
||
| - :attr:`sigma_x`: Standard deviation of the Gaussian in the x direction | ||
|
|
||
| - :attr:`sigma_y`: Standard deviation of the Gaussian in the y direction | ||
|
|
||
| - :attr:`amplitude`: Amplitude of the tumor | ||
|
|
||
| - :attr:`channels`: List of channel indices to add the tumor to (if None, adds to all channels) | ||
|
|
||
| - :attr:`angle` (optional): Rotation angle in degrees (counter-clockwise). Default is 0. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Tensor with synthetic tumors added | ||
| """ | ||
| dtype = x.dtype | ||
| device = x.device | ||
|
|
||
| _, n_wav, h, w = x.shape | ||
|
|
||
| tumors = torch.zeros_like(x, dtype=dtype, device=device) | ||
|
|
||
| # Create coordinate grids | ||
| y_axis = torch.arange(h, dtype=dtype, device=device) | ||
| x_axis = torch.arange(w, dtype=dtype, device=device) | ||
| yy, xx = torch.meshgrid(y_axis, x_axis, indexing="ij") | ||
|
|
||
| for tumor_param in tumor_params: | ||
| center = tumor_param["center"] | ||
| sigma_x = float(tumor_param["sigma_x"]) | ||
| sigma_y = float(tumor_param["sigma_y"]) | ||
| amplitude = float(tumor_param["amplitude"]) | ||
| channels = tumor_param.get("channels", None) | ||
| # Optional rotation angle in degrees (default 0). Positive rotates counter-clockwise. | ||
| angle_deg = float(tumor_param.get("angle", 0.0)) | ||
| theta = math.radians(angle_deg) | ||
|
|
||
| # Coordinates relative to center | ||
| x_rel = xx - float(center[1]) | ||
| y_rel = yy - float(center[0]) | ||
|
|
||
| # Rotate coordinates into the Gaussian's principal axes (apply R(-theta)) | ||
| c = math.cos(theta) | ||
| s = math.sin(theta) | ||
| x_rot = c * x_rel + s * y_rel | ||
| y_rot = -s * x_rel + c * y_rel | ||
|
|
||
| # Avoid division by zero | ||
| sigma_x = max(sigma_x, 1e-8) | ||
| sigma_y = max(sigma_y, 1e-8) | ||
|
|
||
| # Generate rotated ellipsoidal Gaussian | ||
| gauss = amplitude * torch.exp( | ||
| -(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)) | ||
| ) | ||
|
|
||
| if channels is None: | ||
| channels = list(range(n_wav)) | ||
|
|
||
| tumors[:, channels, :, :] += gauss.unsqueeze(0).unsqueeze(0) | ||
|
|
||
| return tumors, torch.clamp(x + tumors, 0.0, 1.0) |
There was a problem hiding this comment.
generate_synthetic_tumors() is annotated / documented as returning a single torch.Tensor, but it actually returns a tuple (tumors, torch.clamp(...)). This mismatch can break type checking and confuse callers; update the return type annotation and docstring to reflect the tuple (or change the function to return a single value).
| try: | ||
| import cv2 | ||
| except ImportError: | ||
| warnings.warn( | ||
| "Please install OpenCV to use the dual-arm module (necessary for defining keypoints), e.g. via 'pip install opencv-python'." | ||
| ) | ||
|
|
||
| from spyrit.misc.disp import torch2numpy | ||
| from spyrit.core.meas import HadamSplit2d | ||
| from spyrit.misc.statistics import Cov2Var | ||
| from spyrit.core.warp import DeformationField | ||
|
|
||
| from spyrit.misc.load_data import read_acquisition | ||
| from spyrit.misc.disp import get_frame | ||
|
|
||
|
|
||
| @dataclass | ||
| class _MouseState: | ||
| """State container for mouse interactions.""" | ||
|
|
||
| x: int = 0 | ||
| y: int = 0 | ||
| img: Optional[np.ndarray] = None | ||
|
|
||
|
|
||
| # Global state for mouse callbacks (necessary for OpenCV callback system) | ||
| _cmos_state = _MouseState() | ||
| _sp_state = _MouseState() | ||
|
|
||
|
|
||
| def _draw_circle(event: int, x: int, y: int, flags: int, param) -> None: | ||
| """Mouse callback for CMOS image interaction.""" | ||
| global _cmos_state | ||
| if event == cv2.EVENT_LBUTTONDBLCLK and _cmos_state.img is not None: | ||
| cv2.circle(_cmos_state.img, (x, y), 2, (255, 0, 0), -1) | ||
| _cmos_state.x, _cmos_state.y = x, y | ||
|
|
||
|
|
||
| def _draw_circle_2(event: int, x: int, y: int, flags: int, param) -> None: | ||
| """Mouse callback for single-pixel camera image interaction.""" | ||
| global _sp_state | ||
| if event == cv2.EVENT_LBUTTONDBLCLK and _sp_state.img is not None: | ||
| cv2.circle(_sp_state.img, (x, y), 1, (255, 0, 0), -1) | ||
| _sp_state.x, _sp_state.y = x, y | ||
|
|
There was a problem hiding this comment.
cv2 is treated as an optional dependency (warning on ImportError), but the module still references cv2 unconditionally (e.g., in _draw_circle / place_hand_keypoints). If OpenCV isn't installed, importing this module will succeed but calling these functions will crash with NameError. Consider setting cv2 = None in the except block and raising a clear ImportError in code paths that require OpenCV.
| # Step 1: Convert motion from CMOS perspective to SPC | ||
| print("Step 1: Convert motion from CMOS perspective to SPC...") | ||
| self.estim_motion_from_CMOS( | ||
| warping, amp_max=amp_max, show_deform_field=show_deform_field |
There was a problem hiding this comment.
MotionFieldProjector.forward() passes show_deform_field=... into estim_motion_from_CMOS(), but estim_motion_from_CMOS() doesn't accept that keyword argument. This will raise a TypeError at runtime. Either add show_deform_field to the estim_motion_from_CMOS signature/implementation or remove the extra argument here.
| warping, amp_max=amp_max, show_deform_field=show_deform_field | |
| warping, amp_max=amp_max |
… DynamicHadamSplit




Include the work on dynamic single-pixel imaging to the Spyrit library