分布式深度学习¶
分布式训练将计算分散到多个 GPU 和机器上,以训练那些单设备无法容纳或训练过慢的模型。本文件涵盖混合精度、数据并行、模型并行、流水线并行、ZeRO、FSDP、张量并行以及 all-reduce 等通信原语——这些对于大规模训练大型语言模型至关重要。
-
在单个 GPU 上训练大型神经网络最终会遇到瓶颈。模型可能无法放入内存,或者训练可能需要数月时间。分布式训练将工作负载分散到多个设备(GPU、TPU 或整台机器)上,以实现更快的训练和更大的模型。本文件介绍了实现这一目标的技术。
-
要理解为什么分布式如此重要,首先要了解训练的计算成本。对于一个具有 \(d_{\text{in}}\) 个输入、\(d_{\text{out}}\) 个输出的稠密层,在一个批次大小为 \(B\) 的样本上进行一次前向传播需要大约 \(2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}\) 次浮点运算(FLOPs):输出矩阵的每个元素对应一次乘法和一次加法。反向传播的成本大约是前向传播的两倍(计算关于输入和权重的梯度),因此稠密层的一个训练步骤大约需要 \(6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}\) 次浮点运算。
-
对于一个隐藏维度为 \(d\) 的 Transformer 层,自注意力模块包含四个投影(Q、K、V 和输出),每个投影的成本为 \(O(B \cdot n \cdot d^2)\) 次浮点运算(其中 \(n\) 是序列长度),再加上注意力矩阵计算 \(O(B \cdot n^2 \cdot d)\)。前馈模块有两个稠密层,通常先扩展到 \(4d\) 再返回:\(O(B \cdot n \cdot 8d^2)\)。每层总计约为 \(O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)\)。再乘以层数,你就明白为什么训练 GPT 规模的模型需要数千个 GPU 小时。
-
内存墙通常是更严格的限制。在训练过程中,GPU 内存必须同时容纳四样东西:
- 参数:模型权重。一个 70 亿参数的模型,如果用 FP32(每参数 4 字节)存储,仅权重就需要 28 GB。
- 梯度:与参数大小相同。又是 28 GB。
- 优化器状态:Adam 维护两个额外的缓冲区(一阶和二阶矩估计),每个都与参数大小相同。即使模型使用较低精度,这些缓冲区也以 FP32 形式保持以保证数值稳定性。对于我们的 70 亿模型,那就是 \(2 \times 28 = 56\) GB。
-
激活值:前向传播过程中保存的用于反向传播的中间值。大小取决于批次大小、序列长度和模型宽度。这通常是最大的组成部分,并且随批次大小线性增长。
-
对于我们的 70 亿模型,使用 FP32 Adam,仅参数、梯度和优化器状态就需要 28 + 28 + 56 = 112 GB,这还没算激活值。单个 80 GB 的 A100 GPU 无法容纳这些。这就是为什么分布式策略必不可少。
-
混合精度训练是第一道防线。不是将所有内容都存储在 FP32(32 位浮点数)中,而是在前向和后向传播中使用 FP16 或 BF16(16 位),同时保留一份 FP32 权重的主副本用于优化器更新。
-
FP16 具有高精度(10 位尾数),但数值范围有限,可能导致溢出/下溢。损失缩放(在反向传播前将损失乘以一个大因子,然后将梯度除以同一因子)可以缓解这个问题。
-
BF16(脑浮点数)具有与 FP32 相同的指数范围(8 位指数),但精度较低(7 位尾数)。它几乎从不溢出,很少需要损失缩放,因此使用更简单。BF16 是现代 Transformer 训练的默认精度。
-
混合精度大致将激活值和梯度的内存(前向/反向传播中的主要开销)减半,同时优化器状态仍使用 FP32 以保证数值稳定性。
-
数据并行是最简单的分布式策略。你将完整模型复制到 \(N\) 个 GPU 上,将每个小批量分成 \(N\) 等份,然后将一份发送给每个 GPU。每个 GPU 独立对其分块运行前向和后向传播。然后跨所有 GPU 对梯度进行平均(使用 all-reduce 操作),每个 GPU 更新其本地模型副本。
-
从模型的角度来看,这等效于使用 \(N\) 倍大的小批量进行训练。如果每个 GPU 处理批次大小为 \(B\),则有效批次大小为 \(N \cdot B\)。
-
梯度平均可以同步或异步进行。同步 SGD 等待所有 GPU 完成后再进行平均,确保了与使用更大批量的单 GPU 训练数学等价。缺点是速度最慢的 GPU(“掉队者”)会拖累所有人。
-
异步 SGD 允许每个 GPU 独立地更新一个共享参数服务器,无需等待。这消除了掉队者问题,但引入了“陈旧梯度”:一个 GPU 可能基于稍微过时的参数计算梯度。陈旧梯度会增加噪声并可能减慢收敛速度。在实践中,具有高效通信的同步 SGD 更受青睐。
-
梯度累积是一种软件技巧,用于在有限硬件上模拟更大的批次大小。它不是每小批量做一次更新,而是运行多次前向/后向传播并累积梯度,然后做一次更新。这在不增加激活值内存需求(每次只有一个小批量的激活值在内存中)的情况下,给出了与更大批次相同的结果。
-
当模型本身太大而无法放入单个 GPU 时,就需要模型并行。主要有两种类型。
-
张量并行将单个层分割到多个 GPU 上。一个大的矩阵乘法 \(Y = XW\) 可以按列分割:将 \(W\) 划分为 \([W_1, W_2]\) 并放到两个 GPU 上,并行计算 \(Y_1 = XW_1\) 和 \(Y_2 = XW_2\),然后拼接。这适用于注意力投影和前馈层。它需要 GPU 之间快速的通信(通常使用节点内的 NVLink),因为每层都要合并部分结果。
-
流水线并行将不同的层分配给不同的 GPU。GPU 0 运行第 1-4 层,GPU 1 运行第 5-8 层,依此类推。数据像流水线一样流过这些阶段。朴素方法存在“流水线气泡”:当 GPU 0 处理微批次 1 的前向传播时,GPU 1-3 处于空闲状态。微批次通过将小批量分割成更小的微批次,让它们依次流经流水线,从而缓解了这个问题,使大多数时间所有 GPU 都保持忙碌。
-
混合并行结合了数据、张量和流水线并行。一个典型的大模型设置可能会在节点内(8 个通过高速 NVLink 连接的 GPU)使用张量并行,跨节点使用流水线并行,跨节点组使用数据并行。这就是 GPT-4 和 Llama 等模型的训练方式。
-
分布式训练的效率很大程度上取决于通信。关键操作是 all-reduce:给定 \(N\) 个 GPU 上各自有一个值,计算它们的和(或平均值),并将结果分发给所有 GPU。
-
朴素的 all-reduce 将所有数据发送到一个 GPU,求和,然后广播回去。这种方式的通信量为 \(O(N)\),并在根节点处造成瓶颈。
-
环形 all-reduce 高效得多。将 \(N\) 个 GPU 排列成一个环。每个 GPU 将其数据分成 \(N\) 块。在 \(N-1\) 步中,每个 GPU 向邻居发送一块,并从另一个邻居接收一块,累积部分和。再经过 \(N-1\) 步,完整的总和被传播到所有 GPU。每个 GPU 传输的总数据量为 \(2(N-1)/N\) 倍的数据大小,当 \(N\) 增长时趋近于 \(2\times\)。关键是,这不会随 \(N\) 增加而增加,因此是带宽最优的。
-
参数服务器是一种替代架构,其中专用的服务器节点保存模型参数。工作节点计算梯度并将其发送到服务器,服务器更新参数并将其发送回工作节点。这更简单,但可能在服务器处造成通信瓶颈。
-
NCCL(NVIDIA 集体通信库)是 GPU 间通信的标准库。它提供了 all-reduce、all-gather、broadcast 等集体操作的高效实现,并会根据网络拓扑自动选择最佳算法。
-
缩放定律描述了模型性能如何随计算量、数据量和模型规模而提升。最初的 Kaplan 等人(2020)的缩放定律发现,损失与每个因素呈幂律关系:
-
其中 \(N\) 是参数量,\(D\) 是数据集大小,\(C\) 是计算预算。
-
Chinchilla 缩放定律(Hoffmann 等人,2022)表明大多数模型训练不足:在给定的计算预算下,应该训练一个更小的模型,使用比以前认为的更多的数据。最优比例大约是每个参数 20 个 token。一个 70 亿参数的模型应该看到约 1400 亿 token,而不是 Llama 1 在 650 亿模型上使用的 3000 亿 token。这一发现将领域转向了“计算最优”训练。
-
混合专家(MoE) 是一种架构,可以在不按比例增加计算量的情况下扩展模型容量。在每层 Transformer 中,你不是只有一个前馈网络,而是有 \(N\) 个“专家”网络(每个都是标准的前馈网络)。一个门控网络(路由器)检查每个 token 并将其发送给排名前 \(K\) 的专家(通常 \(K=1\) 或 \(K=2\))。
-
总参数量要大得多(因为有 \(N\) 个专家),但每个 token 的浮点运算量大致保持不变(因为每个 token 只激活 \(K\) 个专家)。例如,Mixtral 8x7B 总共有 470 亿参数,但每次前向传播仅使用约 130 亿,以较小模型的计算成本提供了更大模型的性能。
-
MoE 带来了挑战。负载均衡:如果路由器将大多数 token 发送到同一个专家,其他专家就被浪费了。一个辅助损失函数鼓励均匀路由。通信:不同的专家可能位于不同的 GPU 上,因此路由 token 需要 all-to-all 通信,这很昂贵。
-
容错性在训练持续数周或数月、涉及数千个 GPU 时至关重要。如果单个 GPU 发生故障,你不会希望丢失所有进度。检查点定期将模型权重、优化器状态和训练状态(学习率、步数、数据位置)保存到磁盘。如果发生故障,你从最后一个检查点重新开始。
-
梯度检查点(也称为激活重计算)是一种内存优化技术,不是容错机制。在前向传播过程中,不保存所有用于反向传播的激活值,只保存某些检查点位置。在反向传播过程中,从检查点重新计算缺失的激活值。这是用计算换内存:它使前向传播成本增加约 33%,但可以将激活值内存减少 \(\sqrt{L}\) 倍(其中 \(L\) 是层数)。
-
综上所述,训练前沿模型结合了所有这些技术:BF16 混合精度、跨数千个 GPU 使用环形 all-reduce 的数据并行、节点内的张量并行、跨节点的流水线并行、减少内存的梯度检查点、提高参数效率的 MoE 以及用于容错的常规检查点。系统工程与算法设计同样具有挑战性。
-
分布式训练工具包总结:
| 技术 | 作用 | 权衡 |
|---|---|---|
| 混合精度 (BF16) | 将激活值/梯度内存减半 | 轻微的数值差异 |
| 数据并行 | 跨 GPU 扩展批次大小 | 梯度同步的通信开销 |
| 张量并行 | 将层拆分到多个 GPU | 需要高速互连 |
| 流水线并行 | 将模型阶段拆分到多个 GPU | 流水线气泡(浪费计算) |
| 梯度累积 | 模拟大批次 | 更慢(多次前向/后向传播) |
| 梯度检查点 | 减少激活值内存 | 增加约33%的计算量 |
| 环形 all-reduce | 高效的梯度平均 | 大模型时受带宽限制 |
| 混合专家 (MoE) | 更大容量,相同浮点运算量 | 负载均衡,路由复杂度 |
| 缩放定律 | 指导计算分配 | 经验性的,可能不适用于所有规模 |
编码任务(使用 CoLab 或 notebook)¶
-
计算一个 Transformer 层的浮点运算量和内存需求。给定隐藏维度 \(d\)、序列长度 \(n\)、批次大小 \(B\) 和层数,估算总训练成本。
import jax.numpy as jnp def transformer_layer_flops(d, n, B): """近似单层 Transformer 前向传播的浮点运算量。""" # QKV 投影:3 * (B * n * d * d) * 2(乘加) qkv_flops = 3 * 2 * B * n * d * d # 注意力:QK^T 需 (B * n * n * d) * 2,attn*V 需 (B * n * n * d) * 2 attn_flops = 2 * 2 * B * n * n * d # 输出投影:(B * n * d * d) * 2 out_flops = 2 * B * n * d * d # 前馈网络:两层,d->4d 和 4d->d:2 * (B * n * d * 4d) * 2 ffn_flops = 2 * 2 * B * n * d * 4 * d return qkv_flops + attn_flops + out_flops + ffn_flops def transformer_layer_memory(d, n, B, dtype_bytes=2): """近似每层激活值内存(字节)。""" # QKV: 3 * B * n * d qkv_mem = 3 * B * n * d * dtype_bytes # 注意力权重: B * heads * n * n (近似 B * n * n * sizeof) attn_mem = B * n * n * dtype_bytes # 前馈网络中间值: B * n * 4d ffn_mem = B * n * 4 * d * dtype_bytes return qkv_mem + attn_mem + ffn_mem # 以 GPT-2 规模为例 d, n, B, L = 1024, 1024, 8, 24 fwd_flops = transformer_layer_flops(d, n, B) total_flops = 3 * L * fwd_flops # 3倍用于前向+后向 act_mem = L * transformer_layer_memory(d, n, B) param_count = L * (12 * d * d + 13 * d) # 近似 print(f"模型: d={d}, n={n}, B={B}, L={L}") print(f"参数量: {param_count / 1e6:.0f}M") print(f"每步浮点运算量: {total_flops / 1e12:.2f} TFLOPs") print(f"激活值内存: {act_mem / 1e9:.2f} GB (BF16)") print(f"参数内存 (FP32): {param_count * 4 / 1e9:.2f} GB") print(f"Adam 优化器内存: {param_count * 8 / 1e9:.2f} GB") print(f"训练总内存: {(param_count * 16 + act_mem) / 1e9:.2f} GB") -
模拟数据并行训练。将数据集拆分到多个“虚拟 GPU”上,独立计算梯度,取平均,并验证结果与单 GPU 训练一致。
import jax import jax.numpy as jnp # 简单线性模型: y = wx + b key = jax.random.PRNGKey(0) X = jax.random.normal(key, (64, 4)) w_true = jnp.array([1.0, -2.0, 3.0, 0.5]) y = X @ w_true + 0.1 * jax.random.normal(key, (64,)) def loss_fn(w, X, y): return jnp.mean((X @ w - y) ** 2) grad_fn = jax.grad(loss_fn) # 单 GPU: 全批量梯度 w = jnp.zeros(4) grad_single = grad_fn(w, X, y) # 数据并行: 拆分到 4 个 "GPU" n_gpus = 4 chunk_size = len(X) // n_gpus grads = [] for i in range(n_gpus): X_chunk = X[i*chunk_size:(i+1)*chunk_size] y_chunk = y[i*chunk_size:(i+1)*chunk_size] grads.append(grad_fn(w, X_chunk, y_chunk)) # 全部归约: 平均梯度 grad_parallel = jnp.mean(jnp.stack(grads), axis=0) print("单 GPU 梯度:", grad_single) print("数据并行梯度 (平均):", grad_parallel) print(f"匹配: {jnp.allclose(grad_single, grad_parallel, atol=1e-5)}") # 训练两者并比较 w_single, w_parallel = jnp.zeros(4), jnp.zeros(4) lr = 0.1 for step in range(100): w_single = w_single - lr * grad_fn(w_single, X, y) grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size], y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)] avg_grad = jnp.mean(jnp.stack(grads), axis=0) w_parallel = w_parallel - lr * avg_grad print(f"\n经过 100 步后:") print(f"单 GPU 权重: {w_single}") print(f"数据并行权重: {w_parallel}") print(f"最大差异: {jnp.max(jnp.abs(w_single - w_parallel)):.2e}") -
实现一个简单的混合专家层。创建一个门控网络,将 token 路由到 top-K 专家,并组合他们的输出。
import jax import jax.numpy as jnp def expert_fn(x, W1, b1, W2, b2): """简单的两层前馈专家网络。""" h = jnp.maximum(0, x @ W1 + b1) # ReLU return h @ W2 + b2 def moe_layer(x, gate_W, experts_params, top_k=2): """ MoE 前向传播。 x: (batch, d_model) gate_W: (d_model, n_experts) experts_params: 每个专家的 (W1, b1, W2, b2) 列表 """ n_experts = len(experts_params) # 门控:计算路由分数 gate_logits = x @ gate_W # (batch, n_experts) gate_probs = jax.nn.softmax(gate_logits, axis=-1) # 选择 Top-K top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k] top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1) # 重新归一化 top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True) # 计算专家输出(简化:运行所有专家,之后掩码) expert_outputs = jnp.stack([ expert_fn(x, *experts_params[i]) for i in range(n_experts) ], axis=1) # (batch, n_experts, d_model) # 收集 Top-K 专家输出并加权 batch_idx = jnp.arange(x.shape[0])[:, None] selected_outputs = expert_outputs[batch_idx, top_k_indices] # (batch, top_k, d_model) output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1) return output, gate_probs # 设置 key = jax.random.PRNGKey(42) batch, d_model, d_ff, n_experts = 8, 16, 32, 4 # 初始化专家 experts_params = [] for i in range(n_experts): k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2] experts_params.append(( jax.random.normal(k1, (d_model, d_ff)) * 0.1, jnp.zeros(d_ff), jax.random.normal(k2, (d_ff, d_model)) * 0.1, jnp.zeros(d_model), )) key, subkey = jax.random.split(key) gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1 x = jax.random.normal(key, (batch, d_model)) output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"门控概率(第一个样本): {gate_probs[0]}") print(f"专家使用率(跨批次平均):") for i in range(n_experts): usage = jnp.mean(gate_probs[:, i]) print(f" 专家 {i}: {usage:.3f}")