Skip to content

Commit dc1bf52

Browse files
authored
New features robert (#70)
* add remove_to_label_for_cc_filter * add remove_to * add itk global_cord support * add referecne for loading POI (POI.load) * enable loading via Path * add set_above_3_point_plane * make an internal function to make vert-seg * do not corp when empty instead of failing * fix issue with empty images and cropout * bug fixes * ruff * remove deprecated function * x --------- Co-authored-by: ga84mun <[email protected]>
1 parent efa5486 commit dc1bf52

File tree

11 files changed

+374
-121
lines changed

11 files changed

+374
-121
lines changed

TPTBox/core/nii_wrapper.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -747,18 +747,18 @@ def pad_to(self,target_shape:list[int]|tuple[int,int,int] | Self, mode:MODES="co
747747
s = s.apply_crop(tuple(crop),inplace=inplace)
748748
return s.apply_pad(padding,inplace=inplace,mode=mode)
749749

750-
def apply_pad(self,padd:Sequence[tuple[int|None,int]],mode:MODES="constant",inplace = False):
750+
def apply_pad(self,padd:Sequence[tuple[int|None,int]],mode:MODES="constant",inplace = False,verbose:logging=True):
751751
#TODO add other modes
752752
#TODO add testcases and options for modes
753753
transform = np.eye(4, dtype=int)
754754
for i, (before,_) in enumerate(padd):
755755
#transform[i, i] = pad_slice.step if pad_slice.step is not None else 1
756756
transform[i, 3] = -before if before is not None else 0
757757
affine = self.affine.dot(transform)
758-
print(mode)
759758
args = {}
760759
if mode == "constant":
761760
args["constant_values"]=self.get_c_val()
761+
log.print(f"Padd {padd}; {mode=}, {args}",verbose=verbose)
762762
arr = np.pad(self.get_array(),padd,mode=mode,**args) # type: ignore
763763

764764
nii:_unpacked_nii = (arr,affine,self.header)
@@ -1260,21 +1260,24 @@ def get_segmentation_connected_components(self, labels: int |list[int], connecti
12601260
cc = {i: self.set_array(k) for i,k in cc.items()}
12611261
return cc, cc_n
12621262

1263-
def get_connected_components(self, labels: int |list[int]=1, connectivity: int = 3, verbose: bool=False,inplace=False) -> Self:
1264-
arr = self.get_seg_array()
1265-
cc, _ = np_connected_components(arr, connectivity=connectivity, label_ref=labels, verbose=verbose)
1266-
out = None
1267-
1268-
for i,k in cc.items():
1269-
if out is None:
1270-
out = k
1271-
else:
1272-
out += i*k
1273-
if out is None:
1274-
return self if inplace else self.copy()
1263+
def get_connected_components(self, labels: int |list[int]=1, connectivity: int = 3, verbose: bool=False,inplace=False) -> Self: # noqa: ARG002
1264+
out = np_get_largest_k_connected_components(self.get_seg_array(), label_ref=labels, connectivity=connectivity, return_original_labels=False)
12751265
return self.set_array(out,inplace=inplace)
12761266

1277-
def filter_connected_components(self, labels: int |list[int]|None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False, inplace=False):
1267+
#arr = self.get_seg_array()
1268+
#cc, _ = np_connected_components(arr, connectivity=connectivity, label_ref=labels, verbose=verbose)
1269+
#out = None
1270+
1271+
#for i,k in cc.items():
1272+
# if out is None:
1273+
# out = k
1274+
# else:
1275+
# out += i*k
1276+
#if out is None:
1277+
# return self if inplace else self.copy()
1278+
#return self.set_array(out,inplace=inplace)
1279+
1280+
def filter_connected_components(self, labels: int |list[int]|None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,removed_to_label=0,keep_label=False, inplace=False):
12781281
"""
12791282
Filter connected components in a segmentation array based on specified volume constraints.
12801283
@@ -1289,7 +1292,7 @@ def filter_connected_components(self, labels: int |list[int]|None,min_volume:int
12891292
Returns:
12901293
None
12911294
"""
1292-
nii = self.get_largest_k_segmentation_connected_components(max_count_component,labels,connectivity=connectivity,return_original_labels=keep_label,min_volume=min_volume,max_volume=max_volume)
1295+
nii = self.get_largest_k_segmentation_connected_components(max_count_component,labels,connectivity=connectivity,return_original_labels=keep_label,min_volume=min_volume,max_volume=max_volume,removed_to_label=removed_to_label)
12931296
if keep_label and labels is not None:
12941297
if isinstance(labels,int):
12951298
labels = [labels]
@@ -1319,7 +1322,7 @@ def get_segmentation_connected_components_center_of_mass(self, label: int, conne
13191322
return np_get_connected_components_center_of_mass(arr, label=label, connectivity=connectivity, sort_by_axis=sort_by_axis)
13201323

13211324

1322-
def get_largest_k_segmentation_connected_components(self, k: int | None, labels: int | list[int] | None = None, connectivity: int = 1, return_original_labels: bool = True,inplace=False,min_volume:int=0,max_volume:int|None=None):
1325+
def get_largest_k_segmentation_connected_components(self, k: int | None, labels: int | list[int] | None = None, connectivity: int = 1, return_original_labels: bool = True,inplace=False,min_volume:int=0,max_volume:int|None=None,removed_to_label=0):
13231326
"""Finds the largest k connected components in a given array (does NOT work with zero as label!)
13241327
13251328
Args:
@@ -1329,7 +1332,7 @@ def get_largest_k_segmentation_connected_components(self, k: int | None, labels:
13291332
return_original_labels (bool): If set to False, will label the components from 1 to k. Defaults to True
13301333
"""
13311334
msk_i_data = self.get_seg_array()
1332-
out = np_get_largest_k_connected_components(msk_i_data, k=k, label_ref=labels, connectivity=connectivity, return_original_labels=return_original_labels,min_volume=min_volume,max_volume=max_volume)
1335+
out = np_get_largest_k_connected_components(msk_i_data, k=k, label_ref=labels, connectivity=connectivity, return_original_labels=return_original_labels,min_volume=min_volume,max_volume=max_volume,removed_to_label=removed_to_label)
13331336
return self.set_array(out,inplace=inplace)
13341337

13351338
def compute_surface_mask(self, connectivity: int, dilated_surface: bool = False):

TPTBox/core/np_utils.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def np_extract_label(
5353
arr = arr.copy()
5454

5555
if use_crop:
56-
crop = np_bbox_binary(arr, px_dist=1)
56+
crop = np_bbox_binary(arr, px_dist=1, raise_error=False)
5757
arrc = arr[crop]
5858
else:
5959
arrc = arr
@@ -82,10 +82,13 @@ def np_extract_label(
8282

8383
def cc3dstatistics(arr: UINTARRAY, use_crop: bool = True) -> dict:
8484
assert np.issubdtype(arr.dtype, np.unsignedinteger), f"cc3dstatistics expects uint type, got {arr.dtype}"
85-
if use_crop:
86-
crop = np_bbox_binary(arr)
87-
arrc = arr[crop]
88-
return _cc3dstats(arrc)
85+
try:
86+
if use_crop:
87+
crop = np_bbox_binary(arr, raise_error=False, px_dist=2)
88+
arrc = arr[crop]
89+
return _cc3dstats(arrc)
90+
except ValueError as e:
91+
print(e)
8992
return _cc3dstats(arr)
9093

9194

@@ -303,10 +306,7 @@ def np_dilate_msk(
303306
# try:
304307
arr_bin = arr.copy()
305308
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
306-
try:
307-
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel)
308-
except AssertionError:
309-
crop = tuple([slice(None)] * arr.ndim)
309+
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
310310
arrc = arr[crop]
311311
else:
312312
arrc = arr
@@ -379,10 +379,7 @@ def np_erode_msk(
379379
# try:
380380
arr_bin = arr.copy()
381381
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
382-
try:
383-
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel)
384-
except AssertionError:
385-
crop = tuple([slice(None)] * arr.ndim)
382+
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
386383
arrc = arr[crop]
387384
else:
388385
arrc = arr
@@ -461,9 +458,9 @@ def np_calc_crop_around_centerpoint(
461458
n_dim = len(poi)
462459
if isinstance(pad_to_size, int):
463460
pad_to_size = np.ones(n_dim) * pad_to_size
464-
assert (
465-
n_dim == len(arr.shape) == len(cutout_size) == len(pad_to_size)
466-
), f"dimension mismatch, got dim {n_dim}, poi {poi}, arr shape {arr.shape}, cutout {cutout_size}, pad_to_size {pad_to_size}"
461+
assert n_dim == len(arr.shape) == len(cutout_size) == len(pad_to_size), (
462+
f"dimension mismatch, got dim {n_dim}, poi {poi}, arr shape {arr.shape}, cutout {cutout_size}, pad_to_size {pad_to_size}"
463+
)
467464

468465
poi = tuple(int(i) for i in poi)
469466
shape = arr.shape
@@ -491,7 +488,7 @@ def np_calc_crop_around_centerpoint(
491488
)
492489

493490

494-
def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray = 0) -> tuple[slice, ...]:
491+
def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray = 0, raise_error=True) -> tuple[slice, ...]:
495492
"""calculates a bounding box in n dimensions given a image (factor ~2 times faster than compute_crop)
496493
497494
Args:
@@ -502,7 +499,11 @@ def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray =
502499
list of boundary coordinates as slices tuple
503500
"""
504501
assert img is not None, "bbox_nd: received None as image"
505-
assert np_count_nonzero(img) > 0, "bbox_nd: img is empty, cannot calculate a bbox"
502+
if np_count_nonzero(img) == 0:
503+
if raise_error:
504+
assert AssertionError("bbox_nd: img is empty, cannot calculate a bbox")
505+
return tuple([slice(None)] * img.ndim)
506+
506507
n = img.ndim
507508
shp = img.shape
508509
if isinstance(px_dist, int):
@@ -704,6 +705,8 @@ def np_get_largest_k_connected_components(
704705
return_original_labels: bool = True,
705706
min_volume: float = 0,
706707
max_volume: float | None = None,
708+
removed_to_label=0,
709+
_return_unsorted=False,
707710
) -> UINTARRAY:
708711
"""finds the largest k connected components in a given array (does NOT work with zero as label!)
709712
@@ -731,6 +734,8 @@ def np_get_largest_k_connected_components(
731734
arr2[np.isin(arr, labels, invert=True)] = 0 # type:ignore
732735

733736
labels_out, n = connected_components(arr2, connectivity=connectivity, return_N=True)
737+
if _return_unsorted:
738+
return labels_out
734739
if k is None:
735740
k = n
736741
k = min(k, n) # if k > N, will return all N but still sorted
@@ -750,7 +755,8 @@ def np_get_largest_k_connected_components(
750755
if k == i:
751756
break
752757
i += 1
753-
758+
if removed_to_label != 0:
759+
arr[np.logical_and(labels_out != 0, arr == 0)] = removed_to_label
754760
if return_original_labels:
755761
arr *= cc_out > 0 # to get original labels
756762
return arr
@@ -843,10 +849,7 @@ def np_translate_arr(arr: np.ndarray, translation_vector: tuple[int, int] | tupl
843849

844850

845851
def np_fill_holes(
846-
arr: np.ndarray,
847-
label_ref: LABEL_REFERENCE = None,
848-
slice_wise_dim: int | None = None,
849-
use_crop: bool = True,
852+
arr: np.ndarray, label_ref: LABEL_REFERENCE = None, slice_wise_dim: int | None = None, use_crop: bool = True, pbar=False
850853
) -> np.ndarray:
851854
"""Fills holes in segmentations
852855
@@ -863,16 +866,19 @@ def np_fill_holes(
863866
labels: Sequence[int] = _to_labels(arr, label_ref)
864867

865868
if use_crop:
866-
gcrop = np_bbox_binary(arr, px_dist=1)
869+
gcrop = np_bbox_binary(arr, px_dist=1, raise_error=False)
867870
arrc = arr[gcrop]
868871
else:
869872
arrc = arr
873+
if pbar:
874+
from tqdm import tqdm
870875

876+
labels = tqdm(labels, desc="fill_holes") # type: ignore
871877
for l in labels: # type:ignore
872878
arr_l = arrc.copy()
873879
arr_l = np_extract_label(arr_l, l)
874880
if use_crop:
875-
crop = np_bbox_binary(arr_l, px_dist=1)
881+
crop = np_bbox_binary(arr_l, px_dist=1, raise_error=False)
876882
arr_lc = arr_l[crop]
877883
else:
878884
arr_lc = arr_l

TPTBox/core/poi.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def zoom(self, value):
182182
else:
183183
self._zoom = tuple(round(float(v), ROUNDING_LVL) for v in value) # type: ignore
184184

185+
@spacing.setter
186+
def spacing(self, value):
187+
self.zoom = value
188+
185189
def clone(self, **qargs):
186190
return self.copy(**qargs)
187191

@@ -560,12 +564,7 @@ def rescale(self, voxel_spacing: ZOOMS = (1, 1, 1), decimals=ROUNDING_LVL, verbo
560564
return self.copy(centroids=points, zoom=voxel_spacing, shape=shp)
561565

562566
def rescale_(self, voxel_spacing: ZOOMS = (1, 1, 1), decimals=3, verbose: logging = False) -> Self:
563-
return self.rescale(
564-
voxel_spacing=voxel_spacing,
565-
decimals=decimals,
566-
verbose=verbose,
567-
inplace=True,
568-
)
567+
return self.rescale(voxel_spacing=voxel_spacing, decimals=decimals, verbose=verbose, inplace=True)
569568

570569
def to_global(self):
571570
"""Converts the Centroids object to a global POI_Global object.
@@ -666,8 +665,8 @@ def make_point_cloud_nii(self, affine=None, s=8, sphere=False):
666665
affine = self.affine
667666
arr = np.zeros(self.shape_int)
668667
arr2 = np.zeros(self.shape_int)
669-
s1 = min(s // 2, 1)
670-
s2 = min(s - s1, 1)
668+
s1 = max(s // 2, 1)
669+
s2 = max(s - s1, 1)
671670
from math import ceil, floor
672671

673672
if sphere:
@@ -723,7 +722,7 @@ def filter_points_inside_shape(self, inplace=False) -> Self:
723722
return self.copy(filtered_centroids)
724723

725724
@classmethod
726-
def load(cls, poi: POI_Reference):
725+
def load(cls, poi: POI_Reference, reference: Has_Grid | None = None):
727726
"""Load a Centroids object from various input sources.
728727
729728
This method provides a convenient way to load a Centroids object from different sources,
@@ -761,7 +760,18 @@ def load(cls, poi: POI_Reference):
761760
>>> existing_poi = POI(...)
762761
>>> loaded_poi = POI.load(existing_poi)
763762
"""
764-
return load_poi(poi)
763+
poi_obj = load_poi(poi)
764+
if reference is not None:
765+
if poi_obj.spacing is None:
766+
poi_obj.spacing = reference.spacing
767+
if poi_obj.rotation is None:
768+
poi_obj.rotation = reference.rotation
769+
if poi_obj.shape is None:
770+
poi_obj.shape = reference.shape
771+
if poi_obj.origin is None:
772+
poi_obj.origin = reference.origin
773+
reference.assert_affine(poi_obj)
774+
return poi_obj
765775

766776
def assert_affine(
767777
self,
@@ -1468,12 +1478,7 @@ def _is_not_yet_computed(ids_in_arr: Sequence[int], extend_to: POI | None, subre
14681478

14691479

14701480
def calc_centroids(
1471-
msk: Image_Reference,
1472-
decimals=3,
1473-
first_stage=-1,
1474-
second_stage: int | Location = 50,
1475-
extend_to: POI | None = None,
1476-
inplace: bool = False,
1481+
msk: Image_Reference, decimals=3, first_stage=-1, second_stage: int | Location = 50, extend_to: POI | None = None, inplace: bool = False
14771482
) -> POI:
14781483
"""
14791484
Calculates the centroid coordinates of each region in the given mask image.
@@ -1531,10 +1536,7 @@ def calc_centroids(
15311536
######## Utility #######
15321537

15331538

1534-
def calc_poi_average(
1535-
pois: list[POI],
1536-
keep_points_not_present_in_all_pois: bool = False,
1537-
) -> POI:
1539+
def calc_poi_average(pois: list[POI], keep_points_not_present_in_all_pois: bool = False) -> POI:
15381540
"""Calculates average of POI across list of POIs and removes all points that are not fully present in all given POIs
15391541
15401542
Args:

TPTBox/core/poi_fun/poi_global.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@ class POI_Global(Abstract_POI):
1717
Inherits from the `Abstract_POI` class and contains methods for converting the POI to different coordinate systems.
1818
"""
1919

20-
def __init__(self, input_poi: poi.POI | POI_Descriptor):
20+
def __init__(self, input_poi: poi.POI | POI_Descriptor | dict[str, dict[str, tuple[float, ...]]], itk_coords: bool = False):
21+
self.itk_coords = itk_coords
22+
23+
if isinstance(input_poi, dict):
24+
global_points = POI_Descriptor()
25+
self.info = {}
26+
self.format = None
27+
for k1, d1 in input_poi.items():
28+
for k2, v in d1.items():
29+
global_points[k1:k2] = v
30+
2131
if isinstance(input_poi, POI_Descriptor):
2232
global_points = input_poi
2333
self.info = {}
@@ -78,7 +88,7 @@ def to_other_poi(self, ref: poi.POI_Reference) -> poi.POI:
7888
"""
7989
return self.to_other(poi.POI.load(ref))
8090

81-
def to_other(self, msk: Has_Grid) -> poi.POI:
91+
def to_other(self, msk: Has_Grid, verbose=False) -> poi.POI:
8292
"""
8393
Convert the POI to another coordinate system.
8494
@@ -90,7 +100,12 @@ def to_other(self, msk: Has_Grid) -> poi.POI:
90100
"""
91101
out = poi.POI_Descriptor(definition=self._get_centroids().definition)
92102
for k1, k2, v in self.items():
103+
if self.itk_coords:
104+
assert len(v) == 3, "n-d vec not implemented for n != 3"
105+
v = (-v[0], -v[1], v[2]) # noqa: PLW2901
93106
v_out = msk.global_to_local(v)
107+
if verbose:
108+
print(v, "-->", v_out)
94109
out[k1, k2] = tuple(v_out)
95110

96111
return poi.POI(centroids=out, **msk._extract_affine(), info=self.info, format=self.format)
@@ -101,4 +116,5 @@ def copy(self, centroids: POI_Descriptor | None = None) -> Self:
101116
p = POI_Global(centroids)
102117
p.format = self.format
103118
p.info = deepcopy(self.info)
119+
p.itk_coords = self.itk_coords
104120
return p # type: ignore

0 commit comments

Comments
 (0)