Skip to content

Commit 18a7bb2

Browse files
committed
fix: stacking bugs + add more tests
fix: mask drawing
1 parent e2134bd commit 18a7bb2

File tree

3 files changed

+133
-25
lines changed

3 files changed

+133
-25
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
pip install -e ".[test]"
5050
5151
- name: Test
52-
run: pytest -v --color=yes --cov=pymmcore_plus --cov-report=xml
52+
run: pytest -v --color=yes --cov=mpl_image_segmenter --cov-report=xml
5353

5454
- name: Coverage
5555
uses: codecov/codecov-action@v3

src/mpl_image_segmenter/_segmenter.py

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class ImageSegmenter:
2121

2222
def __init__( # type: ignore
2323
self,
24-
img,
24+
imgs,
2525
classes=1,
26+
color_image=False,
2627
mask=None,
2728
mask_colors=None,
2829
mask_alpha=0.75,
@@ -38,10 +39,13 @@ def __init__( # type: ignore
3839
3940
Parameters
4041
----------
41-
img : array_like
42-
A valid argument to imshow
42+
imgs : array_like
43+
A single image, or a stack of images shape (N, Y, X)
4344
classes : int, iterable[string], default 1
4445
If a number How many classes to have in the mask.
46+
color_image : bool, default False
47+
If True treat the final dimension of `imgs` as the RGB(A) axis.
48+
Allows for shapes like ([N], Y, X, [3,4])
4549
mask : arraylike, optional
4650
If you want to pre-seed the mask
4751
mask_colors : None, color, or array of colors, optional
@@ -89,29 +93,27 @@ def __init__( # type: ignore
8993
# should probably check the shape here
9094
self.mask_colors[:, -1] = self.mask_alpha
9195

92-
self._img = np.asarray(img)
96+
imgs = np.asanyarray(imgs)
97+
self._imgs = self._pad_to_stack(imgs, "imgs", color_image)
98+
self._color_image = color_image
9399

100+
self._image_index = 0
101+
102+
self._overlay = np.zeros((*self._imgs.shape[1:3], 4))
94103
if mask is None:
95-
self.mask = np.zeros(self._img.shape[:2])
96-
"""See :doc:`/examples/image-segmentation`."""
104+
self.mask = np.zeros(self._imgs.shape[:3])
97105
else:
98-
self.mask = mask
106+
self.mask = self._pad_to_stack(mask, "mask", False)
107+
self._refresh_overlay_values()
99108

100-
self._overlay = np.zeros((*self._img.shape[:2], 4))
101-
for i in range(self._n_classes + 1):
102-
idx = self.mask == i
103-
if i == 0:
104-
self._overlay[idx] = [0, 0, 0, 0]
105-
else:
106-
self._overlay[idx] = self.mask_colors[i - 1]
107109
if ax is not None:
108110
self.ax = ax
109111
self.fig = self.ax.figure
110112
else:
111113
with ioff():
112114
self.fig, self.ax = subplots(figsize=figsize)
113-
self.displayed = self.ax.imshow(self._img, **kwargs)
114-
self._mask = self.ax.imshow(self._overlay)
115+
self._displayed = self.ax.imshow(self._imgs[self._image_index], **kwargs)
116+
self._mask_im = self.ax.imshow(self._overlay)
115117

116118
default_props = {"color": "black", "linewidth": 1, "alpha": 0.8}
117119
if props is None:
@@ -142,8 +144,10 @@ def __init__( # type: ignore
142144
)
143145
self.lasso.set_visible(True)
144146

145-
pix_x = np.arange(self._img.shape[0])
146-
pix_y = np.arange(self._img.shape[1])
147+
# offset shape by 1 because we always pad into
148+
# being a stack (N, Y, X, [3,4])
149+
pix_x = np.arange(self._imgs.shape[1])
150+
pix_y = np.arange(self._imgs.shape[2])
147151
xv, yv = np.meshgrid(pix_y, pix_x)
148152
self.pix = np.vstack((xv.flatten(), yv.flatten())).T
149153

@@ -153,6 +157,74 @@ def __init__( # type: ignore
153157
self._erasing = False
154158
self._paths: dict[str, list[Path]] = {"adding": [], "erasing": []}
155159

160+
def _refresh_overlay_values(self) -> None:
161+
# leave the actual updating of image to other code
162+
# in order to easily manage what gets updated and when
163+
# the drawing happens
164+
for i in range(self._n_classes + 1):
165+
idx = self._mask[self._image_index] == i
166+
if i == 0:
167+
self._overlay[idx] = [0, 0, 0, 0]
168+
else:
169+
self._overlay[idx] = self.mask_colors[i - 1]
170+
171+
@staticmethod
172+
def _pad_to_stack(arr: np.ndarray, name: str, color_image: bool) -> np.ndarray:
173+
if color_image and arr.ndim < 3:
174+
raise ValueError(
175+
f"{name} must be at least 3 dimensional when *color_image* is True"
176+
f" but it is {arr.ndim}D"
177+
)
178+
if arr.ndim == (2 + color_image):
179+
# make shape (1, M, N)
180+
# or (1, M, N, [3, 4])
181+
return arr[None, ...]
182+
elif arr.ndim == (3 + color_image):
183+
return arr
184+
else:
185+
raise ValueError(
186+
f"{name} must be either 2 or 3 dimensional. Did"
187+
" you mean to set *color_image* to True?"
188+
)
189+
190+
@property
191+
def mask(self) -> np.ndarray:
192+
if self._mask.shape[0] == 1:
193+
# don't complicate things in the simple case of
194+
# one image
195+
return self._mask[0]
196+
else:
197+
return self._mask
198+
199+
@mask.setter
200+
def mask(self, val: np.ndarray) -> None:
201+
val = self._pad_to_stack(np.asanyarray(val), "mask", False)
202+
if self._color_image:
203+
compare_shape = self._imgs.shape[:-1]
204+
else:
205+
compare_shape = self._imgs.shape
206+
if val.shape != compare_shape:
207+
raise ValueError("Mask must have the same shape as imgs")
208+
self._mask = val
209+
210+
@property
211+
def image_index(self) -> int:
212+
return self._image_index
213+
214+
@image_index.setter
215+
def image_index(self, val: int) -> None:
216+
if not isinstance(val, Integral):
217+
raise ValueError("image_index must be an integer")
218+
if val >= self._imgs.shape[0]:
219+
raise ValueError(
220+
f"Too large - This segmenter only has {self._imgs.shape[0]} images."
221+
)
222+
self._image_index = val
223+
self._refresh_overlay_values()
224+
self._displayed.set_data(self._imgs[val])
225+
self._mask_im.set_data(self._overlay)
226+
self.fig.canvas.draw_idle()
227+
156228
@property
157229
def panmanager(self) -> PanManager:
158230
return self._pm
@@ -200,17 +272,17 @@ def get_paths(self) -> dict[str, list[Path]]:
200272

201273
def _onselect(self, verts: Any) -> None:
202274
p = Path(verts)
203-
self.indices = p.contains_points(self.pix, radius=0).reshape(self.mask.shape)
275+
indices = p.contains_points(self.pix, radius=0).reshape(self._mask.shape[1:3])
204276
if self._erasing:
205-
self.mask[self.indices] = 0
206-
self._overlay[self.indices] = [0, 0, 0, 0]
277+
self._mask[self._image_index][indices] = 0
278+
self._overlay[indices] = [0, 0, 0, 0]
207279
self._paths["erasing"].append(p)
208280
else:
209-
self.mask[self.indices] = self._cur_class_idx
210-
self._overlay[self.indices] = self.mask_colors[self._cur_class_idx - 1]
281+
self._mask[self._image_index][indices] = self._cur_class_idx
282+
self._overlay[indices] = self.mask_colors[self._cur_class_idx - 1]
211283
self._paths["adding"].append(p)
212284

213-
self._mask.set_data(self._overlay)
285+
self._mask_im.set_data(self._overlay)
214286
self.fig.canvas.draw_idle()
215287

216288
def _ipython_display_(self) -> None:

tests/test_mpl_image_segmenter.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,42 @@
55
from mpl_image_segmenter import ImageSegmenter
66

77

8+
def test_rbga_image():
9+
img_rgb = np.zeros([128, 128, 3])
10+
img_rgba = np.zeros([128, 128, 4])
11+
ImageSegmenter(img_rgb, classes=["a", "b", "c"], color_image=True)
12+
ImageSegmenter(img_rgba, classes=["a", "b", "c"], color_image=True)
13+
14+
15+
@pytest.mark.parametrize(
16+
["shape", "color"],
17+
[
18+
((3, 128, 128), False),
19+
((4, 128, 128, 3), True),
20+
((5, 128, 128, 4), True),
21+
],
22+
)
23+
def test_image_stacks(shape: tuple, color: bool):
24+
top = 25
25+
left = 25
26+
right = 100
27+
bottom = 100
28+
imgs = np.zeros(shape)
29+
# imgs_rgb = np.zeros([4, 128, 128, 3])
30+
# imgs_rgba = np.zeros([5, 128, 128, 4])
31+
seg = ImageSegmenter(imgs, classes=["a", "b", "c"], color_image=color)
32+
seg.current_class = "a"
33+
seg._onselect([(left, top), (left, bottom), (right, bottom), (right, top)])
34+
seg.image_index = 1
35+
assert seg.mask[0].sum() == 5550
36+
seg._onselect([(left, top + 40), (left, bottom), (right, bottom), (right, top)])
37+
assert seg.mask[1].sum() == 4105
38+
assert seg.mask[2].sum() == 0
39+
40+
with pytest.raises(ValueError):
41+
seg.image_index = 5
42+
43+
844
def test_current_class():
945
img = np.zeros([128, 128])
1046
seg = ImageSegmenter(img, classes=["a", "b", "c"])

0 commit comments

Comments
 (0)