Skip to content

Commit b6fa1c7

Browse files
committed
simplify arguments
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent d9e86b2 commit b6fa1c7

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

tests/convergence/moe_c4_convergence.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,22 @@ def build_optimizer(
9999
else:
100100
params_other.append(p)
101101

102+
common_kwargs = {
103+
"lr": lr,
104+
"weight_decay": weight_decay,
105+
}
102106
if optimizer_name == "muon":
103-
muon_opt = Muon(params_2d, lr=lr, momentum=0.95, weight_decay=weight_decay)
104-
adam_opt = torch.optim.AdamW(params_other, lr=lr * 0.1, weight_decay=weight_decay)
107+
muon_opt = Muon(params_2d, **common_kwargs)
108+
adam_opt = torch.optim.AdamW(params_other, **common_kwargs)
105109
return _CombinedOptimizer([muon_opt, adam_opt])
106110

107111
elif optimizer_name == "soap":
108-
soap_opt = SOAP(params_2d, lr=lr, weight_decay=weight_decay, precondition_frequency=10)
109-
adam_opt = torch.optim.AdamW(params_other, lr=lr, weight_decay=weight_decay)
112+
soap_opt = SOAP(params_2d, **common_kwargs, precondition_frequency=1)
113+
adam_opt = torch.optim.AdamW(params_other, **common_kwargs)
110114
return _CombinedOptimizer([soap_opt, adam_opt])
111115

112116
elif optimizer_name == "adamw":
113-
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
117+
return torch.optim.AdamW(model.parameters(), **common_kwargs)
114118

115119
else:
116120
raise ValueError(f"Unknown optimizer: {optimizer_name}")

0 commit comments

Comments
 (0)