Skip to content

Commit c56fa29

Browse files
authored
Merge pull request #1532 from hanhainebula/master
fix bug: safe dist.get_rank()
2 parents 5e1f42f + 24bc8d1 commit c56fa29

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

FlagEmbedding/abc/finetune/embedder/AbsDataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def _load_dataset(self, file_path: str):
6363
Returns:
6464
datasets.Dataset: Loaded HF dataset.
6565
"""
66-
if dist.get_rank() == 0:
66+
safe_rank = dist.get_rank() if dist.is_initialized() else 0
67+
if safe_rank == 0:
6768
logger.info(f'loading data from {file_path} ...')
6869

6970
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
@@ -342,7 +343,8 @@ def _load_dataset(self, file_path: str):
342343
Returns:
343344
datasets.Dataset: The loaded dataset.
344345
"""
345-
if dist.get_rank() == 0:
346+
safe_rank = dist.get_rank() if dist.is_initialized() else 0
347+
if safe_rank == 0:
346348
logger.info(f'loading data from {file_path} ...')
347349

348350
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)

FlagEmbedding/abc/finetune/embedder/AbsModeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(
5454
if self.negatives_cross_device:
5555
if not dist.is_initialized():
5656
raise ValueError('Distributed training has not been initialized for representation all gather.')
57-
self.process_rank = dist.get_rank()
58-
self.world_size = dist.get_world_size()
57+
self.process_rank = dist.get_rank() if dist.is_initialized() else 0
58+
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
5959

6060
self.sub_batch_size = sub_batch_size
6161
self.kd_loss_type = kd_loss_type

FlagEmbedding/abc/finetune/reranker/AbsDataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def _load_dataset(self, file_path: str):
6464
Returns:
6565
datasets.Dataset: Loaded HF dataset.
6666
"""
67-
if dist.get_rank() == 0:
67+
safe_rank = dist.get_rank() if dist.is_initialized() else 0
68+
if safe_rank == 0:
6869
logger.info(f'loading data from {file_path} ...')
6970

7071
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)

0 commit comments

Comments
 (0)