Skip to content

Commit 8aa3efa

Browse files
authored
Merge pull request #140 from funkelab/47-move-the-graph-creation-code-from-the-toolbox-into-funtracks
move motile-toolbox code to funtracks
2 parents d0745b8 + cec43f3 commit 8aa3efa

12 files changed

Lines changed: 915 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies =[
3939
"pandas>=2.3.3",
4040
"zarr>=2.18,<4",
4141
"numcodecs>=0.13,<0.16",
42+
"tqdm>=4.66.1",
4243
]
4344

4445
[project.urls]
@@ -107,6 +108,7 @@ unfixable = [
107108
[tool.mypy]
108109
ignore_missing_imports = true
109110
python_version = "3.10"
111+
explicit_package_bases = true
110112

111113
[tool.coverage.report]
112114
exclude_also = [
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .compute_graph import (
2+
compute_graph_from_points_list,
3+
compute_graph_from_seg,
4+
)
5+
from .iou import add_iou
6+
from .utils import add_cand_edges, nodes_from_segmentation
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import logging
2+
3+
import networkx as nx
4+
import numpy as np
5+
6+
from .iou import add_iou
7+
from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def compute_graph_from_seg(
13+
segmentation: np.ndarray,
14+
max_edge_distance: float,
15+
iou: bool = False,
16+
scale: list[float] | None = None,
17+
) -> nx.DiGraph:
18+
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
19+
centroid of each segmentation and edges are added for all nodes in adjacent frames
20+
within max_edge_distance.
21+
22+
Args:
23+
segmentation (np.ndarray): A numpy array with integer labels and dimensions
24+
(t, [z], y, x).
25+
max_edge_distance (float): Maximum distance that objects can travel between
26+
frames. All nodes with centroids within this distance in adjacent frames
27+
will by connected with a candidate edge.
28+
iou (bool, optional): Whether to include IOU on the candidate graph.
29+
Defaults to False.
30+
scale (list[float] | None, optional): The scale of the segmentation data.
31+
Will be used to rescale the point locations and attribute computations.
32+
Defaults to None, which implies the data is isotropic.
33+
34+
Returns:
35+
nx.DiGraph: A candidate graph that can be passed to the motile solver
36+
"""
37+
# add nodes
38+
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale)
39+
logger.info("Candidate nodes: %d", cand_graph.number_of_nodes())
40+
41+
# add edges
42+
add_cand_edges(
43+
cand_graph,
44+
max_edge_distance=max_edge_distance,
45+
node_frame_dict=node_frame_dict,
46+
)
47+
if iou:
48+
# Scale does not matter to IOU, because both numerator and denominator
49+
# are scaled by the anisotropy.
50+
add_iou(cand_graph, segmentation, node_frame_dict)
51+
52+
logger.info("Candidate edges: %d", cand_graph.number_of_edges())
53+
54+
return cand_graph
55+
56+
57+
def compute_graph_from_points_list(
58+
points_list: np.ndarray,
59+
max_edge_distance: float,
60+
scale: list[float] | None = None,
61+
) -> nx.DiGraph:
62+
"""Construct a candidate graph from a points list.
63+
64+
Args:
65+
points_list (np.ndarray): An NxD numpy array with N points and D
66+
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
67+
max_edge_distance (float): Maximum distance that objects can travel between
68+
frames. All nodes with centroids within this distance in adjacent frames
69+
will by connected with a candidate edge.
70+
scale (list[float] | None, optional): Amount to scale the points in each
71+
dimension. Only needed if the provided points are in "voxel" coordinates
72+
instead of world coordinates. Defaults to None, which implies the data is
73+
isotropic.
74+
75+
Returns:
76+
nx.DiGraph: A candidate graph that can be passed to the motile solver.
77+
"""
78+
# add nodes
79+
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
80+
logger.info("Candidate nodes: %d", cand_graph.number_of_nodes())
81+
# add edges
82+
add_cand_edges(
83+
cand_graph,
84+
max_edge_distance=max_edge_distance,
85+
node_frame_dict=node_frame_dict,
86+
)
87+
return cand_graph
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from itertools import product
2+
3+
import networkx as nx
4+
import numpy as np
5+
from tqdm import tqdm
6+
7+
from funtracks.data_model.graph_attributes import EdgeAttr
8+
9+
from .utils import _compute_node_frame_dict
10+
11+
12+
def _compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> list[tuple[int, int, float]]:
13+
"""Compute label IOUs between two label arrays of the same shape. Ignores background
14+
(label 0).
15+
16+
Args:
17+
frame1 (np.ndarray): Array with integer labels
18+
frame2 (np.ndarray): Array with integer labels
19+
20+
Returns:
21+
list[tuple[int, int, float]]: List of tuples of label in frame 1, label in
22+
frame 2, and iou values. Labels that have no overlap are not included.
23+
"""
24+
frame1 = frame1.flatten()
25+
frame2 = frame2.flatten()
26+
# get indices where both are not zero (ignore background)
27+
# this speeds up computation significantly
28+
non_zero_indices = np.logical_and(frame1, frame2)
29+
flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]])
30+
31+
values, counts = np.unique(flattened_stacked, axis=1, return_counts=True)
32+
frame1_values, frame1_counts = np.unique(frame1, return_counts=True)
33+
frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True))
34+
frame2_values, frame2_counts = np.unique(frame2, return_counts=True)
35+
frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True))
36+
ious: list[tuple[int, int, float]] = []
37+
for index in range(values.shape[1]):
38+
pair = values[:, index]
39+
intersection = counts[index]
40+
id1, id2 = pair
41+
union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection
42+
ious.append((id1, id2, intersection / union))
43+
return ious
44+
45+
46+
def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]:
47+
"""Get all ious values for the provided segmentations (all frames).
48+
Will return as map from node_id -> dict[node_id] -> iou for easy
49+
navigation when adding to candidate graph.
50+
51+
Args:
52+
segmentation (np.ndarray): Segmentations that were used to create cand_graph.
53+
Has shape ([h], t, [z], y, x), where h is the number of hypotheses
54+
if multiseg is True.
55+
multiseg (bool): Flag indicating if the provided segmentation contains
56+
multiple hypothesis segmentations. Defaults to False.
57+
58+
Returns:
59+
dict[int, dict[int, float]]: A map from node id to another dictionary, which
60+
contains node_ids to iou values.
61+
"""
62+
iou_dict: dict[int, dict[int, float]] = {}
63+
hypo_pairs: list[tuple[int, ...]] = [(0, 0)]
64+
if multiseg:
65+
num_hypotheses = segmentation.shape[0]
66+
if num_hypotheses > 1:
67+
hypo_pairs = list(product(range(num_hypotheses), repeat=2))
68+
else:
69+
segmentation = np.expand_dims(segmentation, 0)
70+
71+
for frame in range(segmentation.shape[1] - 1):
72+
for hypo1, hypo2 in hypo_pairs:
73+
seg1 = segmentation[hypo1][frame]
74+
seg2 = segmentation[hypo2][frame + 1]
75+
ious = _compute_ious(seg1, seg2)
76+
for label1, label2, iou in ious:
77+
if label1 not in iou_dict:
78+
iou_dict[label1] = {}
79+
iou_dict[label1][label2] = iou
80+
return iou_dict
81+
82+
83+
def add_iou(
84+
cand_graph: nx.DiGraph,
85+
segmentation: np.ndarray,
86+
node_frame_dict: dict[int, list[int]] | None = None,
87+
multiseg=False,
88+
) -> None:
89+
"""Add IOU to the candidate graph.
90+
91+
Args:
92+
cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated
93+
segmentation (np.ndarray): segmentation that was used to create cand_graph.
94+
Has shape ([h], t, [z], y, x), where h is the number of hypotheses if
95+
multiseg is True.
96+
node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from
97+
time frames to nodes in that frame. Will be computed if not provided,
98+
but can be provided for efficiency (e.g. after running
99+
nodes_from_segmentation). Defaults to None.
100+
multiseg (bool): Flag indicating if the given segmentation is actually multiple
101+
stacked segmentations. Defaults to False.
102+
"""
103+
if node_frame_dict is None:
104+
node_frame_dict = _compute_node_frame_dict(cand_graph)
105+
frames = sorted(node_frame_dict.keys())
106+
ious = _get_iou_dict(segmentation, multiseg=multiseg)
107+
for frame in tqdm(frames):
108+
if frame + 1 not in node_frame_dict:
109+
continue
110+
next_nodes = node_frame_dict[frame + 1]
111+
for node_id in node_frame_dict[frame]:
112+
for next_id in next_nodes:
113+
iou = ious.get(node_id, {}).get(next_id, 0)
114+
if (node_id, next_id) in cand_graph.edges:
115+
cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou

0 commit comments

Comments
 (0)