Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3732322
feat: enhance prediction handling by validating output channels and a…
rhoadesScholar Jan 28, 2026
3de888c
Merge branch 'main' into multi_channel_writing
rhoadesScholar Feb 19, 2026
e0961bf
Update src/cellmap_segmentation_challenge/predict.py
rhoadesScholar Feb 19, 2026
8467547
Update src/cellmap_segmentation_challenge/predict.py
rhoadesScholar Feb 19, 2026
5329c81
fix: adjust shape check for target arrays in _predict function
rhoadesScholar Feb 19, 2026
4ca53a7
feat: enhance model input/output handling with singleton dimension ut…
rhoadesScholar Feb 19, 2026
04cbe5f
Initial plan
Copilot Feb 19, 2026
5da027f
Initial plan
Copilot Feb 19, 2026
f259b36
Update tests/test_utils.py
rhoadesScholar Feb 19, 2026
f6b6578
fix: use rank-based check for channel dimension detection
Copilot Feb 19, 2026
d998f2d
Merge branch 'multi_channel_writing' into copilot/sub-pr-188-another-one
rhoadesScholar Feb 19, 2026
b8595cf
Merge branch 'multi_channel_writing' into copilot/sub-pr-188
rhoadesScholar Feb 19, 2026
84aa23d
feat: add validation for num_channels_per_class in structure_model_ou…
Copilot Feb 19, 2026
6915174
fix: remove duplicate TestDownloadFile class and format code
Copilot Feb 19, 2026
6696cf1
refactor: improve test method names for clarity
Copilot Feb 19, 2026
1e18d67
feat: use deepcopy and spatial rank check for robust shape handling
Copilot Feb 19, 2026
c15980d
test: add comprehensive unit tests for shape adjustment logic
Copilot Feb 19, 2026
99b75ed
refactor: address code review feedback on imports and clarity
Copilot Feb 19, 2026
a2e8902
Merge pull request #193 from janelia-cellmap/copilot/sub-pr-188-anoth…
rhoadesScholar Feb 19, 2026
619a23c
Merge branch 'multi_channel_writing' into copilot/sub-pr-188
rhoadesScholar Feb 19, 2026
9595f57
Merge pull request #191 from janelia-cellmap/copilot/sub-pr-188
rhoadesScholar Feb 19, 2026
4bc7b2e
Apply suggestion from @rhoadesScholar
rhoadesScholar Feb 19, 2026
7b68a2f
Apply suggestion from @Copilot
rhoadesScholar Feb 19, 2026
102bb09
Apply suggestion from @rhoadesScholar
rhoadesScholar Feb 19, 2026
4865a59
fix: improve documentation in predict_2D.py and predict_3D.py
rhoadesScholar Feb 20, 2026
dcbb99c
fix: update cellmap-data dependency version to >=2026.2.19.2140
rhoadesScholar Feb 20, 2026
d315e4b
fix: adjust shape handling in prediction functions and improve deepco…
rhoadesScholar Feb 20, 2026
fbfdd06
Merge branch 'main' into multi_channel_writing
rhoadesScholar Feb 20, 2026
8757c23
Update src/cellmap_segmentation_challenge/predict.py
rhoadesScholar Feb 20, 2026
c1f9c36
Apply suggestion from @rhoadesScholar
rhoadesScholar Feb 20, 2026
b7878f1
Initial plan
Copilot Feb 20, 2026
4acf5fa
fix: treat num_channels_per_class=1 same as None to match docstring b…
Copilot Feb 20, 2026
248caf4
Merge pull request #196 from janelia-cellmap/copilot/sub-pr-188
rhoadesScholar Feb 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/predict_2D.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This file is used to predict the segmentation logits of the 3D test datasets using the model trained in the train_2D.py script.
# It does so by using the 'train_2D.py' configuration file, and the 'predict' function from the cellmap_segmentation_challenge package, which loads the trained model and runs inference on the test data.
# %%
# Imports
from cellmap_segmentation_challenge import predict
Expand Down
1 change: 1 addition & 0 deletions examples/predict_3D.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This file is used to predict the segmentation logits of the 3D test datasets using the model trained in the train_3D.py script.
# It does so by using the 'train_3D.py' configuration file, and the 'predict' function from the cellmap_segmentation_challenge package, which loads the trained model and runs inference on the test data.
# %%
# Imports
from cellmap_segmentation_challenge import predict
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"tensorboard",
"tensorboardX",
"click>=8, <9",
"cellmap-data",
"cellmap-data>=2026.2.19.2140",
"tqdm",
"numcodecs < 0.16.0",
"zarr < 3.0.0",
Expand Down
128 changes: 107 additions & 21 deletions src/cellmap_segmentation_challenge/predict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import copy
import os
import tempfile
from glob import glob
from typing import Any

import numpy as np
import torch
import torchvision.transforms.v2 as T
from cellmap_data import CellMapDatasetWriter, CellMapImage
Expand All @@ -18,7 +18,15 @@

from .config import CROP_NAME, PREDICTIONS_PATH, RAW_NAME, SEARCH_PATH
from .models import get_model
from .utils import load_safe_config, get_test_crops
from .utils import (
load_safe_config,
get_test_crops,
get_data_from_batch,
get_singleton_dim,
squeeze_singleton_dim,
structure_model_output,
unsqueeze_singleton_dim,
)
from .utils.datasplit import get_formatted_fields, get_raw_path


Expand Down Expand Up @@ -111,36 +119,102 @@ def _predict(
The batch size to use for prediction
"""

value_transforms = T.Compose(
[
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
)
model.eval()
device = dataset_writer_kwargs["device"]
input_keys = list(dataset_writer_kwargs["input_arrays"].keys())

dataset_writer = CellMapDatasetWriter(
**dataset_writer_kwargs, raw_value_transforms=value_transforms
# Test a single batch to get number of output channels
test_batch = {
k: torch.rand((1, *info["shape"])).unsqueeze(0).to(device)
for k, info in dataset_writer_kwargs["input_arrays"].items()
}
test_inputs = get_data_from_batch(test_batch, input_keys, device)
# Apply the same singleton-dimension squeezing as in the main prediction loop
singleton_dim = get_singleton_dim(
list(dataset_writer_kwargs["input_arrays"].values())[0]["shape"]
)
if singleton_dim is not None:
test_inputs = squeeze_singleton_dim(test_inputs, singleton_dim + 1)
with torch.no_grad():
test_outputs = model(test_inputs)
model_returns_class_dict = False
num_channels_per_class = None
if isinstance(test_outputs, dict):
if set(test_outputs.keys()) == set(dataset_writer_kwargs["classes"]):
# Keys are the class names; values are already per-class tensors
model_returns_class_dict = True
else:
# Dict with non-class keys (e.g., resolution levels): use the first
# value tensor to detect the channel count
test_outputs = next(iter(test_outputs.values()))
if not model_returns_class_dict and test_outputs.shape[1] > len(
dataset_writer_kwargs["classes"]
):
if test_outputs.shape[1] % len(dataset_writer_kwargs["classes"]) == 0:
num_channels_per_class = test_outputs.shape[1] // len(
dataset_writer_kwargs["classes"]
)
# To avoid mutating the input dictionary (which may be shared across multiple
# prediction calls), create a deep copy of target_arrays and update the shape
# to include the channel dimension.
target_arrays_copy = copy.deepcopy(dataset_writer_kwargs["target_arrays"])
for key in target_arrays_copy.keys():
current_shape = target_arrays_copy[key]["shape"]
# Use the first input array's shape to determine expected spatial rank
# (all input arrays should have the same spatial dimensions)
first_input_key = next(iter(dataset_writer_kwargs["input_arrays"]))
expected_spatial_rank = len(
dataset_writer_kwargs["input_arrays"][first_input_key]["shape"]
)
# Only prepend the channel dimension if the shape doesn't already include it
# We check if the current rank matches the expected spatial rank (no channel dim yet)
if len(current_shape) == expected_spatial_rank:
target_arrays_copy[key]["shape"] = (
num_channels_per_class,
*current_shape,
)
# Replace target_arrays in the kwargs with the modified copy
dataset_writer_kwargs = {
**dataset_writer_kwargs,
"target_arrays": target_arrays_copy,
}
else:
raise ValueError(
f"Number of output channels ({test_outputs.shape[1]}) does not match number of "
f"classes ({len(dataset_writer_kwargs['classes'])}). Should be a multiple of the "
"number of classes."
)
del test_batch, test_inputs, test_outputs

if "raw_value_transforms" not in dataset_writer_kwargs:
dataset_writer_kwargs["raw_value_transforms"] = T.Compose(
[
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
)

dataset_writer = CellMapDatasetWriter(**dataset_writer_kwargs)
dataloader = dataset_writer.loader(batch_size=batch_size)
model.eval()

# Find singleton dimension if there is one
# Only the first singleton dimension will be used for squeezing/unsqueezing.
# If there are multiple singleton dimensions, only the first is handled.
singleton_dim = np.where(
[s == 1 for s in dataset_writer_kwargs["input_arrays"]["input"]["shape"]]
)[0]
singleton_dim = singleton_dim[0] if singleton_dim.size > 0 else None
with torch.no_grad():
for batch in tqdm(dataloader, dynamic_ncols=True):
# Get the inputs and outputs
inputs = batch["input"].to(dataset_writer_kwargs["device"])
# Get the inputs, handling dict vs. tensor data
inputs = get_data_from_batch(batch, input_keys, device)
if singleton_dim is not None:
# Remove singleton dimension
inputs = inputs.squeeze(dim=singleton_dim + 2)
inputs = squeeze_singleton_dim(inputs, singleton_dim + 2)
outputs = model(inputs)
if singleton_dim is not None:
outputs = outputs.unsqueeze(dim=singleton_dim + 2)
outputs = {"output": outputs}
outputs = unsqueeze_singleton_dim(outputs, singleton_dim + 2)

outputs = structure_model_output(
outputs,
dataset_writer_kwargs["classes"],
num_channels_per_class,
)

# Save the outputs
dataset_writer[batch["idx"]] = outputs
Expand Down Expand Up @@ -185,6 +259,16 @@ def predict(
config, "input_array_info", {"shape": (1, 128, 128), "scale": (8, 8, 8)}
)
target_array_info = getattr(config, "target_array_info", input_array_info)
value_transforms = getattr(
config,
"value_transforms",
T.Compose(
[
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
)
model = config.model

# %% Check that the GPU is available
Expand Down Expand Up @@ -266,6 +350,7 @@ def predict(
"target_bounds": target_bounds,
"overwrite": overwrite,
"device": device,
"raw_value_transforms": value_transforms,
}
)
else:
Expand Down Expand Up @@ -322,6 +407,7 @@ def predict(
"target_bounds": target_bounds,
"overwrite": overwrite,
"device": device,
"raw_value_transforms": value_transforms,
}
)

Expand Down
43 changes: 24 additions & 19 deletions src/cellmap_segmentation_challenge/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
make_datasplit_csv,
make_s3_datasplit_csv,
format_string,
get_data_from_batch,
get_singleton_dim,
squeeze_singleton_dim,
unsqueeze_singleton_dim,
)


Expand Down Expand Up @@ -359,6 +363,13 @@ def _save_training_batch_for_viz(batch, inputs, outputs, targets):
input_keys = list(train_loader.dataset.input_arrays.keys())
target_keys = list(train_loader.dataset.target_arrays.keys())

# Find singleton dimension for 2D-model squeeze/unsqueeze
if "shape" in input_array_info:
first_input_shape = input_array_info["shape"]
else:
first_input_shape = list(input_array_info.values())[0]["shape"]
singleton_dim = get_singleton_dim(first_input_shape)

# %% Train the model
post_fix_dict = {}

Expand Down Expand Up @@ -394,20 +405,16 @@ def _save_training_batch_for_viz(batch, inputs, outputs, targets):
# Increment the training iteration
n_iter += 1

# Forward pass (compute the output of the model)
if len(input_keys) > 1:
inputs = {key: batch[key].to(device) for key in input_keys}
else:
inputs = batch[input_keys[0]].to(device)
# Assumes the model input is a single tensor
# Forward pass (compute the model outputs)
inputs = get_data_from_batch(batch, input_keys, device)
if singleton_dim is not None:
inputs = squeeze_singleton_dim(inputs, singleton_dim + 2)
outputs = model(inputs)
if singleton_dim is not None:
outputs = unsqueeze_singleton_dim(outputs, singleton_dim + 2)

# Compute the loss
if len(target_keys) > 1:
targets = {key: batch[key].to(device) for key in target_keys}
else:
targets = batch[target_keys[0]].to(device)
# Assumes the model output is a single tensor
targets = get_data_from_batch(batch, target_keys, device)
loss = criterion(outputs, targets) / gradient_accumulation_steps

# Backward pass (compute the gradients)
Expand Down Expand Up @@ -500,17 +507,15 @@ def _save_training_batch_for_viz(batch, inputs, outputs, targets):
with torch.no_grad():
i = 0
for batch in val_bar:
if len(input_keys) > 1:
inputs = {key: batch[key].to(device) for key in input_keys}
else:
inputs = batch[input_keys[0]].to(device)
inputs = get_data_from_batch(batch, input_keys, device)
if singleton_dim is not None:
inputs = squeeze_singleton_dim(inputs, singleton_dim + 2)
outputs = model(inputs)
if singleton_dim is not None:
outputs = unsqueeze_singleton_dim(outputs, singleton_dim + 2)

# Compute the loss
if len(target_keys) > 1:
targets = {key: batch[key].to(device) for key in target_keys}
else:
targets = batch[target_keys[0]].to(device)
targets = get_data_from_batch(batch, target_keys, device)
val_score += criterion(outputs, targets).item()
i += 1

Expand Down
5 changes: 5 additions & 0 deletions src/cellmap_segmentation_challenge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
"download_file",
"format_string",
"get_git_hash",
"get_data_from_batch",
"structure_model_output",
"get_singleton_dim",
"squeeze_singleton_dim",
"unsqueeze_singleton_dim",
],
"submission": [
"package_submission",
Expand Down
Loading