引言
在当今人工智能领域,多模态学习已成为最前沿的研究方向之一。与单模态模型相比,多模态模型能够同时处理和理解来自不同来源(如图像、文本、音频等)的数据,从而更接近人类认知世界的方式。然而,这种能力的提升也带来了新的技术挑战,特别是在数据预处理和训练策略方面。上一篇我们讨论了FSDP,这一篇进一步讨论一个在分布式训练中“常被忽略但很重要”的工程细节:PyTorch 的 FSDP 不会自动处理 padding 或 batch 对齐问题。
❌ FSDP 不做什么?
FSDP(Fully Sharded Data Parallel)仅负责:
-
分布式训练结构:模型参数切片(shard)、梯度通信、优化状态切片
-
提供高效的多卡训练支持,包括 activation checkpoint、zero-redundancy optimizer 等
它不会:
-
自动填充(padding)不同长度输入
-
自动对齐 batch 的 sequence 长度
-
调整不同 GPU 上的数据量平衡
-
插件生成 attention mask
也就是说,所有与 sequence 长度分布和 padding 相关的问题都需要你自己处理。
🚨 你必须解决的问题
手写 collate_fn:实现 padding 和 truncation,确保 batch 中每条样本长度一致
控制每个 device 的样本负载:避免某块 GPU 处理大量超长文本而其他 GPU 空闲,导致速度和显存不平衡
生成 attention_mask:padding 部分不应参与 attention 计算,需要 mask 掉
SplitModalitySampler
SplitModalitySampler 是一种用于多模态数据(如文本+图像)的采样器,旨在解决 load balancing(负载均衡)和 长度相近样本分组(减少padding)之间的矛盾。它的核心思想是 按模态拆分样本,再分别进行动态批处理(dynamic batching),从而兼顾效率和内存优化。
1. 核心目标
-
Load Balancing: 在多卡或多进程训练时,确保每个GPU的batch计算量均衡,避免某些卡负载过高而拖慢整体训练。
-
长度相近的样本分组: 将长度相似的样本(如文本长度)分到同一个batch,减少padding带来的计算浪费。
传统方法(如纯按长度分桶)可能导致某些GPU分到太多长样本而负载不均,而纯随机采样则可能增加padding。
2. 工作原理
(1) 按模态拆分样本
-
将数据集按模态(如纯文本、文本+图像、纯图像等)分成多个子集,因为一般带图片的对话和纯文本的圣诞长度相差很大。
-
例如:
-
子集A:仅文本(无图像)
-
子集B:文本+图像
-
def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
modality_lengths = []
for example in self.examples:
is_multimodal = "image" in example
n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
return modality_lengths
multimodal_indices, multimodal_lengths = zip(
*[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal]
)
# Handle Special Case –> no "unimodal" inputs
unimodal_split = [
(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal
]
(2) 均衡分配批次
-
从每个子集的按需采样,确保:
-
不同GPU分配的batch总计算量(如文本token总数+图像像素总数)均衡。比如我们可以用以下代码将batch sz = 128的数据分到两个gpu上:
-
@staticmethod
def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]:
"""Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank."""
assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!"
# Establish initial buckets, capacities, and max number of elements per bucket
n_examples_per_bucket = len(batch_idxs) // n_buckets
bucket_indices = [[] for _ in range(n_buckets)]
bucket_lengths = [0 for _ in range(n_buckets)]
# Note that `batch_idxs` is already sorted by corresponding length (in descending order)
for idx in batch_idxs:
shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths))
bucket_indices[shortest_bucket_idx].append(idx)
# Update `bucket_lengths` –> set length to infinity if at capacity!
bucket_lengths[shortest_bucket_idx] += idx2lengths[idx]
if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket:
bucket_lengths[shortest_bucket_idx] = float("inf")
return bucket_indices
mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs]
mm_length_bucketed_idxs = [
self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs
]
# flatten
mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket]
分完后的结果如下:
Rank 0:128 125 124 121 …
Rank 1:127 126 123 122 …
这样就实现了负载均衡。
最后拉平,变成了:[128, 125, 124, 121, 127, 126, 123, 122].
(3) 长度相近的样本分组
- Pytorch将一个batch的数据分发到不同的GPU上,默认是按rank顺序发的,这样就造成了实际长度相近的样本并未分在一组。比如[100,99,2,1]:
-
Rank 0:100,2 → 需要 98 个pad token
-
Rank 1:99, 1 → 需要 98 个pad token
-
结果是组内长度差别过大,不仅显存的利用率不高,也会减缓训练速度。
-
-
如果是先按每个rank需要处理的数据量reshape,再重新分配:
-
Rank 0:100, 99→ 需要 1个pad token
-
Rank 1: 2,1→ 需要 1 个pad token
-
# naive way: without reshape
# optimized way: reshape and slicing, here per_replica_batch_size = 2
per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size)
# continuous chunk within the mini-batch
replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas]
这里rank 0设备由于要处理长得多的序列(100 vs 2)而耗时明显更长,导致设备间不平衡和空闲等待。
实际上,在我们上一步先按长度均分后,rank 1中的最大长度将是当前batch中第二长的,各设备间的负载会很平衡(128 vs 127)。
建议:把最长的批次放在最前面,这样如果会出现OOM,就能快速失败。
🎲 仍然足够随机
因为划分batch之前会先随机打乱id,batch之间还是会有一定的随机性。能保证每个batch的数据量大致差不多。
3. 优势
-
减少Padding:同一batch内样本长度相近,显存利用率高。
-
负载均衡:通过拆分模态和动态调度,平衡多GPU的计算量。
-
灵活性:支持混合模态(如部分样本有文本无图像)
📚 参考文献
LLaVA Trainer 视觉-语言对齐与指令微调训练逻辑的核心实现
HuggingFace Transformers – Trainer Utilities 高效训练与分布式批处理的PyTorch核心工具集
OpenVLA – Batching Utilities 多模态token对齐任务中使用的定制化批处理逻辑
评论前必须登录!
注册