Huggingface 核心技巧(一): LengthGroupedSampler
本文字数:4k 字 | 阅读时长 ≈ 20 min

Huggingface 核心技巧(一): LengthGroupedSampler

本文字数:4k 字 | 阅读时长 ≈ 20 min

Huggingface 核心技巧(一): LengthGroupedSampler

LengthGroupedSampler 的核心就是长度分组,他能够生成一个索引顺序,使得 DataLoader 每个 batch 里的样本长度尽量接近,从而减少 padding,提高训练效率。

废话少说,直接看源码

def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
    if mega_batch_mult is None:
        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
        # Just in case, for tiny datasets
        if mega_batch_mult == 0:
            mega_batch_mult = 1

    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    indices = torch.randperm(len(lengths), generator=generator)
    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]

    # The rest is to get the biggest batch first.
    # Since each megabatch is sorted by descending length, the longest element is the first
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
    # Switch to put the longest element in first position
    megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]

    return [i for megabatch in megabatches for i in megabatch]
参数 含义
lengths 每个样本的长度(通常是 tokenizer 后的 token 数量)。
batch_size 一个 batch 的样本数量。
mega_batch_mult mega-batch 的倍数,控制“在局部范围内排序”的粒度。
generator 随机数生成器(可确保分布式时所有 rank 洗牌一致)。

1. LengthGroupedSampler 类

mega_batch_mult 大小是 min(len(lengths) // (batch_size * 4), 50),这里是写死的,意思是每个 mega_batch_mult 包含 mega_batch_mult * batch_size 个样本,默认 mega_batch_mult 最大为 50,最小为 1。

2. 工作流程

  1. 随机打乱:首先对所有数据随机打乱
  2. 分组:按 mega-batch(局部块)划分
  3. 局部排序:在每个块内部按长度降序排序(长的在前)
  4. 最长序列优先:把最长样本所在 batch 提到最前面(提前触发 OOM)
  5. 展平输出:最后把所有块展平成一个索引列表

一个简单的例子

假设是下面的例子

lengths = [3, 8, 4, 10, 9, 5, 2, 7]
batch_size = 2
mega_batch_mult = 2

处理流程如下

原索引: [0,1,2,3,4,5,6,7]
随机打乱 → [3,0,5,2,1,7,4,6]
mega_batch_size = 4
分成2个mega-batch:
  MB1 = [3,0,5,2] → 对应长度[10,3,5,4] → 排序后 [3,5,2,0]
  MB2 = [1,7,4,6] → 对应长度[8,7,9,2] → 排序后 [4,1,7,6]
拼接 → [3,5,2,0,4,1,7,6]

最长样本在索引3(长度10) → 已经在最前
这个主要看这行代码 megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]

3. 一个完整的示例

这个例子完整的展示了数据的提取过程

import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import BatchEncoding
from typing import List, Optional

# 定义示例数据集
class ExampleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def visualize_grouping_process(lengths, indices, megabatches, sorted_megabatches, max_idx, step="", prefix=""):
    """
    可视化分组过程的辅助函数
    
    Args:
        lengths: 序列长度列表
        indices: 当前的索引列表
        megabatches: mega-batch列表
        sorted_megabatches: 排序后的mega-batch列表
        max_idx: 最长序列所在的mega-batch索引
        step: 当前步骤的描述
        prefix: 输出的前缀(用于缩进)
    """
    if not step:
        return
        
    print(f"\n{prefix}>>> 可视化步骤: {step}")
    if step == "random_permutation":
        print(f"{prefix}随机打乱后的索引: {indices.tolist()}")
    
    elif step == "create_megabatches":
        print(f"{prefix}划分的mega-batches:")
        for i, megabatch in enumerate(megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")
    
    elif step == "sort_megabatches":
        print(f"{prefix}排序后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")
    
    elif step == "adjust_longest":
        print(f"{prefix}最长序列在 Mega-batch {max_idx + 1}")
        print(f"{prefix}调整后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")

def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None, visualize=False):
    """
    Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
    lengths. To do this, the indices are:

    - randomly permuted
    - grouped in mega-batches of size `mega_batch_mult * batch_size`
    - sorted by length in each mega-batch

    The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
    maximum length placed first, so that an OOM happens sooner rather than later.
    """
    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
    if mega_batch_mult is None:
        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
        # Just in case, for tiny datasets
        if mega_batch_mult == 0:
            mega_batch_mult = 1

    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    indices = torch.randperm(len(lengths), generator=generator)
    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]

    # The rest is to get the biggest batch first.
    # Since each megabatch is sorted by descending length, the longest element is the first
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
    # Switch to put the longest element in first position
    megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]

    return [i for megabatch in megabatches for i in megabatch]

# 定义 LengthGroupedSampler 类
class LengthGroupedSampler(Sampler):
    def __init__(self, batch_size: int, dataset: Optional[Dataset] = None, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, generator=None):
        if dataset is None and lengths is None:
            raise ValueError("One of dataset and lengths must be provided.")

        self.batch_size = batch_size
        if lengths is None:
            model_input_name = model_input_name if model_input_name is not None else "input_ids"
            if not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) or model_input_name not in dataset[0]:
                raise ValueError("Can only automatically infer lengths for datasets whose items are dictionaries with an '{model_input_name}' key.")
            lengths = [len(feature[model_input_name]) for feature in dataset]
        elif isinstance(lengths, torch.Tensor):
            lengths = lengths.tolist()

        self.lengths = lengths
        self.generator = generator

    def __len__(self):
        return len(self.lengths)

    def __iter__(self):
        indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
        return iter(indices)

# 自定义 collate_fn 函数
def collate_fn(batch):
    # 找出当前batch中最长的序列长度
    max_length = max(len(item["input_ids"]) for item in batch)
    
    # 对batch中的每个序列进行填充
    padded_batch = []
    for item in batch:
        input_ids = item["input_ids"]
        padding_length = max_length - len(input_ids)
        padded_input_ids = input_ids + [0] * padding_length
        padded_batch.append({"input_ids": padded_input_ids})
    
    # 为了更好的可视化,我们也打印一下原始长度
    print("\n当前batch信息:")
    print(f"最大长度: {max_length}")
    for i, (orig, padded) in enumerate(zip(batch, padded_batch)):
        print(f"序列 {i+1}: 原始长度 {len(orig['input_ids'])}, "
              f"原始数据: {orig['input_ids']}, "
              f"填充后: {padded['input_ids']}")
    
    return padded_batch

def visualize_length_grouped_sampling(data, batch_size=3, mega_batch_mult=2, generator=None):
    """
    可视化长度分组采样的过程
    
    Args:
        data: 数据列表,每个元素是包含 'input_ids' 的字典
        batch_size: batch大小
        mega_batch_mult: mega batch 倍数
    """
    print("\n" + "="*80)
    print("长度分组采样可视化")
    print("="*80)
    
    # 1. 显示原始数据
    print("\n1. 原始数据:")
    print("-"*50)
    lengths = [len(item["input_ids"]) for item in data]
    for idx, (item, length) in enumerate(zip(data, lengths)):
        print(f"索引: {idx:<2} | 长度: {length:<2} | 数据: {item['input_ids']}")
    
    # 2. 随机打乱
    print("\n随机打乱前生成器状态:", generator.get_state())  # 打印生成器状态
    indices = torch.randperm(len(lengths), generator=generator).tolist()
    print("随机打乱后的索引:", indices)
    print("随机打乱后生成器状态:", generator.get_state())  # 打印生成器状态
    
    # 3. 划分mega-batches
    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
    
    print(f"\n3. 划分成mega-batches (size={megabatch_size}):")
    for i, megabatch in enumerate(megabatches):
        print(f"\nMega-batch {i+1} (排序前):")
        print("-"*50)
        for idx in megabatch:
            print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 4. 对每个mega-batch内部排序
    sorted_megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) 
                         for megabatch in megabatches]
    
    print("\n4. 每个mega-batch内部按长度排序:")
    for i, megabatch in enumerate(sorted_megabatches):
        print(f"\nMega-batch {i+1} (排序后):")
        print("-"*50)
        for idx in megabatch:
            print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 5. 找出最长序列并调整位置
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in sorted_megabatches]
    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
    
    print("\n5. 最长序列调整:")
    print("-"*50)
    print(f"最长序列在 Mega-batch {max_idx + 1}")
    
    if max_idx > 0:
        sorted_megabatches[0][0], sorted_megabatches[max_idx][0] = \
            sorted_megabatches[max_idx][0], sorted_megabatches[0][0]
        
        print("\n调整后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"\nMega-batch {i+1}:")
            print("-"*50)
            for idx in megabatch:
                print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 6. 最终展平的结果
    final_indices = [i for megabatch in sorted_megabatches for i in megabatch]
    print("\n6. 最终的采样顺序:")
    print("-"*50)
    for i, idx in enumerate(final_indices):
        batch_num = i // batch_size + 1
        if i % batch_size == 0:
            print(f"\nBatch {batch_num}:")
        print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")

if __name__ == "__main__":
    # 创建示例数据集
    data = [
        {"input_ids": [1, 2, 3]},                    # 长度 3
        {"input_ids": [1, 2]},                       # 长度 2
        {"input_ids": [1, 2, 3, 4, 5]},             # 长度 5
        {"input_ids": [1]},                          # 长度 1
        {"input_ids": [1, 2, 3, 4]},                # 长度 4
        {"input_ids": [1, 2, 3, 4, 5, 6]},          # 长度 6
        {"input_ids": [1, 2, 3, 4, 5, 6, 7]},       # 长度 7
        {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8]},    # 长度 8
        {"input_ids": [1, 2, 3]},                    # 长度 3
        {"input_ids": [1, 2, 3, 4]},                # 长度 4
        {"input_ids": [1]},                          # 长度 1
        {"input_ids": [1, 2, 3, 4, 5]},             # 长度 5
    ]

    batch_size = 3

    # 设置随机种子以确保结果可重现
    generator = torch.Generator()
    generator.manual_seed(42)  # 设置初始种子

    # 首先运行可视化函数
    print("\n=== 可视化长度分组采样过程 ===")
    visualize_length_grouped_sampling(data, batch_size=batch_size, mega_batch_mult=2, generator=generator)

    print("\n随机生成器状态:", generator.get_state())  # 打印当前生成器状态

    # 重置生成器到相同的种子
    generator = torch.Generator()
    generator.manual_seed(42)  # 使用相同的种子

    print("\n=== 实际运行结果 ===")
    # 创建数据集和DataLoader
    dataset = ExampleDataset(data)
    sampler = LengthGroupedSampler(batch_size=batch_size, dataset=dataset, 
                                 model_input_name="input_ids", generator=generator)
    dataloader = DataLoader(dataset, batch_size=batch_size, 
                          sampler=sampler, collate_fn=collate_fn)

    # 打印每个批次的内容
    print("\n实际批次输出:")
    for i, batch in enumerate(dataloader, 1):
        print(f"\nBatch {i}:")
        for item in batch:
            print(f"长度: {len(item['input_ids'])}, 数据: {item['input_ids']}")

输出

=== 可视化长度分组采样过程 ===

================================================================================
长度分组采样可视化
================================================================================

1. 原始数据:
--------------------------------------------------
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 3  | 长度: 1  | 数据: [1]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]

随机打乱前生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)
随机打乱后的索引: [6, 8, 1, 7, 0, 2, 10, 11, 4, 3, 5, 9]
随机打乱后生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)

3. 划分成mega-batches (size=6):

Mega-batch 1 (排序前):
--------------------------------------------------
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]

Mega-batch 2 (排序前):
--------------------------------------------------
索引: 10 | 长度: 1  | 数据: [1]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 3  | 长度: 1  | 数据: [1]
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]

4. 每个mega-batch内部按长度排序:

Mega-batch 1 (排序后):
--------------------------------------------------
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]

Mega-batch 2 (排序后):
--------------------------------------------------
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 3  | 长度: 1  | 数据: [1]

5. 最长序列调整:
--------------------------------------------------
最长序列在 Mega-batch 1

6. 最终的采样顺序:
--------------------------------------------------

Batch 1:
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]

Batch 2:
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]

Batch 3:
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]

Batch 4:
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 3  | 长度: 1  | 数据: [1]

随机生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)

=== 实际运行结果 ===

实际批次输出:

当前batch信息:
最大长度: 8
序列 1: 原始长度 8, 原始数据: [1, 2, 3, 4, 5, 6, 7, 8], 填充后: [1, 2, 3, 4, 5, 6, 7, 8]
序列 2: 原始长度 3, 原始数据: [1, 2, 3], 填充后: [1, 2, 3, 0, 0, 0, 0, 0]
序列 3: 原始长度 2, 原始数据: [1, 2], 填充后: [1, 2, 0, 0, 0, 0, 0, 0]

Batch 1:
长度: 8, 数据: [1, 2, 3, 4, 5, 6, 7, 8]
长度: 8, 数据: [1, 2, 3, 0, 0, 0, 0, 0]
长度: 8, 数据: [1, 2, 0, 0, 0, 0, 0, 0]

当前batch信息:
最大长度: 7
序列 1: 原始长度 7, 原始数据: [1, 2, 3, 4, 5, 6, 7], 填充后: [1, 2, 3, 4, 5, 6, 7]
序列 2: 原始长度 5, 原始数据: [1, 2, 3, 4, 5], 填充后: [1, 2, 3, 4, 5, 0, 0]
序列 3: 原始长度 3, 原始数据: [1, 2, 3], 填充后: [1, 2, 3, 0, 0, 0, 0]

Batch 2:
长度: 7, 数据: [1, 2, 3, 4, 5, 6, 7]
长度: 7, 数据: [1, 2, 3, 4, 5, 0, 0]
长度: 7, 数据: [1, 2, 3, 0, 0, 0, 0]

当前batch信息:
最大长度: 5
序列 1: 原始长度 5, 原始数据: [1, 2, 3, 4, 5], 填充后: [1, 2, 3, 4, 5]
序列 2: 原始长度 4, 原始数据: [1, 2, 3, 4], 填充后: [1, 2, 3, 4, 0]
序列 3: 原始长度 1, 原始数据: [1], 填充后: [1, 0, 0, 0, 0]

Batch 3:
长度: 5, 数据: [1, 2, 3, 4, 5]
长度: 5, 数据: [1, 2, 3, 4, 0]
长度: 5, 数据: [1, 0, 0, 0, 0]

当前batch信息:
最大长度: 6
序列 1: 原始长度 6, 原始数据: [1, 2, 3, 4, 5, 6], 填充后: [1, 2, 3, 4, 5, 6]
序列 2: 原始长度 4, 原始数据: [1, 2, 3, 4], 填充后: [1, 2, 3, 4, 0, 0]
序列 3: 原始长度 1, 原始数据: [1], 填充后: [1, 0, 0, 0, 0, 0]

Batch 4:
长度: 6, 数据: [1, 2, 3, 4, 5, 6]
长度: 6, 数据: [1, 2, 3, 4, 0, 0]
长度: 6, 数据: [1, 0, 0, 0, 0, 0]

mega_batch(大批次)是一个比普通batch更大的数据块的概念。让我用一个具体例子来解释:

假设我们有以下设置:

batch_size = 3           # 正常的批次大小
mega_batch_mult = 2      # mega batch 倍数
data_size = 12          # 总数据量

那么:

  1. 普通batch:包含3个样本
  2. mega_batch:包含6个样本 (batch_size * mega_batch_mult = 3 * 2 = 6)

数据处理流程:

原始数据 (12个样本):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

↓ 随机打乱

随机顺序:
[5, 2, 8, 1, 11, 4, 7, 3, 9, 0, 6, 10]

↓ 分成mega-batches (每个包含6个样本)

Mega-batch 1: [5, 2, 8, 1, 11, 4]
Mega-batch 2: [7, 3, 9, 0, 6, 10]

↓ 在每个mega-batch内部按长度排序

排序后的Mega-batch 1: [11, 8, 5, 4, 2, 1]  (按序列长度降序)
排序后的Mega-batch 2: [9, 7, 6, 3, 10, 0]  (按序列长度降序)

↓ 最终按batch_size=3分批处理

Batch 1: [11, 8, 5]
Batch 2: [4, 2, 1]
Batch 3: [9, 7, 6]
Batch 4: [3, 10, 0]