@@ -21,8 +21,9 @@ class ImageSegmenter:
2121
2222 def __init__ ( # type: ignore
2323 self ,
24- img ,
24+ imgs ,
2525 classes = 1 ,
26+ color_image = False ,
2627 mask = None ,
2728 mask_colors = None ,
2829 mask_alpha = 0.75 ,
@@ -38,10 +39,13 @@ def __init__( # type: ignore
3839
3940 Parameters
4041 ----------
41- img : array_like
42- A valid argument to imshow
42+ imgs : array_like
43+ A single image, or a stack of images shape (N, Y, X)
4344 classes : int, iterable[string], default 1
4445 If a number How many classes to have in the mask.
46+ color_image : bool, default False
47+ If True treat the final dimension of `imgs` as the RGB(A) axis.
48+ Allows for shapes like ([N], Y, X, [3,4])
4549 mask : arraylike, optional
4650 If you want to pre-seed the mask
4751 mask_colors : None, color, or array of colors, optional
@@ -89,29 +93,27 @@ def __init__( # type: ignore
8993 # should probably check the shape here
9094 self .mask_colors [:, - 1 ] = self .mask_alpha
9195
92- self ._img = np .asarray (img )
96+ imgs = np .asanyarray (imgs )
97+ self ._imgs = self ._pad_to_stack (imgs , "imgs" , color_image )
98+ self ._color_image = color_image
9399
100+ self ._image_index = 0
101+
102+ self ._overlay = np .zeros ((* self ._imgs .shape [1 :3 ], 4 ))
94103 if mask is None :
95- self .mask = np .zeros (self ._img .shape [:2 ])
96- """See :doc:`/examples/image-segmentation`."""
104+ self .mask = np .zeros (self ._imgs .shape [:3 ])
97105 else :
98- self .mask = mask
106+ self .mask = self ._pad_to_stack (mask , "mask" , False )
107+ self ._refresh_overlay_values ()
99108
100- self ._overlay = np .zeros ((* self ._img .shape [:2 ], 4 ))
101- for i in range (self ._n_classes + 1 ):
102- idx = self .mask == i
103- if i == 0 :
104- self ._overlay [idx ] = [0 , 0 , 0 , 0 ]
105- else :
106- self ._overlay [idx ] = self .mask_colors [i - 1 ]
107109 if ax is not None :
108110 self .ax = ax
109111 self .fig = self .ax .figure
110112 else :
111113 with ioff ():
112114 self .fig , self .ax = subplots (figsize = figsize )
113- self .displayed = self .ax .imshow (self ._img , ** kwargs )
114- self ._mask = self .ax .imshow (self ._overlay )
115+ self ._displayed = self .ax .imshow (self ._imgs [ self . _image_index ] , ** kwargs )
116+ self ._mask_im = self .ax .imshow (self ._overlay )
115117
116118 default_props = {"color" : "black" , "linewidth" : 1 , "alpha" : 0.8 }
117119 if props is None :
@@ -142,8 +144,10 @@ def __init__( # type: ignore
142144 )
143145 self .lasso .set_visible (True )
144146
145- pix_x = np .arange (self ._img .shape [0 ])
146- pix_y = np .arange (self ._img .shape [1 ])
147+ # offset shape by 1 because we always pad into
148+ # being a stack (N, Y, X, [3,4])
149+ pix_x = np .arange (self ._imgs .shape [1 ])
150+ pix_y = np .arange (self ._imgs .shape [2 ])
147151 xv , yv = np .meshgrid (pix_y , pix_x )
148152 self .pix = np .vstack ((xv .flatten (), yv .flatten ())).T
149153
@@ -153,6 +157,74 @@ def __init__( # type: ignore
153157 self ._erasing = False
154158 self ._paths : dict [str , list [Path ]] = {"adding" : [], "erasing" : []}
155159
160+ def _refresh_overlay_values (self ) -> None :
161+ # leave the actual updating of image to other code
162+ # in order to easily manage what gets updated and when
163+ # the drawing happens
164+ for i in range (self ._n_classes + 1 ):
165+ idx = self ._mask [self ._image_index ] == i
166+ if i == 0 :
167+ self ._overlay [idx ] = [0 , 0 , 0 , 0 ]
168+ else :
169+ self ._overlay [idx ] = self .mask_colors [i - 1 ]
170+
171+ @staticmethod
172+ def _pad_to_stack (arr : np .ndarray , name : str , color_image : bool ) -> np .ndarray :
173+ if color_image and arr .ndim < 3 :
174+ raise ValueError (
175+ f"{ name } must be at least 3 dimensional when *color_image* is True"
176+ f" but it is { arr .ndim } D"
177+ )
178+ if arr .ndim == (2 + color_image ):
179+ # make shape (1, M, N)
180+ # or (1, M, N, [3, 4])
181+ return arr [None , ...]
182+ elif arr .ndim == (3 + color_image ):
183+ return arr
184+ else :
185+ raise ValueError (
186+ f"{ name } must be either 2 or 3 dimensional. Did"
187+ " you mean to set *color_image* to True?"
188+ )
189+
190+ @property
191+ def mask (self ) -> np .ndarray :
192+ if self ._mask .shape [0 ] == 1 :
193+ # don't complicate things in the simple case of
194+ # one image
195+ return self ._mask [0 ]
196+ else :
197+ return self ._mask
198+
199+ @mask .setter
200+ def mask (self , val : np .ndarray ) -> None :
201+ val = self ._pad_to_stack (np .asanyarray (val ), "mask" , False )
202+ if self ._color_image :
203+ compare_shape = self ._imgs .shape [:- 1 ]
204+ else :
205+ compare_shape = self ._imgs .shape
206+ if val .shape != compare_shape :
207+ raise ValueError ("Mask must have the same shape as imgs" )
208+ self ._mask = val
209+
210+ @property
211+ def image_index (self ) -> int :
212+ return self ._image_index
213+
214+ @image_index .setter
215+ def image_index (self , val : int ) -> None :
216+ if not isinstance (val , Integral ):
217+ raise ValueError ("image_index must be an integer" )
218+ if val >= self ._imgs .shape [0 ]:
219+ raise ValueError (
220+ f"Too large - This segmenter only has { self ._imgs .shape [0 ]} images."
221+ )
222+ self ._image_index = val
223+ self ._refresh_overlay_values ()
224+ self ._displayed .set_data (self ._imgs [val ])
225+ self ._mask_im .set_data (self ._overlay )
226+ self .fig .canvas .draw_idle ()
227+
156228 @property
157229 def panmanager (self ) -> PanManager :
158230 return self ._pm
@@ -200,17 +272,17 @@ def get_paths(self) -> dict[str, list[Path]]:
200272
201273 def _onselect (self , verts : Any ) -> None :
202274 p = Path (verts )
203- self . indices = p .contains_points (self .pix , radius = 0 ).reshape (self .mask .shape )
275+ indices = p .contains_points (self .pix , radius = 0 ).reshape (self ._mask .shape [ 1 : 3 ] )
204276 if self ._erasing :
205- self .mask [self .indices ] = 0
206- self ._overlay [self . indices ] = [0 , 0 , 0 , 0 ]
277+ self ._mask [self ._image_index ][ indices ] = 0
278+ self ._overlay [indices ] = [0 , 0 , 0 , 0 ]
207279 self ._paths ["erasing" ].append (p )
208280 else :
209- self .mask [self .indices ] = self ._cur_class_idx
210- self ._overlay [self . indices ] = self .mask_colors [self ._cur_class_idx - 1 ]
281+ self ._mask [self ._image_index ][ indices ] = self ._cur_class_idx
282+ self ._overlay [indices ] = self .mask_colors [self ._cur_class_idx - 1 ]
211283 self ._paths ["adding" ].append (p )
212284
213- self ._mask .set_data (self ._overlay )
285+ self ._mask_im .set_data (self ._overlay )
214286 self .fig .canvas .draw_idle ()
215287
216288 def _ipython_display_ (self ) -> None :
0 commit comments