Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/cellmap_segmentation_challenge/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def structure_model_output(
Ordered list of class names.
num_channels_per_class:
When ``> 1``, each class occupies this many consecutive channels in the
channel dimension. ``None`` means one channel per class.
channel dimension. ``None`` or ``1`` means one channel per class.

Returns
-------
Expand All @@ -608,7 +608,7 @@ def structure_model_output(
# Dict with non-class keys (e.g. resolution levels): split each value
structured = {}
for k, v in outputs.items():
if num_channels_per_class is not None:
if num_channels_per_class is not None and num_channels_per_class > 1:
expected_channels = len(classes) * num_channels_per_class
if v.shape[1] != expected_channels:
raise ValueError(
Expand All @@ -631,7 +631,7 @@ def structure_model_output(
f"classes ({len(classes)}). Should be a multiple of the number of classes."
)
return structured
elif num_channels_per_class is not None:
elif num_channels_per_class is not None and num_channels_per_class > 1:
expected_channels = len(classes) * num_channels_per_class
if outputs.shape[1] != expected_channels:
raise ValueError(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,26 @@ def test_resolution_dict_with_num_channels_per_class_wrong_count_raises(self):
with pytest.raises(ValueError, match="does not match expected"):
structure_model_output(d, CLASSES, num_channels_per_class=2)

def test_tensor_num_channels_per_class_equals_one(self):
"""Test that num_channels_per_class=1 is treated the same as None (returns tensor, not dict)"""
out = torch.zeros(2, 3, 8, 8) # 3 classes, 1 ch each
result = structure_model_output(out, CLASSES, num_channels_per_class=1)
assert set(result.keys()) == {"output"}
# Should return a plain tensor, not a dict of class names
assert isinstance(result["output"], torch.Tensor)
assert result["output"].shape == (2, 3, 8, 8)

def test_resolution_dict_num_channels_per_class_equals_one(self):
"""Test that num_channels_per_class=1 with resolution dict returns tensors, not nested dicts"""
d = {"8nm": torch.zeros(2, 3, 8, 8), "32nm": torch.zeros(2, 3, 4, 4)}
result = structure_model_output(d, CLASSES, num_channels_per_class=1)
assert set(result.keys()) == {"8nm", "32nm"}
# Should return plain tensors, not dicts
assert isinstance(result["8nm"], torch.Tensor)
assert isinstance(result["32nm"], torch.Tensor)
assert result["8nm"].shape == (2, 3, 8, 8)
assert result["32nm"].shape == (2, 3, 4, 4)


class TestDownloadFile:
"""Tests for download_file function"""
Expand Down