Skip to content

Commit 01293f6

Browse files
hjh0119Jintao-Huang
authored andcommitted
fix generation-batch-size&steps_per_generation check (#8048)
1 parent ec223f6 commit 01293f6

3 files changed

Lines changed: 3 additions & 35 deletions

File tree

swift/llm/argument/deploy_args.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,11 @@ class RolloutArguments(DeployArguments):
115115
context_manager: Optional[str] = None
116116

117117
def __post_init__(self):
118-
self._check_trl_version()
119118
self._set_default_engine_type()
120119
super().__post_init__()
121120
self._check_args()
122121
self._check_device_count()
123122

124-
def _check_trl_version(self):
125-
try:
126-
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
127-
except ImportError as e:
128-
raise ImportError("Could not import 'WeightSyncWorkerExtension' from 'trl.scripts.vllm_serve'. "
129-
"Please upgrade your 'trl' package by 'pip install -U trl'") from e
130-
131123
def _set_default_engine_type(self):
132124
if self.vllm_use_async_engine is None:
133125
if self.multi_turn_scheduler:

swift/llm/argument/rlhf_args.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,13 @@ def _check_grpo(self):
451451

452452
import trl
453453
trl_version = version.parse(trl.__version__)
454-
assert trl_version >= version.parse('0.17'), ('Your current version of `trl` is outdated. '
454+
assert trl_version >= version.parse('0.20'), ('Your current version of `trl` is outdated. '
455455
'Please update it by running: pip install -U trl')
456456
if is_mp() and self.use_vllm:
457457
raise ValueError('GRPO with vLLM is not compatible with `device_map`. '
458458
'Please set NPROC_PER_NODE equal to num_processes.')
459459
if self.use_liger_kernel:
460460
liger_kernel_version = version.parse(importlib.metadata.version('liger-kernel'))
461-
assert trl_version >= version.parse('0.18')
462461
if self.delta is not None:
463462
raise ValueError('Liger loss does not support two-sided GRPO loss yet.')
464463
if self.sequence_parallel_size > 1:
@@ -485,12 +484,6 @@ def _check_grpo(self):
485484
if self.async_generate and self.multi_turn_scheduler is not None:
486485
raise NotImplementedError('Currently, async_generate is not supported with multi-turn functionality.')
487486

488-
if self.generation_batch_size or self.steps_per_generation:
489-
from trl.trainer.grpo_config import GRPOConfig
490-
assert 'generation_batch_size' in GRPOConfig.__dict__, (
491-
'generation_batch_size or steps_per_generation needs trl >= 0.18, '
492-
'please install trl `pip install trl>=0.18')
493-
494487
def _external_vllm_warning(self):
495488
if self.rlhf_type not in rlhf_support_vllm_types or not self.vllm_server_host:
496489
return

swift/trainers/rlhf_arguments.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import Optional
33

4+
from transformers.utils.versions import require_version
45
from trl import CPOConfig as HfCPOConfig
56
from trl import DPOConfig as HfDPOConfig
67
from trl import GKDConfig as HfGKDConfig
@@ -58,6 +59,7 @@ def __post_init__(self):
5859
class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
5960

6061
def __post_init__(self):
62+
require_version('trl>=0.20')
6163
GRPOArgumentsMixin.__post_init__(self)
6264
SwiftArgumentsMixin.__post_init__(self)
6365
if self.vllm_reasoning_parser is not None:
@@ -75,25 +77,6 @@ def __post_init__(self):
7577
# https://github.com/modelscope/ms-swift/issues/3863
7678
self.dataloader_drop_last = True
7779

78-
# from trl https://github.com/huggingface/trl/blob/7a39ff3995f2f8b7cb4f8ca29a09390ac587a43d/trl/trainer/grpo_config.py#L843 # noqa: E501
79-
num_processes = self.world_size
80-
# The current default effective batch size
81-
if self.generation_batch_size is None and self.steps_per_generation is None:
82-
self.steps_per_generation = self.gradient_accumulation_steps
83-
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
84-
elif self.generation_batch_size is not None and self.steps_per_generation is None:
85-
# Just ensure the value is divisible by the global batch size
86-
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
87-
raise ValueError(
88-
f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size '
89-
f'({self.per_device_train_batch_size * num_processes}).')
90-
self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes)
91-
elif self.generation_batch_size is None and self.steps_per_generation is not None:
92-
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
93-
else:
94-
raise ValueError(
95-
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time")
96-
9780
self.check_num_generations()
9881

9982
def check_num_generations(self):

0 commit comments

Comments
 (0)