11from dataclasses import dataclass , field
22from typing import Optional
33
4+ from transformers .utils .versions import require_version
45from trl import CPOConfig as HfCPOConfig
56from trl import DPOConfig as HfDPOConfig
67from trl import GKDConfig as HfGKDConfig
@@ -58,6 +59,7 @@ def __post_init__(self):
5859class GRPOConfig (GRPOArgumentsMixin , SwiftArgumentsMixin , HfGRPOConfig ):
5960
6061 def __post_init__ (self ):
62+ require_version ('trl>=0.20' )
6163 GRPOArgumentsMixin .__post_init__ (self )
6264 SwiftArgumentsMixin .__post_init__ (self )
6365 if self .vllm_reasoning_parser is not None :
@@ -75,25 +77,6 @@ def __post_init__(self):
7577 # https://github.com/modelscope/ms-swift/issues/3863
7678 self .dataloader_drop_last = True
7779
78- # from trl https://github.com/huggingface/trl/blob/7a39ff3995f2f8b7cb4f8ca29a09390ac587a43d/trl/trainer/grpo_config.py#L843 # noqa: E501
79- num_processes = self .world_size
80- # The current default effective batch size
81- if self .generation_batch_size is None and self .steps_per_generation is None :
82- self .steps_per_generation = self .gradient_accumulation_steps
83- self .generation_batch_size = self .per_device_train_batch_size * num_processes * self .steps_per_generation
84- elif self .generation_batch_size is not None and self .steps_per_generation is None :
85- # Just ensure the value is divisible by the global batch size
86- if self .generation_batch_size % (self .per_device_train_batch_size * num_processes ) != 0 :
87- raise ValueError (
88- f'generation_batch_size ({ self .generation_batch_size } ) must be divisible by the global batch size '
89- f'({ self .per_device_train_batch_size * num_processes } ).' )
90- self .steps_per_generation = self .generation_batch_size // (self .per_device_train_batch_size * num_processes )
91- elif self .generation_batch_size is None and self .steps_per_generation is not None :
92- self .generation_batch_size = self .per_device_train_batch_size * num_processes * self .steps_per_generation
93- else :
94- raise ValueError (
95- "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" )
96-
9780 self .check_num_generations ()
9881
9982 def check_num_generations (self ):
0 commit comments