Skip to content

Commit e8a4d85

Browse files
committed
improve weight decay mix tests
Signed-off-by: Hao Wu <[email protected]>
1 parent c88e3ea commit e8a4d85

File tree

1 file changed

+56
-38
lines changed

1 file changed

+56
-38
lines changed

tests/test_weight_decay_mixin.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,43 @@
1414
# limitations under the License.
1515

1616
import torch
17+
from absl import flags, logging
1718
from absl.testing import absltest, parameterized
1819

1920
from emerging_optimizers.mixin import WeightDecayMixin
2021

2122

22-
class _WeightDecayHelper(WeightDecayMixin):
23+
flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
24+
flags.DEFINE_integer("seed", None, "Random seed for reproducible tests")
25+
FLAGS = flags.FLAGS
26+
27+
28+
def setUpModule() -> None:
29+
if FLAGS.seed is not None:
30+
logging.info("Setting random seed to %d", FLAGS.seed)
31+
torch.manual_seed(FLAGS.seed)
32+
if torch.cuda.is_available():
33+
torch.cuda.manual_seed_all(FLAGS.seed)
34+
35+
36+
class _WeightDecayTestHelper(WeightDecayMixin):
2337
"""Thin wrapper so we can set weight_decay_method and call the mixin."""
2438

2539
def __init__(self, method: str):
2640
self.weight_decay_method = method
2741

2842

2943
class WeightDecayMixinTest(parameterized.TestCase):
44+
@classmethod
45+
def setUpClass(cls):
46+
cls.device = FLAGS.device
47+
3048
@parameterized.parameters("decoupled", "independent", "l2", "palm")
3149
def test_zero_weight_decay_is_noop(self, method):
3250
"""Neither p nor grad should change when weight_decay is 0."""
33-
helper = _WeightDecayHelper(method)
34-
p = torch.tensor([1.0, 2.0, 3.0])
35-
grad = torch.tensor([0.5, -0.5, 1.0])
51+
helper = _WeightDecayTestHelper(method)
52+
p = torch.tensor([1.0, 2.0, 3.0], device=self.device)
53+
grad = torch.tensor([0.5, -0.5, 1.0], device=self.device)
3654
p_orig, grad_orig = p.clone(), grad.clone()
3755

3856
helper._apply_weight_decay_inplace(p, grad, lr=0.1, weight_decay=0.0)
@@ -41,15 +59,15 @@ def test_zero_weight_decay_is_noop(self, method):
4159
torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0)
4260

4361
@parameterized.parameters(
44-
{"lr": 0.1, "wd": 0.5},
45-
{"lr": 0.01, "wd": 1.0},
46-
{"lr": 1.0, "wd": 0.01},
62+
{"lr": 0.25, "wd": 0.5},
63+
{"lr": 0.025, "wd": 1.0},
64+
{"lr": 1.0, "wd": 0.05},
4765
)
4866
def test_decoupled(self, lr, wd):
4967
"""Decoupled: p <- p * (1 - wd * lr), grad untouched."""
50-
helper = _WeightDecayHelper("decoupled")
51-
p = torch.tensor([4.0, -2.0, 0.0, 7.0])
52-
grad = torch.tensor([1.0, 1.0, 1.0, 1.0])
68+
helper = _WeightDecayTestHelper("decoupled")
69+
p = torch.tensor([4.0, -2.0, 0.0, 8.0], device=self.device)
70+
grad = torch.tensor([1.0, 2.0, 1.0, 1.0], device=self.device)
5371
p_orig, grad_orig = p.clone(), grad.clone()
5472

5573
helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd)
@@ -59,15 +77,15 @@ def test_decoupled(self, lr, wd):
5977
torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0)
6078

6179
@parameterized.parameters(
62-
{"lr": 0.1, "wd": 0.5},
63-
{"lr": 0.01, "wd": 1.0},
64-
{"lr": 1.0, "wd": 0.01},
80+
{"lr": 0.25, "wd": 0.5},
81+
{"lr": 0.025, "wd": 1.0},
82+
{"lr": 1.0, "wd": 0.05},
6583
)
6684
def test_independent(self, lr, wd):
6785
"""Independent: p <- p * (1 - wd), grad untouched, lr irrelevant."""
68-
helper = _WeightDecayHelper("independent")
69-
p = torch.tensor([4.0, -2.0, 0.0, 7.0])
70-
grad = torch.tensor([1.0, 1.0, 1.0, 1.0])
86+
helper = _WeightDecayTestHelper("independent")
87+
p = torch.tensor([4.0, -2.0, 0.0, 7.0], device=self.device)
88+
grad = torch.tensor([1.0, 1.0, 1.0, 1.0], device=self.device)
7189
p_orig, grad_orig = p.clone(), grad.clone()
7290

7391
helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd)
@@ -79,13 +97,13 @@ def test_independent(self, lr, wd):
7997
def test_independent_ignores_lr(self):
8098
"""Two different lr values must produce identical results for independent decay."""
8199
wd = 0.3
82-
p1 = torch.tensor([5.0, -3.0, 1.0])
100+
p1 = torch.tensor([5.0, -3.0, 1.0], device=self.device)
83101
p2 = p1.clone()
84-
grad1 = torch.tensor([1.0, 1.0, 1.0])
102+
grad1 = torch.tensor([1.0, 1.0, 1.0], device=self.device)
85103
grad2 = grad1.clone()
86104

87-
_WeightDecayHelper("independent")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd)
88-
_WeightDecayHelper("independent")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd)
105+
_WeightDecayTestHelper("independent")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd)
106+
_WeightDecayTestHelper("independent")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd)
89107

90108
torch.testing.assert_close(p1, p2, atol=0, rtol=0)
91109

@@ -96,9 +114,9 @@ def test_independent_ignores_lr(self):
96114
)
97115
def test_l2(self, lr, wd):
98116
"""L2: grad <- grad + p * wd, p untouched."""
99-
helper = _WeightDecayHelper("l2")
100-
p = torch.tensor([4.0, -2.0, 0.0, 7.0])
101-
grad = torch.tensor([1.0, 1.0, 1.0, 1.0])
117+
helper = _WeightDecayTestHelper("l2")
118+
p = torch.tensor([4.0, -2.0, 0.0, 7.0], device=self.device)
119+
grad = torch.tensor([1.0, 1.0, 1.0, 1.0], device=self.device)
102120
p_orig, grad_orig = p.clone(), grad.clone()
103121

104122
helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd)
@@ -110,26 +128,26 @@ def test_l2(self, lr, wd):
110128
def test_l2_ignores_lr(self):
111129
"""Two different lr values must produce identical results for L2 decay."""
112130
wd = 0.3
113-
p1 = torch.tensor([5.0, -3.0, 1.0])
131+
p1 = torch.tensor([5.0, -3.0, 1.0], device=self.device)
114132
p2 = p1.clone()
115-
grad1 = torch.tensor([1.0, 1.0, 1.0])
133+
grad1 = torch.tensor([1.0, 1.0, 1.0], device=self.device)
116134
grad2 = grad1.clone()
117135

118-
_WeightDecayHelper("l2")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd)
119-
_WeightDecayHelper("l2")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd)
136+
_WeightDecayTestHelper("l2")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd)
137+
_WeightDecayTestHelper("l2")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd)
120138

121139
torch.testing.assert_close(grad1, grad2, atol=0, rtol=0)
122140

123141
@parameterized.parameters(
124-
{"lr": 0.1, "wd": 0.5},
125-
{"lr": 0.01, "wd": 1.0},
126-
{"lr": 1.0, "wd": 0.01},
142+
{"lr": 0.25, "wd": 0.5},
143+
{"lr": 0.025, "wd": 1.0},
144+
{"lr": 1.0, "wd": 0.05},
127145
)
128146
def test_palm(self, lr, wd):
129147
"""PaLM: p <- p * (1 - wd * lr^2), grad untouched."""
130-
helper = _WeightDecayHelper("palm")
131-
p = torch.tensor([4.0, -2.0, 0.0, 7.0])
132-
grad = torch.tensor([1.0, 1.0, 1.0, 1.0])
148+
helper = _WeightDecayTestHelper("palm")
149+
p = torch.tensor([4.0, -2.0, 0.0, 7.0], device=self.device)
150+
grad = torch.tensor([1.0, 1.0, 1.0, 1.0], device=self.device)
133151
p_orig, grad_orig = p.clone(), grad.clone()
134152

135153
helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd)
@@ -141,8 +159,8 @@ def test_palm(self, lr, wd):
141159
def test_default_method_is_l2(self):
142160
"""When weight_decay_method attribute is absent, default to L2."""
143161
helper = WeightDecayMixin()
144-
p = torch.tensor([4.0, -2.0, 0.0, 7.0])
145-
grad = torch.tensor([1.0, 1.0, 1.0, 1.0])
162+
p = torch.tensor([4.0, -2.0, 0.0, 7.0], device=self.device)
163+
grad = torch.tensor([1.0, 1.0, 1.0, 1.0], device=self.device)
146164
p_orig, grad_orig = p.clone(), grad.clone()
147165

148166
wd = 0.5
@@ -154,9 +172,9 @@ def test_default_method_is_l2(self):
154172

155173
def test_invalid_method_raises(self):
156174
"""An unrecognized weight_decay_method must raise ValueError."""
157-
helper = _WeightDecayHelper("bogus")
158-
p = torch.tensor([1.0])
159-
grad = torch.tensor([1.0])
175+
helper = _WeightDecayTestHelper("bogus")
176+
p = torch.tensor([1.0], device=self.device)
177+
grad = torch.tensor([1.0], device=self.device)
160178
with self.assertRaises(ValueError):
161179
helper._apply_weight_decay_inplace(p, grad, lr=0.1, weight_decay=0.1)
162180

0 commit comments

Comments
 (0)