Huggingface 核心技巧(一): LengthGroupedSampler
本文字数:4.3k 字 | 阅读时长 ≈ 21 min

Huggingface 核心技巧(一): LengthGroupedSampler

本文字数:4.3k 字 | 阅读时长 ≈ 21 min

Huggingface 核心技巧(一): LengthGroupedSampler

1. 引言

在处理NLP任务时,我们经常遇到不同长度的序列。为了提高训练效率,我们需要合理地组织这些序列。Huggingface的LengthGroupedSampler就是为解决这个问题而设计的。

2. 为什么需要长度分组?

在处理变长序列时,我们面临两个主要问题:

  1. 内存效率:如果batch中的序列长度差异太大,短序列需要大量padding,造成内存浪费
  2. 训练效果:完全按长度排序会降低随机性,可能影响模型训练效果

LengthGroupedSampler通过"mega-batch"机制巧妙地平衡了这两个问题。

3. Mega-Batch机制详解

3.1 什么是Mega-Batch?

Mega-Batch是比普通batch更大的数据块。例如:

3.2 工作流程

  1. 随机打乱:首先对所有数据随机打乱
  2. 分组:将打乱后的数据分成多个mega-batch
  3. 局部排序:在每个mega-batch内部按序列长度排序
  4. 最长序列优先:将包含最长序列的batch放在最前面

3.3 示例演示

假设我们有以下数据:

data = [
    {"input_ids": [1, 2, 3]},         # 长度 3
    {"input_ids": [1, 2]},            # 长度 2
    {"input_ids": [1, 2, 3, 4, 5]},   # 长度 5
    {"input_ids": [1]},               # 长度 1
]

# 设置参数
batch_size = 2
mega_batch_mult = 2

处理流程:

1. 随机打乱:
   [3, 1, 4, 2]

2. 分成mega-batches (size=4):
   Mega-batch: [3, 1, 4, 2]

3. 按长度排序:
   [4, 3, 1, 2]  # 长度: 5,3,2,1

4. 最终batches:
   Batch 1: [4, 3]  # 长度: 5,3
   Batch 2: [1, 2]  # 长度: 2,1

4. 代码实现

4.1 核心采样器实现

class LengthGroupedSampler(Sampler):
    def __init__(self, batch_size: int, dataset: Optional[Dataset] = None, 
                 lengths: Optional[List[int]] = None, 
                 model_input_name: Optional[str] = None):
        if dataset is None and lengths is None:
            raise ValueError("One of dataset and lengths must be provided.")
        
        self.batch_size = batch_size
        self.lengths = self._get_lengths(dataset, lengths, model_input_name)

    def __iter__(self):
        indices = get_length_grouped_indices(self.lengths, self.batch_size)
        return iter(indices)

4.2 使用示例

# 创建数据加载器
dataset = YourDataset(...)
sampler = LengthGroupedSampler(batch_size=32, dataset=dataset)
dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    sampler=sampler,
    collate_fn=collate_fn
)

# 训练循环
for batch in dataloader:
    # 处理batch
    ...

5. 优势与注意事项

5.1 主要优势

  1. 提高计算效率:减少padding,节省内存
  2. 保持随机性:通过mega-batch机制平衡排序和随机性
  3. 灵活可调:可通过mega_batch_mult参数调整分组粒度

5.2 使用建议

  1. 根据数据集大小调整mega_batch_mult
  2. 监控训练过程中的内存使用情况
  3. 在大规模数据集上效果更明显

6. 总结

LengthGroupedSampler是一个优秀的数据采样工具,通过巧妙的设计在效率和训练效果之间取得了很好的平衡。合理使用它可以显著提升训练效率,特别是在处理变长序列的NLP任务中。

参考资料

  1. Huggingface Transformers文档
  2. PyTorch DataLoader文档
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import BatchEncoding
from typing import List, Optional

# 定义示例数据集
class ExampleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def visualize_grouping_process(lengths, indices, megabatches, sorted_megabatches, max_idx, step="", prefix=""):
    """
    可视化分组过程的辅助函数
    
    Args:
        lengths: 序列长度列表
        indices: 当前的索引列表
        megabatches: mega-batch列表
        sorted_megabatches: 排序后的mega-batch列表
        max_idx: 最长序列所在的mega-batch索引
        step: 当前步骤的描述
        prefix: 输出的前缀(用于缩进)
    """
    if not step:
        return
        
    print(f"\n{prefix}>>> 可视化步骤: {step}")
    if step == "random_permutation":
        print(f"{prefix}随机打乱后的索引: {indices.tolist()}")
    
    elif step == "create_megabatches":
        print(f"{prefix}划分的mega-batches:")
        for i, megabatch in enumerate(megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")
    
    elif step == "sort_megabatches":
        print(f"{prefix}排序后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")
    
    elif step == "adjust_longest":
        print(f"{prefix}最长序列在 Mega-batch {max_idx + 1}")
        print(f"{prefix}调整后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"{prefix}Mega-batch {i+1}: {megabatch}")
            print(f"{prefix}对应的长度: {[lengths[idx] for idx in megabatch]}")

def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None, visualize=False):
    """
    Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
    lengths. To do this, the indices are:

    - randomly permuted
    - grouped in mega-batches of size `mega_batch_mult * batch_size`
    - sorted by length in each mega-batch

    The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
    maximum length placed first, so that an OOM happens sooner rather than later.
    """
    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
    if mega_batch_mult is None:
        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
        # Just in case, for tiny datasets
        if mega_batch_mult == 0:
            mega_batch_mult = 1

    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    indices = torch.randperm(len(lengths), generator=generator)
    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]

    # The rest is to get the biggest batch first.
    # Since each megabatch is sorted by descending length, the longest element is the first
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
    # Switch to put the longest element in first position
    megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]

    return [i for megabatch in megabatches for i in megabatch]

# 定义 LengthGroupedSampler 类
class LengthGroupedSampler(Sampler):
    def __init__(self, batch_size: int, dataset: Optional[Dataset] = None, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, generator=None):
        if dataset is None and lengths is None:
            raise ValueError("One of dataset and lengths must be provided.")

        self.batch_size = batch_size
        if lengths is None:
            model_input_name = model_input_name if model_input_name is not None else "input_ids"
            if not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) or model_input_name not in dataset[0]:
                raise ValueError("Can only automatically infer lengths for datasets whose items are dictionaries with an '{model_input_name}' key.")
            lengths = [len(feature[model_input_name]) for feature in dataset]
        elif isinstance(lengths, torch.Tensor):
            lengths = lengths.tolist()

        self.lengths = lengths
        self.generator = generator

    def __len__(self):
        return len(self.lengths)

    def __iter__(self):
        indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
        return iter(indices)

# 自定义 collate_fn 函数
def collate_fn(batch):
    # 找出当前batch中最长的序列长度
    max_length = max(len(item["input_ids"]) for item in batch)
    
    # 对batch中的每个序列进行填充
    padded_batch = []
    for item in batch:
        input_ids = item["input_ids"]
        padding_length = max_length - len(input_ids)
        padded_input_ids = input_ids + [0] * padding_length
        padded_batch.append({"input_ids": padded_input_ids})
    
    # 为了更好的可视化,我们也打印一下原始长度
    print("\n当前batch信息:")
    print(f"最大长度: {max_length}")
    for i, (orig, padded) in enumerate(zip(batch, padded_batch)):
        print(f"序列 {i+1}: 原始长度 {len(orig['input_ids'])}, "
              f"原始数据: {orig['input_ids']}, "
              f"填充后: {padded['input_ids']}")
    
    return padded_batch

def visualize_length_grouped_sampling(data, batch_size=3, mega_batch_mult=2, generator=None):
    """
    可视化长度分组采样的过程
    
    Args:
        data: 数据列表,每个元素是包含 'input_ids' 的字典
        batch_size: batch大小
        mega_batch_mult: mega batch 倍数
    """
    print("\n" + "="*80)
    print("长度分组采样可视化")
    print("="*80)
    
    # 1. 显示原始数据
    print("\n1. 原始数据:")
    print("-"*50)
    lengths = [len(item["input_ids"]) for item in data]
    for idx, (item, length) in enumerate(zip(data, lengths)):
        print(f"索引: {idx:<2} | 长度: {length:<2} | 数据: {item['input_ids']}")
    
    # 2. 随机打乱
    print("\n随机打乱前生成器状态:", generator.get_state())  # 打印生成器状态
    indices = torch.randperm(len(lengths), generator=generator).tolist()
    print("随机打乱后的索引:", indices)
    print("随机打乱后生成器状态:", generator.get_state())  # 打印生成器状态
    
    # 3. 划分mega-batches
    megabatch_size = mega_batch_mult * batch_size
    megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
    
    print(f"\n3. 划分成mega-batches (size={megabatch_size}):")
    for i, megabatch in enumerate(megabatches):
        print(f"\nMega-batch {i+1} (排序前):")
        print("-"*50)
        for idx in megabatch:
            print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 4. 对每个mega-batch内部排序
    sorted_megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) 
                         for megabatch in megabatches]
    
    print("\n4. 每个mega-batch内部按长度排序:")
    for i, megabatch in enumerate(sorted_megabatches):
        print(f"\nMega-batch {i+1} (排序后):")
        print("-"*50)
        for idx in megabatch:
            print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 5. 找出最长序列并调整位置
    megabatch_maximums = [lengths[megabatch[0]] for megabatch in sorted_megabatches]
    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
    
    print("\n5. 最长序列调整:")
    print("-"*50)
    print(f"最长序列在 Mega-batch {max_idx + 1}")
    
    if max_idx > 0:
        sorted_megabatches[0][0], sorted_megabatches[max_idx][0] = \
            sorted_megabatches[max_idx][0], sorted_megabatches[0][0]
        
        print("\n调整后的mega-batches:")
        for i, megabatch in enumerate(sorted_megabatches):
            print(f"\nMega-batch {i+1}:")
            print("-"*50)
            for idx in megabatch:
                print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")
    
    # 6. 最终展平的结果
    final_indices = [i for megabatch in sorted_megabatches for i in megabatch]
    print("\n6. 最终的采样顺序:")
    print("-"*50)
    for i, idx in enumerate(final_indices):
        batch_num = i // batch_size + 1
        if i % batch_size == 0:
            print(f"\nBatch {batch_num}:")
        print(f"索引: {idx:<2} | 长度: {lengths[idx]:<2} | 数据: {data[idx]['input_ids']}")

if __name__ == "__main__":
    # 创建示例数据集
    data = [
        {"input_ids": [1, 2, 3]},                    # 长度 3
        {"input_ids": [1, 2]},                       # 长度 2
        {"input_ids": [1, 2, 3, 4, 5]},             # 长度 5
        {"input_ids": [1]},                          # 长度 1
        {"input_ids": [1, 2, 3, 4]},                # 长度 4
        {"input_ids": [1, 2, 3, 4, 5, 6]},          # 长度 6
        {"input_ids": [1, 2, 3, 4, 5, 6, 7]},       # 长度 7
        {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8]},    # 长度 8
        {"input_ids": [1, 2, 3]},                    # 长度 3
        {"input_ids": [1, 2, 3, 4]},                # 长度 4
        {"input_ids": [1]},                          # 长度 1
        {"input_ids": [1, 2, 3, 4, 5]},             # 长度 5
    ]

    batch_size = 3

    # 设置随机种子以确保结果可重现
    generator = torch.Generator()
    generator.manual_seed(42)  # 设置初始种子

    # 首先运行可视化函数
    print("\n=== 可视化长度分组采样过程 ===")
    visualize_length_grouped_sampling(data, batch_size=batch_size, mega_batch_mult=2, generator=generator)

    print("\n随机生成器状态:", generator.get_state())  # 打印当前生成器状态

    # 重置生成器到相同的种子
    generator = torch.Generator()
    generator.manual_seed(42)  # 使用相同的种子

    print("\n=== 实际运行结果 ===")
    # 创建数据集和DataLoader
    dataset = ExampleDataset(data)
    sampler = LengthGroupedSampler(batch_size=batch_size, dataset=dataset, 
                                 model_input_name="input_ids", generator=generator)
    dataloader = DataLoader(dataset, batch_size=batch_size, 
                          sampler=sampler, collate_fn=collate_fn)

    # 打印每个批次的内容
    print("\n实际批次输出:")
    for i, batch in enumerate(dataloader, 1):
        print(f"\nBatch {i}:")
        for item in batch:
            print(f"长度: {len(item['input_ids'])}, 数据: {item['input_ids']}")

输出

=== 可视化长度分组采样过程 ===

================================================================================
长度分组采样可视化
================================================================================

1. 原始数据:
--------------------------------------------------
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 3  | 长度: 1  | 数据: [1]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]

随机打乱前生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)
随机打乱后的索引: [6, 8, 1, 7, 0, 2, 10, 11, 4, 3, 5, 9]
随机打乱后生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)

3. 划分成mega-batches (size=6):

Mega-batch 1 (排序前):
--------------------------------------------------
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]

Mega-batch 2 (排序前):
--------------------------------------------------
索引: 10 | 长度: 1  | 数据: [1]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 3  | 长度: 1  | 数据: [1]
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]

4. 每个mega-batch内部按长度排序:

Mega-batch 1 (排序后):
--------------------------------------------------
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]

Mega-batch 2 (排序后):
--------------------------------------------------
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 3  | 长度: 1  | 数据: [1]

5. 最长序列调整:
--------------------------------------------------
最长序列在 Mega-batch 1

6. 最终的采样顺序:
--------------------------------------------------

Batch 1:
索引: 7  | 长度: 8  | 数据: [1, 2, 3, 4, 5, 6, 7, 8]
索引: 6  | 长度: 7  | 数据: [1, 2, 3, 4, 5, 6, 7]
索引: 2  | 长度: 5  | 数据: [1, 2, 3, 4, 5]

Batch 2:
索引: 8  | 长度: 3  | 数据: [1, 2, 3]
索引: 0  | 长度: 3  | 数据: [1, 2, 3]
索引: 1  | 长度: 2  | 数据: [1, 2]

Batch 3:
索引: 5  | 长度: 6  | 数据: [1, 2, 3, 4, 5, 6]
索引: 11 | 长度: 5  | 数据: [1, 2, 3, 4, 5]
索引: 4  | 长度: 4  | 数据: [1, 2, 3, 4]

Batch 4:
索引: 9  | 长度: 4  | 数据: [1, 2, 3, 4]
索引: 10 | 长度: 1  | 数据: [1]
索引: 3  | 长度: 1  | 数据: [1]

随机生成器状态: tensor([42,  0,  0,  ...,  0,  0,  0], dtype=torch.uint8)

=== 实际运行结果 ===

实际批次输出:

当前batch信息:
最大长度: 8
序列 1: 原始长度 8, 原始数据: [1, 2, 3, 4, 5, 6, 7, 8], 填充后: [1, 2, 3, 4, 5, 6, 7, 8]
序列 2: 原始长度 3, 原始数据: [1, 2, 3], 填充后: [1, 2, 3, 0, 0, 0, 0, 0]
序列 3: 原始长度 2, 原始数据: [1, 2], 填充后: [1, 2, 0, 0, 0, 0, 0, 0]

Batch 1:
长度: 8, 数据: [1, 2, 3, 4, 5, 6, 7, 8]
长度: 8, 数据: [1, 2, 3, 0, 0, 0, 0, 0]
长度: 8, 数据: [1, 2, 0, 0, 0, 0, 0, 0]

当前batch信息:
最大长度: 7
序列 1: 原始长度 7, 原始数据: [1, 2, 3, 4, 5, 6, 7], 填充后: [1, 2, 3, 4, 5, 6, 7]
序列 2: 原始长度 5, 原始数据: [1, 2, 3, 4, 5], 填充后: [1, 2, 3, 4, 5, 0, 0]
序列 3: 原始长度 3, 原始数据: [1, 2, 3], 填充后: [1, 2, 3, 0, 0, 0, 0]

Batch 2:
长度: 7, 数据: [1, 2, 3, 4, 5, 6, 7]
长度: 7, 数据: [1, 2, 3, 4, 5, 0, 0]
长度: 7, 数据: [1, 2, 3, 0, 0, 0, 0]

当前batch信息:
最大长度: 5
序列 1: 原始长度 5, 原始数据: [1, 2, 3, 4, 5], 填充后: [1, 2, 3, 4, 5]
序列 2: 原始长度 4, 原始数据: [1, 2, 3, 4], 填充后: [1, 2, 3, 4, 0]
序列 3: 原始长度 1, 原始数据: [1], 填充后: [1, 0, 0, 0, 0]

Batch 3:
长度: 5, 数据: [1, 2, 3, 4, 5]
长度: 5, 数据: [1, 2, 3, 4, 0]
长度: 5, 数据: [1, 0, 0, 0, 0]

当前batch信息:
最大长度: 6
序列 1: 原始长度 6, 原始数据: [1, 2, 3, 4, 5, 6], 填充后: [1, 2, 3, 4, 5, 6]
序列 2: 原始长度 4, 原始数据: [1, 2, 3, 4], 填充后: [1, 2, 3, 4, 0, 0]
序列 3: 原始长度 1, 原始数据: [1], 填充后: [1, 0, 0, 0, 0, 0]

Batch 4:
长度: 6, 数据: [1, 2, 3, 4, 5, 6]
长度: 6, 数据: [1, 2, 3, 4, 0, 0]
长度: 6, 数据: [1, 0, 0, 0, 0, 0]

mega_batch(大批次)是一个比普通batch更大的数据块的概念。让我用一个具体例子来解释:

假设我们有以下设置:

batch_size = 3           # 正常的批次大小
mega_batch_mult = 2      # mega batch 倍数
data_size = 12          # 总数据量

那么:

  1. 普通batch:包含3个样本
  2. mega_batch:包含6个样本 (batch_size * mega_batch_mult = 3 * 2 = 6)

数据处理流程:

原始数据 (12个样本):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

↓ 随机打乱

随机顺序:
[5, 2, 8, 1, 11, 4, 7, 3, 9, 0, 6, 10]

↓ 分成mega-batches (每个包含6个样本)

Mega-batch 1: [5, 2, 8, 1, 11, 4]
Mega-batch 2: [7, 3, 9, 0, 6, 10]

↓ 在每个mega-batch内部按长度排序

排序后的Mega-batch 1: [11, 8, 5, 4, 2, 1]  (按序列长度降序)
排序后的Mega-batch 2: [9, 7, 6, 3, 10, 0]  (按序列长度降序)

↓ 最终按batch_size=3分批处理

Batch 1: [11, 8, 5]
Batch 2: [4, 2, 1]
Batch 3: [9, 7, 6]
Batch 4: [3, 10, 0]

使用mega_batch的好处:

  1. 平衡随机性和效率

    • 完全随机会导致长度差异大,需要大量padding
    • 完全按长度排序会失去随机性,可能影响模型训练
    • mega_batch在两者之间取得平衡
  2. 局部排序

    • 只在mega_batch内部排序
    • 保持一定的随机性,同时又能让相似长度的序列靠在一起
  3. 灵活性

    • 通过调整mega_batch_mult可以控制排序的粒度
    • 较大的mult值会使得长度分组更精确
    • 较小的mult值会增加随机性
  4. 内存效率

    • 相似长度的序列在一起处理时,需要的padding更少
    • 减少了内存浪费,提高了计算效率

这就是为什么在代码中使用mega_batch的概念,它是一个在随机性和计算效率之间取得平衡的巧妙设计。