@@ -647,31 +647,34 @@ def _collect_state_dict_for_vllm(self, parameter_group=None, parameter_group_no_
647647 def _move_full_model_to_vllm (self ):
648648 """Transfer full model weights to vLLM engine.
649649
650- Manages the lifecycle of gather and merge/unmerge:
651- - gather_if_zero3: once for the entire sync (DeepSpeed Zero3)
650+ Manages the lifecycle of gather and merge/unmerge per parameter_group :
651+ - gather_if_zero3: per parameter_group batch (DeepSpeed Zero3)
652652 - merge/unmerge: per parameter_group (must be within gather context)
653653 - No clone needed: unmerge happens after load completes
654654 """
655655 is_peft = is_peft_model (self .model )
656- # For DeepSpeed, merge within gather context; FSDP2 uses tensor-level merge
657656 should_merge = is_peft and not self ._is_fsdp2
658657
659658 gather_if_zero3 = get_gather_if_zero3_context (self )
660- parameters = [] if self ._is_fsdp2 else list (self .model .parameters ())
661659
662- with gather_if_zero3 (parameters ):
663- for i , parameter_group in enumerate (self .parameter_groups ):
664- parameter_group_no_lora = self .parameter_groups_no_lora [i ]
660+ for i , parameter_group in enumerate (self .parameter_groups ):
661+ parameter_group_no_lora = self .parameter_groups_no_lora [i ]
662+
663+ if not self ._is_fsdp2 :
664+ parameters = [
665+ parameter for name , parameter in self .model .named_parameters ()
666+ if not parameter_group or name in parameter_group
667+ ]
668+ else :
669+ parameters = []
665670
666- # Merge must be within gather context (needs full parameters)
671+ with gather_if_zero3 ( parameters ):
667672 if should_merge :
668673 with patch_lora_merge (self .model , parameter_group ):
669674 self .model .merge_adapter ()
670675
671676 try :
672- # Collect without clone - unmerge happens after load
673677 state_dict = self ._collect_state_dict_for_vllm (parameter_group , parameter_group_no_lora )
674- # Data is copied here (FlattenedTensorBucket.copy_ or vLLM load_weights)
675678 self ._load_state_dict_to_vllm (state_dict )
676679 finally :
677680 if should_merge :
0 commit comments