1414# limitations under the License.
1515
1616import torch
17+ from absl import flags , logging
1718from absl .testing import absltest , parameterized
1819
1920from emerging_optimizers .mixin import WeightDecayMixin
2021
2122
22- class _WeightDecayHelper (WeightDecayMixin ):
23+ flags .DEFINE_enum ("device" , "cpu" , ["cpu" , "cuda" ], "Device to run tests on" )
24+ flags .DEFINE_integer ("seed" , None , "Random seed for reproducible tests" )
25+ FLAGS = flags .FLAGS
26+
27+
28+ def setUpModule () -> None :
29+ if FLAGS .seed is not None :
30+ logging .info ("Setting random seed to %d" , FLAGS .seed )
31+ torch .manual_seed (FLAGS .seed )
32+ if torch .cuda .is_available ():
33+ torch .cuda .manual_seed_all (FLAGS .seed )
34+
35+
36+ class _WeightDecayTestHelper (WeightDecayMixin ):
2337 """Thin wrapper so we can set weight_decay_method and call the mixin."""
2438
2539 def __init__ (self , method : str ):
2640 self .weight_decay_method = method
2741
2842
2943class WeightDecayMixinTest (parameterized .TestCase ):
44+ @classmethod
45+ def setUpClass (cls ):
46+ cls .device = FLAGS .device
47+
3048 @parameterized .parameters ("decoupled" , "independent" , "l2" , "palm" )
3149 def test_zero_weight_decay_is_noop (self , method ):
3250 """Neither p nor grad should change when weight_decay is 0."""
33- helper = _WeightDecayHelper (method )
34- p = torch .tensor ([1.0 , 2.0 , 3.0 ])
35- grad = torch .tensor ([0.5 , - 0.5 , 1.0 ])
51+ helper = _WeightDecayTestHelper (method )
52+ p = torch .tensor ([1.0 , 2.0 , 3.0 ], device = self . device )
53+ grad = torch .tensor ([0.5 , - 0.5 , 1.0 ], device = self . device )
3654 p_orig , grad_orig = p .clone (), grad .clone ()
3755
3856 helper ._apply_weight_decay_inplace (p , grad , lr = 0.1 , weight_decay = 0.0 )
@@ -41,15 +59,15 @@ def test_zero_weight_decay_is_noop(self, method):
4159 torch .testing .assert_close (grad , grad_orig , atol = 0 , rtol = 0 )
4260
4361 @parameterized .parameters (
44- {"lr" : 0.1 , "wd" : 0.5 },
45- {"lr" : 0.01 , "wd" : 1.0 },
46- {"lr" : 1.0 , "wd" : 0.01 },
62+ {"lr" : 0.25 , "wd" : 0.5 },
63+ {"lr" : 0.025 , "wd" : 1.0 },
64+ {"lr" : 1.0 , "wd" : 0.05 },
4765 )
4866 def test_decoupled (self , lr , wd ):
4967 """Decoupled: p <- p * (1 - wd * lr), grad untouched."""
50- helper = _WeightDecayHelper ("decoupled" )
51- p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7 .0 ])
52- grad = torch .tensor ([1.0 , 1 .0 , 1.0 , 1.0 ])
68+ helper = _WeightDecayTestHelper ("decoupled" )
69+ p = torch .tensor ([4.0 , - 2.0 , 0.0 , 8 .0 ], device = self . device )
70+ grad = torch .tensor ([1.0 , 2 .0 , 1.0 , 1.0 ], device = self . device )
5371 p_orig , grad_orig = p .clone (), grad .clone ()
5472
5573 helper ._apply_weight_decay_inplace (p , grad , lr = lr , weight_decay = wd )
@@ -59,15 +77,15 @@ def test_decoupled(self, lr, wd):
5977 torch .testing .assert_close (grad , grad_orig , atol = 0 , rtol = 0 )
6078
6179 @parameterized .parameters (
62- {"lr" : 0.1 , "wd" : 0.5 },
63- {"lr" : 0.01 , "wd" : 1.0 },
64- {"lr" : 1.0 , "wd" : 0.01 },
80+ {"lr" : 0.25 , "wd" : 0.5 },
81+ {"lr" : 0.025 , "wd" : 1.0 },
82+ {"lr" : 1.0 , "wd" : 0.05 },
6583 )
6684 def test_independent (self , lr , wd ):
6785 """Independent: p <- p * (1 - wd), grad untouched, lr irrelevant."""
68- helper = _WeightDecayHelper ("independent" )
69- p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ])
70- grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ])
86+ helper = _WeightDecayTestHelper ("independent" )
87+ p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ], device = self . device )
88+ grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ], device = self . device )
7189 p_orig , grad_orig = p .clone (), grad .clone ()
7290
7391 helper ._apply_weight_decay_inplace (p , grad , lr = lr , weight_decay = wd )
@@ -79,13 +97,13 @@ def test_independent(self, lr, wd):
7997 def test_independent_ignores_lr (self ):
8098 """Two different lr values must produce identical results for independent decay."""
8199 wd = 0.3
82- p1 = torch .tensor ([5.0 , - 3.0 , 1.0 ])
100+ p1 = torch .tensor ([5.0 , - 3.0 , 1.0 ], device = self . device )
83101 p2 = p1 .clone ()
84- grad1 = torch .tensor ([1.0 , 1.0 , 1.0 ])
102+ grad1 = torch .tensor ([1.0 , 1.0 , 1.0 ], device = self . device )
85103 grad2 = grad1 .clone ()
86104
87- _WeightDecayHelper ("independent" )._apply_weight_decay_inplace (p1 , grad1 , lr = 0.001 , weight_decay = wd )
88- _WeightDecayHelper ("independent" )._apply_weight_decay_inplace (p2 , grad2 , lr = 100.0 , weight_decay = wd )
105+ _WeightDecayTestHelper ("independent" )._apply_weight_decay_inplace (p1 , grad1 , lr = 0.001 , weight_decay = wd )
106+ _WeightDecayTestHelper ("independent" )._apply_weight_decay_inplace (p2 , grad2 , lr = 100.0 , weight_decay = wd )
89107
90108 torch .testing .assert_close (p1 , p2 , atol = 0 , rtol = 0 )
91109
@@ -96,9 +114,9 @@ def test_independent_ignores_lr(self):
96114 )
97115 def test_l2 (self , lr , wd ):
98116 """L2: grad <- grad + p * wd, p untouched."""
99- helper = _WeightDecayHelper ("l2" )
100- p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ])
101- grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ])
117+ helper = _WeightDecayTestHelper ("l2" )
118+ p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ], device = self . device )
119+ grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ], device = self . device )
102120 p_orig , grad_orig = p .clone (), grad .clone ()
103121
104122 helper ._apply_weight_decay_inplace (p , grad , lr = lr , weight_decay = wd )
@@ -110,26 +128,26 @@ def test_l2(self, lr, wd):
110128 def test_l2_ignores_lr (self ):
111129 """Two different lr values must produce identical results for L2 decay."""
112130 wd = 0.3
113- p1 = torch .tensor ([5.0 , - 3.0 , 1.0 ])
131+ p1 = torch .tensor ([5.0 , - 3.0 , 1.0 ], device = self . device )
114132 p2 = p1 .clone ()
115- grad1 = torch .tensor ([1.0 , 1.0 , 1.0 ])
133+ grad1 = torch .tensor ([1.0 , 1.0 , 1.0 ], device = self . device )
116134 grad2 = grad1 .clone ()
117135
118- _WeightDecayHelper ("l2" )._apply_weight_decay_inplace (p1 , grad1 , lr = 0.001 , weight_decay = wd )
119- _WeightDecayHelper ("l2" )._apply_weight_decay_inplace (p2 , grad2 , lr = 100.0 , weight_decay = wd )
136+ _WeightDecayTestHelper ("l2" )._apply_weight_decay_inplace (p1 , grad1 , lr = 0.001 , weight_decay = wd )
137+ _WeightDecayTestHelper ("l2" )._apply_weight_decay_inplace (p2 , grad2 , lr = 100.0 , weight_decay = wd )
120138
121139 torch .testing .assert_close (grad1 , grad2 , atol = 0 , rtol = 0 )
122140
123141 @parameterized .parameters (
124- {"lr" : 0.1 , "wd" : 0.5 },
125- {"lr" : 0.01 , "wd" : 1.0 },
126- {"lr" : 1.0 , "wd" : 0.01 },
142+ {"lr" : 0.25 , "wd" : 0.5 },
143+ {"lr" : 0.025 , "wd" : 1.0 },
144+ {"lr" : 1.0 , "wd" : 0.05 },
127145 )
128146 def test_palm (self , lr , wd ):
129147 """PaLM: p <- p * (1 - wd * lr^2), grad untouched."""
130- helper = _WeightDecayHelper ("palm" )
131- p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ])
132- grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ])
148+ helper = _WeightDecayTestHelper ("palm" )
149+ p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ], device = self . device )
150+ grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ], device = self . device )
133151 p_orig , grad_orig = p .clone (), grad .clone ()
134152
135153 helper ._apply_weight_decay_inplace (p , grad , lr = lr , weight_decay = wd )
@@ -141,8 +159,8 @@ def test_palm(self, lr, wd):
141159 def test_default_method_is_l2 (self ):
142160 """When weight_decay_method attribute is absent, default to L2."""
143161 helper = WeightDecayMixin ()
144- p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ])
145- grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ])
162+ p = torch .tensor ([4.0 , - 2.0 , 0.0 , 7.0 ], device = self . device )
163+ grad = torch .tensor ([1.0 , 1.0 , 1.0 , 1.0 ], device = self . device )
146164 p_orig , grad_orig = p .clone (), grad .clone ()
147165
148166 wd = 0.5
@@ -154,9 +172,9 @@ def test_default_method_is_l2(self):
154172
155173 def test_invalid_method_raises (self ):
156174 """An unrecognized weight_decay_method must raise ValueError."""
157- helper = _WeightDecayHelper ("bogus" )
158- p = torch .tensor ([1.0 ])
159- grad = torch .tensor ([1.0 ])
175+ helper = _WeightDecayTestHelper ("bogus" )
176+ p = torch .tensor ([1.0 ], device = self . device )
177+ grad = torch .tensor ([1.0 ], device = self . device )
160178 with self .assertRaises (ValueError ):
161179 helper ._apply_weight_decay_inplace (p , grad , lr = 0.1 , weight_decay = 0.1 )
162180
0 commit comments