Skip to content

Commit 684c9f3

Browse files
authored
[bugfix] fix grpo move_model_batches (#8091)
1 parent 71268ab commit 684c9f3

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

swift/trainers/rlhf_trainer/rollout_mixin.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -647,31 +647,34 @@ def _collect_state_dict_for_vllm(self, parameter_group=None, parameter_group_no_
647647
def _move_full_model_to_vllm(self):
648648
"""Transfer full model weights to vLLM engine.
649649
650-
Manages the lifecycle of gather and merge/unmerge:
651-
- gather_if_zero3: once for the entire sync (DeepSpeed Zero3)
650+
Manages the lifecycle of gather and merge/unmerge per parameter_group:
651+
- gather_if_zero3: per parameter_group batch (DeepSpeed Zero3)
652652
- merge/unmerge: per parameter_group (must be within gather context)
653653
- No clone needed: unmerge happens after load completes
654654
"""
655655
is_peft = is_peft_model(self.model)
656-
# For DeepSpeed, merge within gather context; FSDP2 uses tensor-level merge
657656
should_merge = is_peft and not self._is_fsdp2
658657

659658
gather_if_zero3 = get_gather_if_zero3_context(self)
660-
parameters = [] if self._is_fsdp2 else list(self.model.parameters())
661659

662-
with gather_if_zero3(parameters):
663-
for i, parameter_group in enumerate(self.parameter_groups):
664-
parameter_group_no_lora = self.parameter_groups_no_lora[i]
660+
for i, parameter_group in enumerate(self.parameter_groups):
661+
parameter_group_no_lora = self.parameter_groups_no_lora[i]
662+
663+
if not self._is_fsdp2:
664+
parameters = [
665+
parameter for name, parameter in self.model.named_parameters()
666+
if not parameter_group or name in parameter_group
667+
]
668+
else:
669+
parameters = []
665670

666-
# Merge must be within gather context (needs full parameters)
671+
with gather_if_zero3(parameters):
667672
if should_merge:
668673
with patch_lora_merge(self.model, parameter_group):
669674
self.model.merge_adapter()
670675

671676
try:
672-
# Collect without clone - unmerge happens after load
673677
state_dict = self._collect_state_dict_for_vllm(parameter_group, parameter_group_no_lora)
674-
# Data is copied here (FlattenedTensorBucket.copy_ or vLLM load_weights)
675678
self._load_state_dict_to_vllm(state_dict)
676679
finally:
677680
if should_merge:

0 commit comments

Comments
 (0)