@@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
179179 type : type [Node ]
180180 flatten : tp .Callable [[Node ], tuple [tp .Sequence [tuple [Key , Leaf ]], AuxData ]]
181181
182- def node_dict (self , node : Node ) -> dict [Key , Leaf ]:
183- nodes , _ = self .flatten (node )
184- return dict (nodes )
182+ def node_dict (self , node : Node ) -> dict [Key , tp .Any ]:
183+ node_seq , _ = self .flatten (node )
184+ nodes = {
185+ key : node .value if isinstance (node , DataElem | StaticElem ) else node
186+ for key , node in node_seq
187+ }
188+ return nodes
185189
186190
187191@dataclasses .dataclass (frozen = True , slots = True )
@@ -533,32 +537,21 @@ def __treescope_repr__(self, path, subtree_renderer):
533537
534538
535539@dataclasses .dataclass (frozen = True , slots = True )
536- class ArrayAttr :
537- pass
538-
539-
540- ARRAY_ATTR = ArrayAttr ()
541-
542-
543- @dataclasses .dataclass (frozen = True , slots = True )
544- class MutableArrayAttr :
540+ class NodeAttr :
545541 pass
546542
547543
548- MUTABLE_ARRAY_ATTR = MutableArrayAttr ()
549-
544+ NODE_ATTR = NodeAttr ()
550545
551546@dataclasses .dataclass (frozen = True , slots = True )
552- class NodeAttr :
547+ class LeafAttr :
553548 pass
554549
555-
556- NODE_ATTR = NodeAttr ()
550+ LEAF_ATTR = LeafAttr ()
557551
558552AttrType = tp .Union [
559553 NodeAttr ,
560- ArrayAttr ,
561- MutableArrayAttr ,
554+ LeafAttr ,
562555 'Static[tp.Any]' ,
563556]
564557
@@ -710,6 +703,14 @@ def flatten( # type: ignore[invalid-annotation]
710703 else :
711704 return graphdef , leaves
712705
706+ @dataclasses .dataclass (frozen = True , slots = True )
707+ class DataElem :
708+ value : tp .Any
709+
710+
711+ @dataclasses .dataclass (frozen = True , slots = True )
712+ class StaticElem :
713+ value : tp .Any
713714
714715def _graph_flatten (
715716 node : Node ,
@@ -827,6 +828,18 @@ def make_mutable_arraydef(value: variablelib.Ref):
827828 nodes .append (nodedef )
828829
829830 for key , value in values :
831+ is_data = None
832+ if isinstance (value , DataElem ):
833+ value = value .value
834+ is_data = True
835+ elif isinstance (value , StaticElem ):
836+ value = value .value
837+ is_data = False
838+
839+ if is_data is False :
840+ attributes .append ((key , Static (value )))
841+ continue
842+
830843 value_node_impl = get_node_impl (value )
831844 if path is not None :
832845 path .append (key )
@@ -844,15 +857,15 @@ def make_mutable_arraydef(value: variablelib.Ref):
844857 paths ,
845858 )
846859 elif variablelib .is_array_ref (value ):
847- attributes .append ((key , MUTABLE_ARRAY_ATTR ))
860+ attributes .append ((key , NODE_ATTR ))
848861 array_refdef , leaf = make_mutable_arraydef (value )
849862 if not isinstance (leaf , Repeated ):
850863 leaves .append (leaf )
851864 if paths is not None :
852865 paths .append (tuple (path )) # type: ignore
853866 nodes .append (array_refdef )
854- elif isinstance (value , (jax .Array , np .ndarray )):
855- attributes .append ((key , ARRAY_ATTR ))
867+ elif isinstance (value , (jax .Array , np .ndarray )) or is_data :
868+ attributes .append ((key , LEAF_ATTR ))
856869 if paths is not None :
857870 paths .append (tuple (path )) # type: ignore
858871 leaves .append (value )
@@ -1092,41 +1105,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
10921105 key , value = next (attribute_iter )
10931106 if type (value ) is Static :
10941107 children .append ((key , value .value )) # type: ignore[attribute-error]
1095- elif type (value ) is MutableArrayAttr :
1096- array_refdef = next (node_iter )
1097- assert (
1098- type (array_refdef ) is ArrayRefDef or type (array_refdef ) is NodeRef
1099- )
1100- if type (array_refdef ) is NodeRef :
1101- array_ref = index_ref [array_refdef .index ]
1102- else :
1103- assert type (array_refdef ) is ArrayRefDef
1108+ elif type (value ) is LeafAttr :
1109+ leaf = next (leaves_iter )
1110+ children .append ((key , leaf ))
1111+ elif type (value ) is NodeAttr :
1112+ node_def = next (node_iter )
1113+ if isinstance (node_def , NodeRef ):
1114+ node = index_ref [node_def .index ]
1115+ elif isinstance (node_def , ArrayRefDef ):
11041116 leaf = next (leaves_iter )
1105- array_ref = get_mutable_array (array_refdef , leaf )
1106- children .append ((key , array_ref ))
1107- elif type (value ) is ArrayAttr :
1108- array = next (leaves_iter )
1109- children .append ((key , array ))
1117+ node = get_mutable_array (node_def , leaf )
1118+ elif isinstance (node_def , NodeDef | VariableDef ):
1119+ value_node_impl = get_node_impl_for_type (node_def .type )
1120+ node = _graph_unflatten (
1121+ node_def ,
1122+ value_node_impl ,
1123+ node_iter ,
1124+ attribute_iter ,
1125+ leaves_iter ,
1126+ index_ref ,
1127+ outer_index_outer_ref ,
1128+ copy_variables ,
1129+ )
1130+ else :
1131+ raise RuntimeError (f'Unknown node definition: { node_def !r} ' )
1132+ children .append ((key , node ))
11101133 elif type (value ) is NodeRef :
11111134 children .append ((key , index_ref [value .index ])) # type: ignore[attribute-error]
1112- elif type (value ) is NodeAttr :
1113- # if the key is a subgraph we create an empty node
1114- subgraphdef = next (node_iter )
1115- if type (subgraphdef ) is NodeDef :
1116- value_node_impl = get_node_impl_for_type (subgraphdef .type ) # type: ignore[attribute-error]
1117- else :
1118- value_node_impl = None
1119- subnode = _graph_unflatten (
1120- subgraphdef ,
1121- value_node_impl ,
1122- node_iter ,
1123- attribute_iter ,
1124- leaves_iter ,
1125- index_ref ,
1126- outer_index_outer_ref ,
1127- copy_variables ,
1128- )
1129- children .append ((key , subnode ))
11301135 else :
11311136 raise RuntimeError (f'Unknown static field: { key !r} ' )
11321137
0 commit comments