Skip to content

Commit 1d5ee61

Browse files
committed
allow option of different dict per node
1 parent ceb6cd7 commit 1d5ee61

2 files changed

Lines changed: 31 additions & 5 deletions

File tree

src/funtracks/user_actions/user_update_nodes_attrs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,30 @@ class UserUpdateNodesAttrs(ActionGroup):
2020
Args:
2121
tracks: The tracks to update node attributes for.
2222
nodes: The node ids to update.
23-
attrs: A mapping from attribute name to new attribute values,
24-
applied to all nodes.
23+
attrs: Either a single dict applied to all nodes, or a list of dicts
24+
with one entry per node (must match the length of nodes).
2525
"""
2626

2727
def __init__(
2828
self,
2929
tracks: SolutionTracks,
3030
nodes: list[int],
31-
attrs: dict[str, Any],
31+
attrs: dict[str, Any] | list[dict[str, Any]],
3232
):
3333
super().__init__(tracks, actions=[])
3434
self.tracks: SolutionTracks # Narrow type from base class
35-
for node in nodes:
35+
if isinstance(attrs, list):
36+
if len(attrs) != len(nodes):
37+
raise ValueError(
38+
f"attrs list length ({len(attrs)}) must match "
39+
f"nodes length ({len(nodes)})"
40+
)
41+
per_node_attrs = attrs
42+
else:
43+
per_node_attrs = [attrs] * len(nodes)
44+
for node, node_attrs in zip(nodes, per_node_attrs, strict=True):
3645
self.actions.append(
37-
UserUpdateNodeAttrs(tracks, node, attrs, _top_level=False)
46+
UserUpdateNodeAttrs(tracks, node, node_attrs, _top_level=False)
3847
)
3948

4049
self.tracks.action_history.add_new_action(self)

tests/user_actions/test_user_update_nodes_attrs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,23 @@ def test_undo_redo(self, get_tracks, ndim, with_seg):
4545
for node in [1, 2]:
4646
assert tracks.get_node_attr(node, "score") == 0.9
4747

48+
def test_per_node_attrs(self, get_tracks, ndim, with_seg):
49+
"""Test bulk update with a different attr dict per node."""
50+
tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
51+
52+
per_node = [{"score": 0.1}, {"score": 0.9}]
53+
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs=per_node)
54+
55+
assert tracks.get_node_attr(1, "score") == 0.1
56+
assert tracks.get_node_attr(2, "score") == 0.9
57+
58+
def test_per_node_attrs_length_mismatch_raises(self, get_tracks, ndim, with_seg):
59+
"""Mismatched list length raises ValueError."""
60+
tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)
61+
62+
with pytest.raises(ValueError, match="attrs list length"):
63+
UserUpdateNodesAttrs(tracks, nodes=[1, 2], attrs=[{"score": 0.1}])
64+
4865
def test_protected_attr_raises(self, get_tracks, ndim, with_seg):
4966
"""Passing a protected attribute raises ValueError."""
5067
tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True)

0 commit comments

Comments
 (0)