diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index be56ef50dd..35621673cc 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -359,6 +359,8 @@ class Args: """whether to offload parameters to CPU (reduces GPU memory usage)""" deepspeed_offload_optimizer: bool = False """whether to offload optimizer states to CPU (reduces GPU memory usage)""" + deepspeed_cpu_adam: bool = False + """Whether to use DeepSpeedCPUAdam optimizer""" gather_whole_model: bool = True """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" enable_queue_dashboard: bool = True @@ -711,7 +713,11 @@ def load(self, path: str, map_location=None): optim_params = get_optimizer_grouped_parameters(self.policy, args.weight_decay) else: optim_params = self.policy.parameters() - self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) + if args.deepspeed_cpu_adam: + from deepspeed.ops.adam import DeepSpeedCPUAdam + self.optimizer = DeepSpeedCPUAdam(optim_params, lr=args.learning_rate) + else: + self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) num_scheduler_steps = args.num_training_steps * args.num_epochs * args.num_mini_batches warm_up_steps = args.warm_up_steps if args.warmup_ratio > 0.0: