分布式训练时数据集的分配
pytorch
本文字数:1.5k 字 | 阅读时长 ≈ 7 min

分布式训练时数据集的分配

pytorch
本文字数:1.5k 字 | 阅读时长 ≈ 7 min

1. DistributedSampler

torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)

设置 dataset 的子集用于分布式训练
Sampler that restricts data loading to a subset of the dataset.

注意在分布式的模式中,在每个 epoch 要调用 set.epoch 函数,不然进行迭代时每次都是相同的数据集顺序

下面通过例子来理解 DistrubtedSamplerset.epoch 函数:

不调用 set.epoch

运行:CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29501 exe.py

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

output_size = 2
batch_size = 2
data_size = 16

torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class CustomDataset(Dataset):
    def __init__(self, length, local_rank):
        self.len = length
        self.data = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]).to('cuda')
        self.local_rank = local_rank
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

dataset = CustomDataset(data_size, local_rank)
sampler = DistributedSampler(dataset)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler)

for epoch in range(2):
    # sampler.set_epoch(epoch)
    for data in data_loader:
        if local_rank==0:
            print(data)
'''
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
tensor([13, 10], device='cuda:0')
tensor([12, 14], device='cuda:0')
tensor([ 3, 16], device='cuda:0')
tensor([5, 8], device='cuda:0')
tensor([13, 10], device='cuda:0')
tensor([12, 14], device='cuda:0')
tensor([ 3, 16], device='cuda:0')
tensor([5, 8], device='cuda:0')
'''

调用 set.epoch,即将上述代码中的 sampler.set_epoch(epoch)注释取消掉

'''
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
tensor([13, 10], device='cuda:0')
tensor([12, 14], device='cuda:0')
tensor([ 3, 16], device='cuda:0')
tensor([5, 8], device='cuda:0')
tensor([6, 7], device='cuda:0')
tensor([12,  8], device='cuda:0')
tensor([ 2, 10], device='cuda:0')
tensor([11, 14], device='cuda:0')
'''

对于上述输出,当不使用 set.epoch 时,两个 epoch 的 cuda:0 中的数据顺序是一致的,而使用 set.epoch 时,两个 epoch 的 cuda:0 中的数据不同,也就是说,set.epoch 在每个 epoch 设置了不同的随机种子
我这里采用了两张卡,最终数据是平均分配的,也就是说数据被随机的分成了两份进行分配
还有一点需要注意的:如果我们在训练过程中加入了 torch.utils.data.BatchSampler,原理依旧是不变的

下面是第二个例子,来看一下 num_replicas 的作用

import os
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler


# 自定义一个简单的数据集
class SimpleDataset(Dataset):
    def __init__(self, size):
        self.data = list(range(size))  # 数据是从 0 到 size-1 的整数列表

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


# 分布式数据分片与打印
def main():
    # 数据集大小
    dataset_size = 10  # 总数据大小为 10
    batch_size = 2

    # 初始化分布式
    dist.init_process_group("nccl")  # 使用 NCCL 后端(适用于 GPU),也可以用 "gloo"(适用于 CPU)
    local_rank = int(os.environ["LOCAL_RANK"])  # LOCAL_RANK 是当前进程的 GPU 编号
    torch.cuda.set_device(local_rank)  # 当前进程使用的 GPU
    world_size = dist.get_world_size()  # 获取总进程数(比如 2 张 GPU)
    rank = dist.get_rank()  # 获取当前进程的 global rank(范围是 0 到 world_size-1)

    # 创建数据集和分布式采样器
    dataset = SimpleDataset(size=dataset_size)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    # 打印当前分片的数据
    print(f"Rank {rank}/{world_size} is processing the following data:")
    for batch in dataloader:
        print(f"Rank {rank}: Batch {batch.tolist()}")
    # 清理分布式进程组
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

'''
# OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 data.py 
Rank 0/2 is processing the following data:
Rank 0: Batch [0, 2]
Rank 0: Batch [4, 6]
Rank 0: Batch [8]
Rank 1/2 is processing the following data:
Rank 1: Batch [1, 3]
Rank 1: Batch [5, 7]
Rank 1: Batch [9]
'''

DistributedSampler 在分布式训练中负责为每个 GPU 分配数据集的不同子集,以确保每个 GPU 不同进程负责训练时不会处理重复的数据。num_replicas 参数的作用是用于指定当前分布式训练中并行进程的总数。具体来说,它决定了在分布式训练中数据集如何被切分。num_replicas 的作用如下:

2. BatchSampler

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

产生一个 mini-batch 的索引

from torch.utils.data import BatchSampler
sampler = list(BatchSampler(range(10), batch_size=3, drop_last=True))
sampler2 = list(BatchSampler(range(10), batch_size=3, drop_last=False))
print(sampler)
print(sampler2)
'''
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
'''

一般来说在分布式训练的时候,先执行 DistributedSampler,然后执行 BatchSampler,将 BatchSampler 传入到 DataLoader 即可

3. SequentialSampler

torch.utils.data.SequentialSampler(dataset)

用于顺序采样数据集中的样本。它的作用很简单:按照数据集中样本的顺序排列依次取样,而不进行随机打乱。因此,SequentialSampler 通常用于在验证集和测试集上进行评估时的采样,因为这两个阶段通常不需要随机性。

from torch.utils.data import SequentialSampler, DataLoader, Dataset

# 定义简单数据集
class SimpleDataset(Dataset):
    def __init__(self, size):
        self.data = list(range(size))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# 创建数据集和顺序采样器
dataset_val = SimpleDataset(size=10)  # 数据集大小为 10
sampler_val = SequentialSampler(dataset_val)  # 按顺序采样
dataloader_val = DataLoader(dataset_val, sampler=sampler_val, batch_size=2)

# 打印验证集采样顺序
for batch in dataloader_val:
    print(batch)
8月 26, 2025