@@ -2029,8 +2029,44 @@ def test_allreduce_time(self, mode):
20292029 op = Operator ([Eq (ux .forward , ux + tt ), Inc (g , ux )], name = "Op" )
20302030 assert_structure (op , ['t,x,y' , 't' ], 'txy' )
20312031
2032+ # Reduce should be in time loop but not in space loop
2033+ iters = FindNodes (Iteration ).visit (op )
2034+ for i in iters :
2035+ if i .dim .is_Time :
2036+ assert len (FindNodes (Call ).visit (i )) == 1 # one allreduce
2037+ else :
2038+ assert len (FindNodes (Call ).visit (i )) == 0
2039+
2040+ op .apply (time_m = 0 , time_M = nt - 1 )
2041+ assert np .isclose (np .max (g .data ), 4356.0 )
2042+
2043+ @pytest .mark .parallel (mode = 2 )
2044+ def test_multi_allreduce_time (self , mode ):
2045+ space_order = 8
2046+ nx , ny = 11 , 11
2047+
2048+ grid = Grid (shape = (nx , ny ))
2049+ tt = grid .time_dim
2050+ nt = 10
2051+
2052+ ux = TimeFunction (name = "ux" , grid = grid , time_order = 1 , space_order = space_order )
2053+ g = TimeFunction (name = "g" , grid = grid , dimensions = (tt , ), shape = (nt ,))
2054+ h = TimeFunction (name = "h" , grid = grid , dimensions = (tt , ), shape = (nt ,))
2055+
2056+ op = Operator ([Eq (ux .forward , ux + tt ), Inc (g , ux ), Inc (h , ux )], name = "Op" )
2057+ assert_structure (op , ['t,x,y' , 't' ], 'txy' )
2058+
2059+ # Make sure the two allreduce calls are in the time the loop
2060+ iters = FindNodes (Iteration ).visit (op )
2061+ for i in iters :
2062+ if i .dim .is_Time :
2063+ assert len (FindNodes (Call ).visit (i )) == 2 # Two allreduce
2064+ else :
2065+ assert len (FindNodes (Call ).visit (i )) == 0
2066+
20322067 op .apply (time_m = 0 , time_M = nt - 1 )
20332068 assert np .isclose (np .max (g .data ), 4356.0 )
2069+ assert np .isclose (np .max (h .data ), 4356.0 )
20342070
20352071
20362072class TestOperatorAdvanced :
0 commit comments