Skip to content

Commit 4cc1d1b

Browse files
authored
Merge pull request #5159 from samanklesaria/pr/5140
Have _graph_flatten respect nnx.data declarations (extension of #5140)
2 parents dddd218 + ebd7868 commit 4cc1d1b

File tree

3 files changed

+90
-59
lines changed

3 files changed

+90
-59
lines changed

flax/nnx/graph.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
179179
type: type[Node]
180180
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]
181181

182-
def node_dict(self, node: Node) -> dict[Key, Leaf]:
183-
nodes, _ = self.flatten(node)
184-
return dict(nodes)
182+
def node_dict(self, node: Node) -> dict[Key, tp.Any]:
183+
node_seq, _ = self.flatten(node)
184+
nodes = {
185+
key: node.value if isinstance(node, DataElem | StaticElem) else node
186+
for key, node in node_seq
187+
}
188+
return nodes
185189

186190

187191
@dataclasses.dataclass(frozen=True, slots=True)
@@ -533,32 +537,21 @@ def __treescope_repr__(self, path, subtree_renderer):
533537

534538

535539
@dataclasses.dataclass(frozen=True, slots=True)
536-
class ArrayAttr:
537-
pass
538-
539-
540-
ARRAY_ATTR = ArrayAttr()
541-
542-
543-
@dataclasses.dataclass(frozen=True, slots=True)
544-
class MutableArrayAttr:
540+
class NodeAttr:
545541
pass
546542

547543

548-
MUTABLE_ARRAY_ATTR = MutableArrayAttr()
549-
544+
NODE_ATTR = NodeAttr()
550545

551546
@dataclasses.dataclass(frozen=True, slots=True)
552-
class NodeAttr:
547+
class LeafAttr:
553548
pass
554549

555-
556-
NODE_ATTR = NodeAttr()
550+
LEAF_ATTR = LeafAttr()
557551

558552
AttrType = tp.Union[
559553
NodeAttr,
560-
ArrayAttr,
561-
MutableArrayAttr,
554+
LeafAttr,
562555
'Static[tp.Any]',
563556
]
564557

@@ -710,6 +703,14 @@ def flatten( # type: ignore[invalid-annotation]
710703
else:
711704
return graphdef, leaves
712705

706+
@dataclasses.dataclass(frozen=True, slots=True)
707+
class DataElem:
708+
value: tp.Any
709+
710+
711+
@dataclasses.dataclass(frozen=True, slots=True)
712+
class StaticElem:
713+
value: tp.Any
713714

714715
def _graph_flatten(
715716
node: Node,
@@ -827,6 +828,18 @@ def make_mutable_arraydef(value: variablelib.Ref):
827828
nodes.append(nodedef)
828829

829830
for key, value in values:
831+
is_data = None
832+
if isinstance(value, DataElem):
833+
value = value.value
834+
is_data = True
835+
elif isinstance(value, StaticElem):
836+
value = value.value
837+
is_data = False
838+
839+
if is_data is False:
840+
attributes.append((key, Static(value)))
841+
continue
842+
830843
value_node_impl = get_node_impl(value)
831844
if path is not None:
832845
path.append(key)
@@ -844,15 +857,15 @@ def make_mutable_arraydef(value: variablelib.Ref):
844857
paths,
845858
)
846859
elif variablelib.is_array_ref(value):
847-
attributes.append((key, MUTABLE_ARRAY_ATTR))
860+
attributes.append((key, NODE_ATTR))
848861
array_refdef, leaf = make_mutable_arraydef(value)
849862
if not isinstance(leaf, Repeated):
850863
leaves.append(leaf)
851864
if paths is not None:
852865
paths.append(tuple(path)) # type: ignore
853866
nodes.append(array_refdef)
854-
elif isinstance(value, (jax.Array, np.ndarray)):
855-
attributes.append((key, ARRAY_ATTR))
867+
elif isinstance(value, (jax.Array, np.ndarray)) or is_data:
868+
attributes.append((key, LEAF_ATTR))
856869
if paths is not None:
857870
paths.append(tuple(path)) # type: ignore
858871
leaves.append(value)
@@ -1092,41 +1105,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
10921105
key, value = next(attribute_iter)
10931106
if type(value) is Static:
10941107
children.append((key, value.value)) # type: ignore[attribute-error]
1095-
elif type(value) is MutableArrayAttr:
1096-
array_refdef = next(node_iter)
1097-
assert (
1098-
type(array_refdef) is ArrayRefDef or type(array_refdef) is NodeRef
1099-
)
1100-
if type(array_refdef) is NodeRef:
1101-
array_ref = index_ref[array_refdef.index]
1102-
else:
1103-
assert type(array_refdef) is ArrayRefDef
1108+
elif type(value) is LeafAttr:
1109+
leaf = next(leaves_iter)
1110+
children.append((key, leaf))
1111+
elif type(value) is NodeAttr:
1112+
node_def = next(node_iter)
1113+
if isinstance(node_def, NodeRef):
1114+
node = index_ref[node_def.index]
1115+
elif isinstance(node_def, ArrayRefDef):
11041116
leaf = next(leaves_iter)
1105-
array_ref = get_mutable_array(array_refdef, leaf)
1106-
children.append((key, array_ref))
1107-
elif type(value) is ArrayAttr:
1108-
array = next(leaves_iter)
1109-
children.append((key, array))
1117+
node = get_mutable_array(node_def, leaf)
1118+
elif isinstance(node_def, NodeDef | VariableDef):
1119+
value_node_impl = get_node_impl_for_type(node_def.type)
1120+
node = _graph_unflatten(
1121+
node_def,
1122+
value_node_impl,
1123+
node_iter,
1124+
attribute_iter,
1125+
leaves_iter,
1126+
index_ref,
1127+
outer_index_outer_ref,
1128+
copy_variables,
1129+
)
1130+
else:
1131+
raise RuntimeError(f'Unknown node definition: {node_def!r}')
1132+
children.append((key, node))
11101133
elif type(value) is NodeRef:
11111134
children.append((key, index_ref[value.index])) # type: ignore[attribute-error]
1112-
elif type(value) is NodeAttr:
1113-
# if the key is a subgraph we create an empty node
1114-
subgraphdef = next(node_iter)
1115-
if type(subgraphdef) is NodeDef:
1116-
value_node_impl = get_node_impl_for_type(subgraphdef.type) # type: ignore[attribute-error]
1117-
else:
1118-
value_node_impl = None
1119-
subnode = _graph_unflatten(
1120-
subgraphdef,
1121-
value_node_impl,
1122-
node_iter,
1123-
attribute_iter,
1124-
leaves_iter,
1125-
index_ref,
1126-
outer_index_outer_ref,
1127-
copy_variables,
1128-
)
1129-
children.append((key, subnode))
11301135
else:
11311136
raise RuntimeError(f'Unknown static field: {key!r}')
11321137

flax/nnx/pytreelib.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import warnings
2525

2626
from flax.nnx import variablelib
27+
from flax import nnx
2728
import jax
2829
import numpy as np
2930
import treescope # type: ignore[import-untyped]
@@ -917,8 +918,19 @@ def _pytree__unflatten(
917918
# Graph Definition
918919
# -------------------------
919920
def _graph_node_flatten(self):
920-
nodes = vars(self)
921-
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
921+
pytree_nodes = self._pytree__nodes
922+
nodes = (
923+
(
924+
name,
925+
value
926+
if not self._pytree__is_pytree
927+
else nnx.graph.DataElem(value)
928+
if name in pytree_nodes and pytree_nodes[name]
929+
else nnx.graph.StaticElem(value)
930+
)
931+
for name, value in vars(self).items()
932+
)
933+
nodes = sorted(nodes, key=self._pytree__key_sort_fn)
922934
return nodes, type(self)
923935

924936
def _graph_node_set_key(self, key: str, value: tp.Any):

tests/nnx/graph_utils_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def __init__(self):
629629
self.assertFalse(hasattr(ctx, 'ctxtag'))
630630
self.assertIsInstance(graphdef1.nodes[0], nnx.graph.NodeDef)
631631
self.assertIsInstance(graphdef2.nodes[0], nnx.graph.NodeRef)
632-
self.assertLen(nnx.to_flat_state(state1), 1)
632+
self.assertLen(nnx.to_flat_state(state1), 2)
633633
self.assertLen(nnx.to_flat_state(state2), 0)
634634

635635
@jax.jit
@@ -717,7 +717,7 @@ def __init__(self):
717717
assert isinstance(t2, nnx.NodeStates)
718718
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
719719
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
720-
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
720+
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
721721
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
722722

723723
@jax.jit
@@ -744,7 +744,7 @@ def f(pure_tree):
744744
assert isinstance(t2, nnx.NodeStates)
745745
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
746746
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
747-
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
747+
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
748748
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
749749

750750
return pure_tree2
@@ -762,6 +762,20 @@ def f(pure_tree):
762762
self.assertEqual(m.b[...], 1) # type: ignore
763763
self.assertEqual(impure_tree2[1], 1)
764764

765+
def test_graph_flatten_with_data_wrapper(self):
766+
class Foo(nnx.Pytree):
767+
def __init__(self, value, static):
768+
self.value = nnx.data(value)
769+
self.static = nnx.static(static)
770+
771+
tree = Foo(1, 2)
772+
state = nnx.state(tree)
773+
774+
self.assertIn('value', state)
775+
self.assertIsInstance(state['value'], int)
776+
self.assertEqual(state['value'], 1)
777+
self.assertNotIn('static', state)
778+
765779
def test_to_tree_consistent_prefix(self):
766780
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
767781
impure_tree = (m, 1, {'b': m})

0 commit comments

Comments
 (0)