@@ -99,18 +99,22 @@ def build_optimizer(
9999 else :
100100 params_other .append (p )
101101
102+ common_kwargs = {
103+ "lr" : lr ,
104+ "weight_decay" : weight_decay ,
105+ }
102106 if optimizer_name == "muon" :
103- muon_opt = Muon (params_2d , lr = lr , momentum = 0.95 , weight_decay = weight_decay )
104- adam_opt = torch .optim .AdamW (params_other , lr = lr * 0.1 , weight_decay = weight_decay )
107+ muon_opt = Muon (params_2d , ** common_kwargs )
108+ adam_opt = torch .optim .AdamW (params_other , ** common_kwargs )
105109 return _CombinedOptimizer ([muon_opt , adam_opt ])
106110
107111 elif optimizer_name == "soap" :
108- soap_opt = SOAP (params_2d , lr = lr , weight_decay = weight_decay , precondition_frequency = 10 )
109- adam_opt = torch .optim .AdamW (params_other , lr = lr , weight_decay = weight_decay )
112+ soap_opt = SOAP (params_2d , ** common_kwargs , precondition_frequency = 1 )
113+ adam_opt = torch .optim .AdamW (params_other , ** common_kwargs )
110114 return _CombinedOptimizer ([soap_opt , adam_opt ])
111115
112116 elif optimizer_name == "adamw" :
113- return torch .optim .AdamW (model .parameters (), lr = lr , weight_decay = weight_decay )
117+ return torch .optim .AdamW (model .parameters (), ** common_kwargs )
114118
115119 else :
116120 raise ValueError (f"Unknown optimizer: { optimizer_name } " )
0 commit comments