云计算百科
云计算领域专业知识百科平台

多模态学习系列(四):手动对齐和均衡批处理

引言

在当今人工智能领域,多模态学习已成为最前沿的研究方向之一。与单模态模型相比,多模态模型能够同时处理和理解来自不同来源(如图像、文本、音频等)的数据,从而更接近人类认知世界的方式。然而,这种能力的提升也带来了新的技术挑战,特别是在数据预处理和训练策略方面。上一篇我们讨论了FSDP,这一篇进一步讨论一个在分布式训练中“常被忽略但很重要”的工程细节:PyTorch 的 FSDP 不会自动处理 padding 或 batch 对齐问题。


❌ FSDP 不做什么?

FSDP(Fully Sharded Data Parallel)仅负责:

  • 分布式训练结构:模型参数切片(shard)、梯度通信、优化状态切片

  • 提供高效的多卡训练支持,包括 activation checkpoint、zero-redundancy optimizer 等

它不会:

  • 自动填充(padding)不同长度输入

  • 自动对齐 batch 的 sequence 长度

  • 调整不同 GPU 上的数据量平衡

  • 插件生成 attention mask

也就是说,所有与 sequence 长度分布和 padding 相关的问题都需要你自己处理。


🚨 你必须解决的问题

  • 手写 collate_fn:实现 padding 和 truncation,确保 batch 中每条样本长度一致

  • 控制每个 device 的样本负载:避免某块 GPU 处理大量超长文本而其他 GPU 空闲,导致速度和显存不平衡

  • 生成 attention_mask:padding 部分不应参与 attention 计算,需要 mask 掉


  • SplitModalitySampler

    SplitModalitySampler 是一种用于多模态数据(如文本+图像)的采样器,旨在解决 load balancing(负载均衡)和 长度相近样本分组(减少padding)之间的矛盾。它的核心思想是 按模态拆分样本,再分别进行动态批处理(dynamic batching),从而兼顾效率和内存优化。


    1. 核心目标

    • Load Balancing: 在多卡或多进程训练时,确保每个GPU的batch计算量均衡,避免某些卡负载过高而拖慢整体训练。

    • 长度相近的样本分组: 将长度相似的样本(如文本长度)分到同一个batch,减少padding带来的计算浪费。

    传统方法(如纯按长度分桶)可能导致某些GPU分到太多长样本而负载不均,而纯随机采样则可能增加padding。


    2. 工作原理

    (1) 按模态拆分样本
    • 将数据集按模态(如纯文本、文本+图像、纯图像等)分成多个子集,因为一般带图片的对话和纯文本的圣诞长度相差很大。

    • 例如:

      • 子集A:仅文本(无图像)

      • 子集B:文本+图像

    def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
            """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
            modality_lengths = []
            for example in self.examples:
                is_multimodal = "image" in example
                n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
                modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
            return modality_lengths
    multimodal_indices, multimodal_lengths = zip(
    *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal]
    )

    # Handle Special Case –> no "unimodal" inputs
    unimodal_split = [
    (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal
    ]

    (2) 均衡分配批次
    • 从每个子集的按需采样,确保:

      • 不同GPU分配的batch总计算量(如文本token总数+图像像素总数)均衡。比如我们可以用以下代码将batch sz = 128的数据分到两个gpu上:

    @staticmethod
    def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]:
    """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank."""
    assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!"

    # Establish initial buckets, capacities, and max number of elements per bucket
    n_examples_per_bucket = len(batch_idxs) // n_buckets
    bucket_indices = [[] for _ in range(n_buckets)]
    bucket_lengths = [0 for _ in range(n_buckets)]

    # Note that `batch_idxs` is already sorted by corresponding length (in descending order)
    for idx in batch_idxs:
    shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths))
    bucket_indices[shortest_bucket_idx].append(idx)

    # Update `bucket_lengths` –> set length to infinity if at capacity!
    bucket_lengths[shortest_bucket_idx] += idx2lengths[idx]
    if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket:
    bucket_lengths[shortest_bucket_idx] = float("inf")

    return bucket_indices
    mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs]
    mm_length_bucketed_idxs = [
    self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs
    ]
    # flatten
    mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket]

    分完后的结果如下:

    Rank 0:128 125 124 121 …

    Rank 1:127 126 123 122 …

    这样就实现了负载均衡。

    最后拉平,变成了:[128, 125, 124, 121, 127, 126, 123, 122].

    (3) 长度相近的样本分组
    • Pytorch将一个batch的数据分发到不同的GPU上,默认是按rank顺序发的,这样就造成了实际长度相近的样本并未分在一组。比如[100,99,2,1]:
      • Rank 0:100,2 → 需要 98 个pad token

      • Rank 1:99, 1 → 需要 98 个pad token

      • 结果是组内长度差别过大,不仅显存的利用率不高,也会减缓训练速度。

    • 如果是先按每个rank需要处理的数据量reshape,再重新分配:

      • Rank 0:100, 99→ 需要 1个pad token

      • Rank 1: 2,1→ 需要 1 个pad token

    # naive way: without reshape
    # optimized way: reshape and slicing, here per_replica_batch_size = 2
    per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size)
    # continuous chunk within the mini-batch
    replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas]

    这里rank 0设备由于要处理长得多的序列(100 vs 2)而耗时明显更长,导致设备间不平衡和空闲等待。

    实际上,在我们上一步先按长度均分后,rank 1中的最大长度将是当前batch中第二长的,各设备间的负载会很平衡(128 vs 127)。

    建议:把最长的批次放在最前面,这样如果会出现OOM,就能快速失败。

    🎲 仍然足够随机

    因为划分batch之前会先随机打乱id,batch之间还是会有一定的随机性。能保证每个batch的数据量大致差不多。


    3. 优势

    • 减少Padding:同一batch内样本长度相近,显存利用率高。

    • 负载均衡:通过拆分模态和动态调度,平衡多GPU的计算量。

    • 灵活性:支持混合模态(如部分样本有文本无图像)


    📚 参考文献

  • LLaVA Trainer  视觉-语言对齐与指令微调训练逻辑的核心实现

  • HuggingFace Transformers – Trainer Utilities 高效训练与分布式批处理的PyTorch核心工具集

  • OpenVLA – Batching Utilities 多模态token对齐任务中使用的定制化批处理逻辑

  • 赞(0)
    未经允许不得转载:网硕互联帮助中心 » 多模态学习系列(四):手动对齐和均衡批处理
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!