分布式训练时数据集的分配
pytorch
本文字数:2k 字 | 阅读时长 ≈ 9 min

分布式训练时数据集的分配

pytorch
本文字数:2k 字 | 阅读时长 ≈ 9 min

注意:使用 DataLoader(…, batch_sampler=…) 时,不能再同时传入 batch_size、shuffle、sampler、drop_last 等参数给 DataLoader。

1. DistributedSampler

torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)

注意在分布式的模式中,在每个 epoch 要调用 set.epoch 函数,不然进行迭代时每次都是相同的数据集顺序

下面通过例子来理解 DistrubtedSamplerset.epoch 函数:

set.epoch 在每个 epoch 设置了不同的随机种子

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(16))
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

num_replicas = 2   # 模拟2个进程
batch_size = 2
dataset = SimpleDataset()
for rank in range(num_replicas):
    sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    print(f"\n=== Rank {rank} ===")
    for epoch in range(3):
        sampler.set_epoch(epoch)  # 调用和不调用 `set_epoch`
        print(f"Epoch {epoch} batches:")
        for batch in dataloader:
            print(batch)

实验结果:比较调用和不调用 set_epoch 的情况

Rank Epoch 是否调用 set_epoch Batch 1 Batch 2 Batch 3 Batch 4
0 0 调用 [12, 9] [11, 13] [2, 15] [4, 7]
0 1 调用 [5, 6] [11, 7] [1, 9] [10, 13]
0 2 调用 [8, 1] [6, 0] [4, 10] [15, 5]
1 0 调用 [10, 6] [8, 5] [14, 0] [3, 1]
1 1 调用 [15, 4] [2, 12] [0, 8] [3, 14]
1 2 调用 [13, 9] [7, 11] [2, 12] [14, 3]
0 0 不调用 [12, 9] [11, 13] [2, 15] [4, 7]
0 1 不调用 [12, 9] [11, 13] [2, 15] [4, 7]
0 2 不调用 [12, 9] [11, 13] [2, 15] [4, 7]
1 0 不调用 [10, 6] [8, 5] [14, 0] [3, 1]
1 1 不调用 [10, 6] [8, 5] [14, 0] [3, 1]
1 2 不调用 [10, 6] [8, 5] [14, 0] [3, 1]

2. SequentialSampler

torch.utils.data.SequentialSampler(dataset)

用于顺序采样数据集中的样本。它的作用很简单:按照数据集中样本的顺序排列依次取样,而不进行随机打乱。因此,SequentialSampler 通常用于在验证集和测试集上进行评估时的采样,因为这两个阶段通常不需要随机性。

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SequentialSampler

class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(16))
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

batch_size = 2
dataset = SimpleDataset()

sampler = SequentialSampler(dataset)  # 用 SequentialSampler 按顺序取样本
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

for epoch in range(2):
    print(f"\nEpoch {epoch} batches:")
    for batch in dataloader:
        print(batch)

实验结果:

Epoch 0 batches: tensor([0, 1]),tensor([2, 3]), ..., tensor([14, 15])
Epoch 1 batches: tensor([0, 1]),tensor([2, 3]), ..., tensor([14, 15])

3. RandomSampler

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import RandomSampler

class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(16))
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

dataset = SimpleDataset()
batch_size = 4

sampler = RandomSampler(dataset, replacement=False)  # 不放回,遍历完所有样本
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
for epoch in range(2):
    print(f"Epoch {epoch} batches:")
    for batch in dataloader:
        print(batch)

输出内容如下

Epoch 0 batches:
tensor([ 5, 11,  4, 13])
tensor([ 9,  2,  1, 14])
tensor([10,  8,  6,  3])
tensor([15, 12,  0,  7])
Epoch 1 batches:
tensor([ 7,  4, 10,  2])
tensor([ 6, 13, 12, 11])
tensor([ 3,  5, 14,  8])
tensor([ 1,  0, 15,  9])

4. BatchSampler

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

将一个“样本采样器”(Sampler)产生的单个样本索引,按固定 batch_size 打包成一组索引列表,再交给 DataLoader 去取数据。可以与任意 Sampler 搭配:SequentialSampler(按顺序)、RandomSampler(随机),或 DistributedSampler(分布式)。

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.distributed import DistributedSampler

class SimpleDataset(Dataset):
    def __init__(self, n=16):
        self.data = list(range(n))
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

dataset = SimpleDataset(16)

print("=== 1) SequentialSampler + BatchSampler ===")
seq_sampler = SequentialSampler(dataset)
seq_batch_sampler = BatchSampler(seq_sampler, batch_size=4, drop_last=False)
seq_loader = DataLoader(dataset, batch_sampler=seq_batch_sampler)
for i, batch in enumerate(seq_loader): print(f"Batch {i}: {batch.tolist()}")

print("\n=== 2) RandomSampler + BatchSampler ===")
rnd_sampler = RandomSampler(dataset)
rnd_batch_sampler = BatchSampler(rnd_sampler, batch_size=4, drop_last=True)
rnd_loader = DataLoader(dataset, batch_sampler=rnd_batch_sampler)
for i, batch in enumerate(rnd_loader): print(f"Batch {i}: {batch.tolist()}")

输出如下

=== 1) SequentialSampler + BatchSampler ===
Batch 0: [0, 1, 2, 3]
Batch 1: [4, 5, 6, 7]
Batch 2: [8, 9, 10, 11]
Batch 3: [12, 13, 14, 15]

=== 2) RandomSampler + BatchSampler ===
Batch 0: [5, 7, 12, 15]
Batch 1: [0, 8, 14, 13]
Batch 2: [9, 6, 11, 10]
Batch 3: [2, 1, 4, 3]

5. 执行逻辑

从上面内容我们可以总结一下,首先 Sampler 和 BatchSampler 是一种生成器,Sampler 生成单个样本索引,BatchSampler 生成一组样本索引列表。Sampler 用来传入到 BatchSampler 中,BatchSampler 生成的索引列表再传入到 DataLoader 中。

数据加载关系图

Sampler → BatchSampler → DataLoader → Dataset

下面举个例子,假设有 12 个样本,Batch Size = 4

Dataset = [D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10, D11]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]  # `SequentialSampler` 的返回
# `batch_size=4` 时,`BatchSampler` 会把这些索引按 4 个一组打包
[ [0,1,2,3],
  [4,5,6,7],
  [8,9,10,11] ]

最后得到索引之后,DataLoader 接收到每个索引(例如第一个 batch [0,1,2,3]),会调用 Dataset.__getitem__ 取出对应数据:

batch_1 = [D0, D1, D2, D3]
batch_2 = [D4, D5, D6, D7]
batch_3 = [D8, D9, D10, D11]

总结一下

对象 返回内容 示例输出 用途
Sampler 单个样本索引 0, 1, 2, … 控制样本顺序(顺序 / 随机)
BatchSampler 索引列表(一个 batch) [0,1,2,3] 按 batch 封装索引
DataLoader 实际数据 batch [D0,D1,D2,D3] 用于训练/推理循环

6. 进阶

除了上面的 SamplerDataset,pytorch 还有一个核心的组件叫 collate_fn,它的作用是把 Dataset 返回的多个样本数据(一个 batch)合并成一个整体的 batch 数据。默认情况下,DataLoader 会使用一个默认的 collate_fn,它会把相同类型的数据(例如张量、列表、字典等)按维度拼接在一起形成一个 batch。具体执行流程如下

Dataset  --(索引)-->  Sampler/BatchSampler  --(索引列表)-->  DataLoader
                                                   |
                                           取到样本对象列表
                                                   v
                                         collate_fn / data_collator
                                       (对齐、padding、拼 batch、造 labels)
                                                   v
                                            批量张量 batch (dict/tensor)

其中,DataLoader 负责返回样本列表,但是返回之前会调用 collate_fn 函数对样本进行处理,默认的 collate_fn 规则大概是类似调用 torch.stack,能直接拼凑的就拼凑,不能直接拼凑就会报错,因此如果我们返回的样本维度不一致,或者需要对齐、padding、造 labels 等等,就需要自定义 collate_fn 函数。

下面给一个最简单的例子,看完就知道 collate_fn 是什么了

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class ToySeqDataset(Dataset):
    def __init__(self):
        self.data = [
            ([1, 2, 3], 0),
            ([4, 5], 1),
            ([6, 7, 8, 9], 0),
            ([10], 1),
        ]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, i):
        x, y = self.data[i]
        return {"x": torch.tensor(x, dtype=torch.long),
                "y": torch.tensor(y, dtype=torch.long)}

def collate_fn(batch):
    xs = [b["x"] for b in batch]                   
    ys = torch.stack([b["y"] for b in batch])       
    X = pad_sequence(xs, batch_first=True, padding_value=0)
    return {"x": X, "y": ys}

if __name__ == "__main__":
    ds = ToySeqDataset()
    # 这里不调用自定义的 collate_fn 会报错
    good_loader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=collate_fn)
    for step, batch in enumerate(good_loader, 1):
        print(f"\nStep {step} \n x (padded):\n", batch["x"])
        print("y:", batch["y"])

这是一个长度不一的序列数据集,我们用 collate_fn 做了 padding 和拼 batch,让他等长,输出如下

Step 1 
 x (padded):
 tensor([[1, 2, 3],
        [4, 5, 0]])
y: tensor([0, 1])

Step 2 
 x (padded):
 tensor([[ 6,  7,  8,  9],
        [10,  0,  0,  0]])
y: tensor([0, 1])
8月 26, 2025