Skip to content

高级文本生成

高级文本生成超越了普通的自回归解码,旨在提高质量、可控性和速度。本文件涵盖文本扩散模型(D3PM、MDLM)、光学字符识别(OCR)、用于对齐的RLHF和DPO、长上下文方法(RoPE缩放、环注意力)、检索增强生成以及用于更快推理的推测解码。

  • 标准的自回归生成(文件04)从左到右一次生成一个标记。这简单有效,但本质上是顺序的,不允许全局规划,并且对输出的控制有限。本文件涵盖超越普通自回归解码的方法:文本扩散模型、光学字符识别、通过人类反馈实现可控生成、处理长上下文、检索增强生成以及用于更快推理的推测解码。

  • 文本扩散模型将(第08章为图像引入的)扩散框架应用于离散文本。核心挑战在于文本是离散的:你不能像对像素添加噪声那样对标记添加连续的高斯噪声。有几种方法可以解决这个问题。

  • D3PM(离散去噪扩散概率模型,Austin等人,2021)使用转移矩阵直接在离散标记上定义前向损坏过程。在每个前向步骤中,一个标记有一定概率被另一个标记替换(均匀噪声)、被掩码(吸收状态)或保持不变。反向过程学习去噪,从损坏的标记预测干净的标记。步骤\(t\)的转移矩阵\(Q_t\)控制损坏:

\[q(x_t \mid x_{t-1}) = \text{Cat}(x_t ; \, x_{t-1} Q_t)\]
  • 其中\(\text{Cat}\)表示类别分布,\(x\)是一个独热向量。多步前向过程\(q(x_t \mid x_0)\)有一个闭式解:\(q(x_t \mid x_0) = \text{Cat}(x_t ; \, x_0 \bar{Q}_t)\),其中\(\bar{Q}_t = Q_1 Q_2 \cdots Q_t\)是到步骤\(t\)为止所有转移矩阵的乘积。训练最小化一个变分下界,该下界跨时间步分解,类似于连续情况(第08章):
\[\mathcal{L}_{\text{D3PM}} = D_{\text{KL}}(q(x_T \mid x_0) \| p(x_T)) + \sum_{t=2}^{T} D_{\text{KL}}(q(x_{t-1} \mid x_t, x_0) \| p_\theta(x_{t-1} \mid x_t)) - \log p_\theta(x_0 \mid x_1)\]
  • 第一项确保完全损坏的分布与先验(均匀或全掩码)匹配。KL项的总和训练模型逆转每个损坏步骤:真实的后验\(q(x_{t-1} \mid x_t, x_0)\)可以使用贝叶斯规则和已知的转移矩阵以闭式形式计算,模型\(p_\theta(x_{t-1} \mid x_t)\)被训练去匹配它。

  • 由于两个分布都是类别分布,KL散度是对词汇表条目求和。最后一项衡量从最少损坏状态的重建质量。

  • MDLM(掩码扩散语言模型,Sahoo等人,2024)通过使用掩码作为唯一的损坏操作简化了D3PM:前向过程逐渐将标记替换为[MASK]标记,反向过程预测原始标记。这将文本扩散与掩码语言建模(BERT,文件04)联系起来,扩散时间步控制掩码的标记比例。在\(t=0\)时,文本完全干净;在\(t=T\)时,文本完全被掩码。

  • 连续文本扩散通过在连续嵌入空间中工作来回避离散问题。标记首先被映射到它们的嵌入向量(第06章),在这个连续空间中添加噪声,一个去噪模型(通常是Transformer)学习逆转这个过程。在生成时,模型产生连续向量,通过找到最近的嵌入将其映射回离散标记。挑战在于连续空间中的小误差可能导致完全错误的标记,因此需要仔细的舍入和限制。

文本扩散过程

  • 文本扩散的吸引力在于它通过迭代细化同时生成所有标记,而不是从左到右。这允许全局连贯性和轻松填充(在段落中间生成缺失的文本),但目前文本扩散模型在长文本的生成质量上仍然落后于自回归模型。

  • 文本OCR(光学字符识别)是从图像中提取文本的任务。虽然传统上不属于语言生成,但现代OCR系统与NLP深度集成,并越来越多地使用语言模型组件。

  • 场景文本检测定位自然图像中的文本区域(街道标志、产品标签、车牌)。这很有挑战性,因为野外文本以任意角度、尺度、字体出现,并且背景杂乱。检测方法通常使用CNN或Transformer主干来生成文本区域周围的边界框或分割掩码。

  • CRNN(卷积循环神经网络,Shi等人,2017)是一种经典的文本识别架构。CNN从文本图像中提取视觉特征,特征图被切割成一列序列(每个水平位置一列),一个双向LSTM读取此序列以建模上下文。输出使用CTC(连接时序分类)解码,它处理输入列和输出字符之间的对齐,无需显式分割。

  • CTC解决的基本问题:模型产生\(T\)个输出分布(每个输入列一个),但目标文本有\(L \leq T\)个字符。

  • 我们不知道哪些列对应哪些字符。CTC引入一个空白标记\(\epsilon\)并定义一个多对一的映射\(\mathcal{B}\),该映射折叠重复字符并移除空白:\(\mathcal{B}(\text{"HH-ee-ll-ll-oo"}) = \text{"Hello"}\)(其中“-”是空白)。

  • 目标序列\(y\)的概率是坍缩到\(y\)的所有输入对齐路径的概率之和:

\[P(y \mid x) = \sum_{\pi \in \mathcal{B}^{-1}(y)} \prod_{t=1}^{T} P(\pi_t \mid x)\]
  • 其中\(\pi\)是一个长度为\(T\)的对齐路径(每列一个标签,包括空白)。对所有路径朴素求和是指数级的,但前向算法(第05章HMM)使用动态规划在\(O(T \cdot L)\)时间内高效计算此总和。

  • 空白标记是必不可少的:没有它,像“Hello”中重复的“ll”将无法与单个“l”区分开。训练最大化\(\log P(y \mid x)\),在推理时,通过在CTC输出上进行集束搜索或贪婪解码找到最佳路径。

  • 文档OCR处理结构化文档(发票、表格、科学论文),并且除了识别字符外还必须理解布局。像LayoutLM这样的现代系统将文本识别与空间位置特征相结合:每个标记同时获得其文本嵌入和编码其在页面上\((x, y)\)坐标的位置嵌入。这允许模型理解出现在“总计:”下方的数字就是总金额。

CRNN OCR流水线

  • 视觉-语言OCR模型(如TrOCR)将文本识别视为图像到文本的生成:一个视觉Transformer编码器处理图像,一个语言模型解码器逐个字符生成文本。这利用了预训练视觉和语言模型的能力,无需手工特征工程即可处理多样的文字、字体和布局。

  • 可控生成是引导语言模型产生具有期望属性(特定风格、主题、情感、安全级别或事实准确性)输出的挑战。模型应在保持流畅和连贯的同时遵循指令。

  • 文本的无分类器引导(CFG) 改编自图像生成的技术。在训练期间,条件信号(例如提示)以一定比例随机丢弃,同时训练有条件和无条件模型。在推理时,输出logits被插值:

\[\text{logits}_{\text{guided}} = (1 + w) \cdot \text{logits}_{\text{conditional}} - w \cdot \text{logits}_{\text{unconditional}}\]
  • 其中\(w > 0\)放大条件的影响。较高的\(w\)使输出更强烈地遵循提示,但会降低多样性。

  • RLHF(基于人类反馈的强化学习,Ouyang等人,2022)是将语言模型与人类偏好对齐的主要方法。该过程分三个阶段:

  • 首先,监督微调:在高质量人工编写的提示-响应对数据集上微调基础语言模型。

  • 其次,奖励模型训练:收集人类比较(给定提示\(x\)和两个响应\(y_1, y_2\),哪个更好?)并训练奖励模型\(r_\phi(x, y)\)来预测人类偏好。奖励模型使用成对排序损失进行训练:

\[\mathcal{L}_{\text{RM}} = -\log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))\]
  • 其中\(y_w\)是被偏好的响应,\(y_l\)是不被偏好的响应。

  • 第三,强化学习微调:优化语言模型以最大化奖励,同时保持接近SFT模型(以防止模式坍缩)。这使用带有KL惩罚的PPO(近端策略优化,来自第06章):

\[\mathcal{L}_{\text{RL}} = -\mathbb{E}\left[r_\phi(x, y) - \beta \, D_{\text{KL}}(\pi_\theta \| \pi_{\text{SFT}})\right]\]
  • KL项防止模型偏离基础模型太远并利用奖励模型的怪癖(“奖励破解”)。

RLHF流水线

  • DPO(直接偏好优化,Rafailov等人,2023)通过完全消除奖励模型来简化RLHF。关键的数学洞察是上述KL约束的RL目标有一个闭式最优策略:
\[\pi^\ast(y \mid x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y \mid x) \exp\!\left(\frac{r(x, y)}{\beta}\right)\]
  • 其中\(Z(x)\)是归一化配分函数。将其重新排列为奖励形式给出\(r(x, y) = \beta \log \frac{\pi^\ast(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \beta \log Z(x)\)。将这个隐式奖励代入Bradley-Terry偏好模型\(P(y_w \succ y_l) = \sigma(r(x, y_w) - r(x, y_l))\),导致难以处理的\(Z(x)\)项抵消,直接产生DPO损失:
\[\mathcal{L}_{\text{DPO}} = -\log \sigma\!\left(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)}\right)\]
  • 这在数学上等价于RLHF,但将奖励模型和RL训练压缩成一个监督步骤。

  • sigmoid内部的表达式可以解读为:“增加偏好响应的相对概率,降低不偏好响应的相对概率,以参考模型为基准衡量。”

  • \(\beta\)参数控制策略可以偏离参考的程度。在实践中,DPO实现更简单(仅计算当前模型和参考模型下两个补全的对数概率),并避免了PPO训练的不稳定性。

  • 宪法式AI(Bai等人,2022)自动化了对齐过程的部分。它不收集人类比较,而是使用语言模型自身根据一组原则(“宪法”)来批评和修订自己的输出,例如“选择伤害较小的响应”。然后使用AI生成的比较进行偏好训练(RLAIF:来自AI反馈的RL)。

  • 长上下文方法解决了标准自注意力的\(O(n^2)\)内存和计算成本,这限制了序列长度。当\(n\)增长到数万或数十万标记时,标准注意力变得不可行。

  • 稀疏注意力用稀疏模式替换密集的\(n \times n\)注意力矩阵,其中每个标记只关注其他标记的子集。常见模式包括局部注意力(每个标记关注固定大小的邻居窗口)、步长注意力(关注每第\(k\)个标记)和随机注意力(关注随机子集)。这些模式的组合(用于BigBird、Longformer)实现了\(O(n)\)\(O(n \sqrt{n})\)复杂度,同时保持捕获局部和全局依赖关系的能力。

稀疏注意力模式

  • 滑动窗口注意力限制每个标记仅关注前\(w\)个标记(其局部窗口)。这是\(O(nw)\)而非\(O(n^2)\),但长距离信息必须通过层间的重叠窗口传播。使用\(L\)层和窗口大小\(w\),有效感受野为\(L \times w\)个标记。

  • 环注意力通过将长序列以环形拓扑分布在多个设备上。每个设备持有一块序列,并计算其块的注意力,同时向环中的下一个设备发送键值块。这使计算与通信重叠,并允许任意长度的序列,其限制仅由所有设备的总内存决定,而非单个设备的内存。

  • 记忆增强模型通过为Transformer配备外部记忆库来扩展上下文。在每一层,模型可以使用注意力从该记忆中读取和写入。记忆化Transformer缓存来自先前块的键值对,并在后续块中关注它们,有效地将上下文扩展到训练窗口之外。检索是近似进行(使用缓存键的\(k\)-最近邻)以保持高效。

  • 上述方法是长上下文的架构解决方案。同样重要的是如何训练模型以有效利用长上下文。

  • 渐进式上下文扩展是标准方法。从一开始就在非常长的序列上训练成本高得令人望而却步(\(O(n^2)\)注意力成本),因此模型在较短的上下文长度(通常4K-8K标记)上进行预训练,然后继续预训练分阶段扩展到目标长度。

  • Llama 3.1 在8000亿标记上将上下文从8K扩展到128K,序列长度逐渐增加。DeepSeek-V3 在4K上训练,然后扩展到32K,再到128K。

  • 每个阶段使用的标记数量相对较少(相对于完整的预训练预算),因为模型只需要学习如何使用更长的位置,而不是重新学习语言本身。

  • 在扩展期间必须调整位置编码。RoPE插值缩小位置索引,使得模型看到与训练时相同的旋转角度,只是分布在更长的序列上。如果模型在长度\(L\)上训练,你想扩展到\(L' = 4L\),你将所有位置索引除以4。

  • 这意味着模型永远不会看到它未曾遇到的旋转角度,但相邻位置之间的有效分辨率会下降。

  • RoPE外推保持原始位置索引不变,只是将RoPE应用于超过\(L\)的位置,依赖模型泛化到未见过的角度。

  • 插值要稳定得多;外推如果不调整基频(ABF)会迅速退化。

  • YaRN(又一个RoPE扩展)通过认识到并非所有RoPE维度都应同等对待来改进朴素插值。

  • 高频维度(\(\theta_i = \theta_{\text{base}}^{-2i/d}\)中较小的\(i\))在训练长度内旋转多次,可以很好地外推。

  • 低频维度(较大的\(i\))旋转缓慢,对长度扩展更敏感。

  • YaRN仅插值低频维度,外推高频维度,并对注意力logits应用温度缩放\(t\)以补偿分布偏移:

\[\text{score}'_{ij} = \frac{q_i^T k_j}{t \sqrt{d_k}}\]
  • 其中\(t > 1\)使注意力分布平坦化,防止模型在位置信号被压缩时过度关注附近的标记。

  • 长上下文数据整理是一个关键且常被低估的挑战。大多数预训练语料库由短文档(新闻文章、网页、社交媒体帖子)组成。

  • 长上下文训练需要一个能实际锻炼完整上下文窗口的数据混合:书籍、代码仓库、长篇科学文章、多轮对话日志以及主题相关的串联文档。

  • 如果模型仅在填充或打包以填满上下文窗口的短文档上训练,它会学会忽略远处的标记,因为它们从不相关。

  • 序列打包是一种训练效率技术:多个文档连接成一个训练序列以避免填充浪费,注意力掩码防止跨文档注意力。

  • 对于长上下文训练,打包策略很重要:打包许多不相关的短文档教会模型远处的标记是噪声,而打包更少的、真正长的文档教会它使用完整上下文。

  • 一个已知的失败模式是“中间丢失”现象(Liu等人,2023):语言模型倾向于有效使用上下文窗口开头和结尾的信息,但难以处理放在中间的信息。

  • 这类似于人类记忆中的序列位置效应(首因效应和近因效应)。

  • 其部分原因是训练数据分布(重要信息通常在文档的开头或结尾),部分原因是注意力模式集中在附近和初始的标记上。

  • 使用关键信息不同放置位置的长上下文训练可以缓解但不能完全解决这个问题。

  • 大海捞针评估测试模型是否能够从长干扰上下文(“干草堆”)中检索放置在各个位置的特定事实(“针”)。

  • 一个具有真正长上下文能力的模型无论针放在哪里都应该实现近乎完美的检索。

  • 这个测试清楚地揭示了中间丢失效应,并用于基准测试上下文扩展方法。

  • 预训练后的长上下文微调使用有针对性的SFT数据:长多轮对话、证据散布在数千个标记中的文档问答、长文本摘要以及仓库级代码理解。

  • Qwen3 在此阶段使用双块注意力(DCA),它将长序列处理成块对,其中块内注意力是完整的,块间注意力是高效的,在微调期间实现了4倍的有效序列容量。

  • 状态空间模型(SSM)提供了一种根本不同的长序列建模方法。它们不是修改注意力,而是用受连续时间控制理论启发的线性动力系统完全替换它。

  • 一个SSM通过一个潜在状态\(x(t) \in \mathbb{R}^N\)将输入序列\(u(t)\)映射到输出\(y(t)\),由以下方程控制:

\[x'(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)\]
  • 其中\(A \in \mathbb{R}^{N \times N}\)是状态转移矩阵,\(B \in \mathbb{R}^{N \times 1}\)是输入投影,\(C \in \mathbb{R}^{1 \times N}\)是输出投影,\(D\)是跳跃连接。

  • 为了将其应用于离散序列(标记),使用步长\(\Delta\)对连续系统进行离散化。零阶保持离散化给出:

\[\bar{A} = \exp(\Delta A), \quad \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B\]
  • 然后离散递归变为\(x_k = \bar{A} x_{k-1} + \bar{B} u_k\)\(y_k = C x_k + D u_k\),看起来像一个RNN:用隐藏状态一次处理一个标记。

  • 与RNN不同,这个递归也可以展开为一个全局卷积:因为系统是线性的,输出是\(y = \bar{K} \ast u\),其中核\(\bar{K} = (C\bar{B}, \, C\bar{A}\bar{B}, \, C\bar{A}^2\bar{B}, \ldots)\)仅依赖于固定参数。

  • 这种双重观点——递归用于高效自回归推理(每步\(O(1)\))和卷积用于高效并行训练(通过FFT实现\(O(n \log n)\))——是SSM的核心洞察。

SSM双重观点:用于推理的递归,用于训练的卷积,以及Mamba的选择性扩展

  • S4(结构化状态空间序列建模,Gu等人,2022)通过解决关键数值挑战使SSM实用:状态矩阵\(A\)必须捕获长距离依赖,但朴素参数化会导致动态消失或爆炸(与普通RNN相同的问题)。

  • S4使用HiPPO(高阶多项式投影算子)矩阵初始化\(A\),该矩阵源于连续信号最优多项式逼近理论。HiPPO矩阵具有特定结构,可证明使状态能够以优雅衰减的方式维护整个输入历史的压缩表示:

\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]
  • 这种下三角结构确保状态使用Legendre多项式充当输入信号的在线逼近。计算长核的\(\bar{A}^k\)代价高昂,因此S4利用HiPPO矩阵可以分解为低秩项和对角项之和这一事实,实现了\(O(n \log n)\)的核计算。

  • Mamba(Gu和Dao,2023)引入了选择性状态空间的关键创新:使SSM参数依赖于输入。在S4中,矩阵\(A\)\(B\)\(C\)和步长\(\Delta\)是固定的——无论内容如何,相同的动态应用于每个标记。Mamba使\(B\)\(C\)\(\Delta\)成为输入的函数:

\[B_k = \text{Linear}(u_k), \quad C_k = \text{Linear}(u_k), \quad \Delta_k = \text{softplus}(\text{Linear}(u_k))\]
  • 这种选择性允许模型在每个位置决定在状态中存储什么信息以及忽略什么——类似于注意力选择相关标记的方式,但没有二次成本。步长\(\Delta_k\)控制“门”:大的\(\Delta\)导致状态强烈整合当前输入(连续动态向前迈出一大步,有效地重置状态),而小的\(\Delta\)则保留现有状态并忽略当前输入。

  • 权衡是依赖输入的参数破坏了卷积视图(核不再是固定的),因此Mamba不能使用基于FFT的训练。相反,它使用硬件感知的并行扫描算法,利用递归的结合性:状态更新\((x_k, u_k) \mapsto x_{k+1}\)可以表示为一系列结合操作,并使用前缀和(扫描)进行并行化,类似于硬件设计中的并行前缀加法。这在GPU上以\(O(n)\)时间、\(O(\log n)\)深度运行,几乎匹配卷积的效率。

  • Mamba实现了真正每步\(O(1)\)的推理(仅更新固定大小的状态,没有随上下文增长的KV缓存),这使得在长序列长度上其内存效率从根本上优于Transformer。状态大小\(N\)(通常为16)远小于Transformer的KV缓存,后者存储\(O(n \cdot d)\)个值。在实践中,在相同的参数数量下,Mamba在语言建模基准上匹配或超过Transformer的质量,并且在长序列上推理显著更快。

  • 混合架构将SSM层与注意力层结合起来,使用SSM处理大多数层(高效长距离传播),并穿插少量注意力层(精确的基于内容的检索)。像Jamba和Zamba这样的模型交错Mamba和Transformer块,实现了比纯SSM更好的质量,同时保持了大部分推理效率优势。这表明注意力和SSM捕捉到了互补的能力:SSM擅长平滑的长距离状态传播,而注意力擅长精确的、依赖内容的查找。

  • 检索增强生成(RAG)通过在推理时让语言模型访问外部知识库来解决语言模型的知识局限。RAG不是仅仅依赖训练期间编码在模型参数中的知识,而是检索相关文档并基于它们进行生成。

  • 经典的检索器-阅读器架构有两个组件。检索器接收查询并从语料库中获取最相关的\(k\)个段落。阅读器(一个语言模型)基于查询和检索到的段落生成答案。检索器可以使用稀疏方法(BM25,它扩展了文件02中的TF-IDF)或稠密方法。

  • 稠密段落检索(DPR)使用双编码器架构:一个编码器将问题映射到向量,另一个将段落映射到向量。两者通常基于BERT。在索引时,所有段落被编码并存储。在查询时,问题被编码,并使用近似最近邻搜索(如FAISS)找到最近的段落。相似度度量是问题和段落向量之间的点积。

  • 分块策略显著影响检索质量。文档必须被分割成足够小的段落以供检索器处理,但又足够大以包含完整的思想。固定大小的分块(例如,256个标记,50个标记重叠)很简单,但可能会不自然地分割句子。语义分块在段落或章节边界处分割。层次化分块创建不同粒度的摘要树。

RAG架构

  • RAG提供了几个优点:知识库可以无需重新训练模型即可更新,模型可以引用来源,并且由于模型可以将其答案基于检索到的文本,幻觉得以减少。主要挑战是检索质量(如果检索到错误的段落,模型可能会自信地给出错误答案)和延迟(检索给推理增加了一个步骤)。

  • 推测解码通过使用一个小的、快速的草稿模型并行提出多个标记,然后由大的目标模型在单次前向传递中验证,从而加速自回归生成。

  • 算法如下:草稿模型自回归地生成\(k\)个候选标记(这很快,因为草稿模型很小)。

  • 然后目标模型在单次前向传递中同时对所有\(k\)个标记进行评分(这很高效,因为工作被批处理了)。

  • 对于从草稿分布\(p_d(t)\)中采样的每个候选标记\(t\),它以概率\(\min(1, \, p_{\text{target}}(t) / p_d(t))\)被接受。如果被拒绝,则从调整后的分布\(p_{\text{adj}}(t) = \max(0, \, p_{\text{target}}(t) - p_d(t))\)(归一化后)重新采样一个修正后的标记。

  • 这种接受-拒绝方案保证了输出分布与仅使用目标模型相同。

  • 为了理解原因,考虑发出标记\(t\)的有效概率。它可以被直接接受(概率\(p_d(t) \cdot \min(1, p_{\text{target}}(t)/p_d(t))\))或通过重采样产生。

  • 对于\(p_{\text{target}}(t) \leq p_d(t)\)的标记,直接接受贡献\(p_{\text{target}}(t)\)。对于\(p_{\text{target}}(t) > p_d(t)\)的标记,直接接受贡献\(p_d(t)\),重采样贡献剩余部分\(p_{\text{target}}(t) - p_d(t)\)(考虑到拒绝概率后)。

  • 在两种情况下,发出\(t\)的总概率等于\(p_{\text{target}}(t)\)。草稿模型只影响速度,不影响质量。

推测解码

  • 加速取决于接受率:如果草稿模型与目标模型对齐良好,大多数标记被接受,墙钟时间大致等于草稿模型的时间。典型的加速是在没有质量下降的情况下达到2-3倍。

  • Medusa(Cai等人,2024)采用了不同的方法:它不单独使用草稿模型,而是向目标模型本身添加多个轻量级预测头。每个头同时预测不同的未来标记位置(\(k = 1, 2, 3, \ldots\)步之后)。在每一步中,Medusa使用树结构提出多个候选延续,通过目标模型注意力层的单次前向传递验证哪些候选是一致的。这完全避免了需要单独的草稿模型。

  • 并行生成方法更广泛地旨在打破自回归解码的顺序瓶颈。Jacobi解码用猜测初始化所有位置,并并行迭代细化直到收敛,将生成视为不动点迭代。非自回归模型在单次前向传递中同时生成所有标记,但通常会遭受质量下降,需要诸如迭代细化、CTC损失或来自自回归教师的知识蒸馏等技术来缩小差距。

  • 上述技术——对齐、长上下文、检索、高效解码、状态空间模型——在现代生产级LLM中结合在一起。

  • 本文件的其余部分调查了前沿模型的架构创新,展示了文件01-04中的理论思想和上述方法在实践中是如何结合的。

  • 分组查询注意力(GQA)是最广泛采用的注意力效率技术。标准的多头注意力为每个头维护单独的键和值投影,需要每个标记缓存\(n_{\text{heads}} \times d_{\text{head}}\)个值。GQA将多个查询头分组,共享一个键值头。

  • 使用64个查询头和8个KV头(Llama 3、Qwen、Gemma中的常见配置),每个KV头被8个查询头共享,与MHA相比,KV缓存减少了8倍。

  • 输出质量与MHA几乎相同,因为查询仍然可以关注不同的模式,它们只是共享相同的键值子空间。多查询注意力是极端情况,所有查询共享一个KV头,但GQA提供了更好的质量-效率权衡。

  • 多头潜在注意力(MLA),在DeepSeek-V2中引入,实现了更激进的KV缓存压缩。MLA不是缓存完整的键值投影(即使有GQA),而是将隐藏状态下投影到一个低秩潜在向量\(c_t \in \mathbb{R}^{d_c}\),其中\(d_c \ll n_{\text{heads}} \times d_{\text{head}}\)

\[c_t = W_{\text{down}} \, h_t\]
  • 只有这个压缩向量被缓存。在注意力计算时,完整的键和值表示通过上投影重建:\(k_t = W_{\text{up}}^K c_t\)\(v_t = W_{\text{up}}^V c_t\)。在DeepSeek-V3(总参数671B,激活参数37B)中,压缩维度\(d_c = 512\),而完整MHA为\(128 \times 128 = 16{,}384\),KV缓存减少了93%。

  • 一个微妙之处:标准RoPE依赖于位置,与共享压缩不兼容,因此MLA使用解耦的RoPE:一小部分独立的查询和键流(每头64维)通过RoPE携带位置信息,而表示的主体部分通过压缩的潜在路径流动。

注意力KV缓存策略:MHA、GQA和MLA比较

  • 大规模位置编码与原始的正弦方案已有显著不同。所有前沿模型都使用RoPE(文件04),但为了长上下文做了关键修改。原始RoPE公式\(\theta_i = \theta_{\text{base}}^{-2i/d}\)中的基频\(\theta_{\text{base}}\)通常为10,000,这限制了超出训练长度的外推。

  • 调整基频(ABF) 简单地将\(\theta_{\text{base}}\)增加到500,000(Llama 3)或1,000,000(Qwen3、Gemma 3),拉伸旋转周期,使得模型在训练期间遇到较少完整的旋转,并能外推得更远。

  • YaRN(又一个RoPE扩展)应用频率相关的插值:低频维度被插值(缩小),高频维度被外推,温度因子调整注意力分布。DeepSeek-V3、Qwen和Kimi K2都使用基于YaRN的扩展,从在4K-8K上预训练的模型达到128K上下文。

  • iRoPE(交错RoPE),在Llama 4中引入,采取了更激进的方法:每第4个注意力层完全不使用位置编码(NoPE),而其他层使用带有分块注意力的标准RoPE。

  • NoPE层可以关注所有位置而没有任何位置偏差,而RoPE层提供局部顺序。结合推理时的温度缩放,这使Llama 4 Scout能够实现1000万标记的上下文窗口——比任何纯RoPE方法都高出几个数量级。

  • 大规模的混合专家已成为前沿模型的主导架构(文件04介绍了MoE的基础知识)。关键的设计选择是专家数量、路由稀疏性和负载均衡。

  • 路由稀疏性差异显著:DeepSeek-V3使用256个专家,top-8路由(32倍稀疏性);Qwen3使用128个专家,top-8(16倍稀疏性);Mixtral使用8个专家,top-2(4倍稀疏性);Llama 4 Maverick使用128个专家,top-1加一个共享专家(128倍稀疏性)。

  • 更高的稀疏性意味着相同激活计算量下更多的总参数,但需要更仔细的负载均衡和通信基础设施。

  • 无辅助损失的负载均衡(DeepSeek-V3)取代了传统的负载均衡损失(文件04),后者被发现会降低模型质量。取而代之,每个专家维护一个动态偏置项,在每个训练步骤进行调整:过载的专家偏置降低(接收更少的标记),欠载的专家偏置增加。这实现了平衡的路由,没有辅助损失污染主训练信号。

  • 共享专家出现在大多数MoE设计中:一个或多个专家FFN处理每个标记,无论路由结果如何。它们处理所有标记都需要的常见模式(基本句法、功能词),释放被路由的专家进行特化。Llama 4每个标记使用1个共享专家加1个路由专家(非常稀疏);DeepSeek-V3使用1个共享加8个路由。

  • 交替的稠密层和MoE层提供了另一个设计维度。Gemma 2和3交替使用局部/全局注意力层(Gemma 3中比例为5:1,其中局部层使用1024标记的滑动窗口,只有全局层缓存完整的128K上下文)。

  • Llama 4 Maverick交错使用稠密FFN层和MoE层。Kimi K2使用混合稀疏层(在专家层之间穿插一个稠密层)。这种异构设计允许不同的层服务于不同的功能。

  • 多标记预测(MTP),在DeepSeek-V3中使用,训练模型不仅预测下一个标记,还预测之后的标记。在每个位置,一个次级预测模块(共享主模型的嵌入)预测一个额外的未来标记。MTP损失相对于主下一个标记损失的权重为0.1-0.3。除了在训练期间提高表示质量外,MTP头还可以在推理时用作推测解码的草稿头,提供免费加速。

  • 知识蒸馏是一种训练策略,其中大型“教师”模型的输出指导较小“学生”模型的训练。Gemma 2和3广泛使用蒸馏:较小的模型(2B、4B)在计算最优数据量50倍的数据上训练,使用教师的概率分布作为软目标。这就是为什么Gemma 3-4B在质量上匹配Gemma 2-27B。

  • 蒸馏损失替换或补充了标准交叉熵:学生最小化其输出分布与教师分布之间的KL散度:

\[\mathcal{L}_{\text{distill}} = D_{\text{KL}}(p_{\text{teacher}}(\cdot \mid x) \| p_{\text{student}}(\cdot \mid x))\]
  • DeepSeek-R1将其671B推理模型蒸馏成小至1.5B的稠密模型,使用了80万个精选的思维链样本,产生了具有不成比例强大推理能力的小型模型。

  • 通过强化学习进行推理代表了LLM能力最近最显著的进步。DeepSeek-R1展示了在基础模型上进行纯强化学习(没有监督微调)可以引发出思维链推理、自我验证和纠错行为,这些行为在模型因正确答案获得奖励时会自发出现。

  • DeepSeek-R1使用GRPO(群组相对策略优化),它消除了PPO所需的价值网络。对于每个提示,GRPO采样一组\(G\)个输出,计算它们的奖励,并在组内对优势进行归一化:

\[A_i = \frac{r_i - \text{mean}(r_1, \ldots, r_G)}{\text{std}(r_1, \ldots, r_G)}\]
  • 然后策略梯度使用这些群组相对优势和一个裁剪的目标函数(类似于PPO的裁剪)。

  • 消除批评者网络将RL训练的内存和计算需求减半,使得用RL训练671B参数的模型变得可行。

  • 一个关键的设计选择:DeepSeek-R1使用基于规则的奖励(根据标准答案检查数学答案,运行代码测试用例)而不是神经奖励模型,因为在这个规模上发现神经奖励模型容易受到奖励破解的影响。

  • Qwen3的混合思考模式将推理(使用<think>标签进行逐步思维链)和快速直接响应集成到单个模型中,允许用户控制“思考预算”,以权衡延迟和推理深度。

  • 这是通过在思考和思考数据上共同训练实现的,而不是通过单独的模型检查点。

  • 大规模训练稳定性需要超越标准实践的新技术。Logit软裁剪(Gemma 2)通过\(s \cdot \tanh(\text{logits} / s)\)传递注意力分数,并使用软上限\(s\)(通常30-50)来防止无界增长。

  • QK-Norm(Qwen3)在计算注意力分数之前对查询和键向量应用RMSNorm,取代了对QKV偏置的需求。QK-Clip(Kimi K2的MuonClip优化器)在训练期间监控最大注意力logit,当超过阈值时重新缩放查询-键权重矩阵,使得能够稳定预训练万亿参数模型而零不稳定事件。

  • FP8混合精度训练(DeepSeek-V3)在前向和后向传播中对计算密集的矩阵乘法使用8位浮点数,而将主权重保持在更高精度。

  • 与BF16/FP16训练相比,这大致使吞吐量翻倍,而质量损失可忽略不计。DeepSeek-V3训练其671B参数模型仅用了280万H800 GPU小时——是同类模型的一小部分——这主要归功于这一点和其他工程优化。

编码任务(使用 CoLab 或 notebook)

  1. 从头实现一个简单的检索增强生成流水线。使用TF-IDF(文件02)为文档集建立索引,为查询检索最相关的段落,并将其前置到提示中。

    import jax.numpy as jnp
    import math
    from collections import Counter
    
    # 知识库:一组短文
    knowledge_base = [
        "埃菲尔铁塔是法国巴黎的一座锻铁格子塔。它建于1887年至1889年,是1889年世界博览会的中心建筑。",
        "中国长城是在中国北部边境修建的一系列防御工事。始建于公元前7世纪。",
        "光合作用是植物利用叶绿素将阳光、水和二氧化碳转化为葡萄糖和氧气的过程。",
        "广义相对论由阿尔伯特·爱因斯坦于1915年发表,将引力描述为质量和能量引起的时空曲率。",
        "Python是一种以其简洁语法和可读性而闻名的高级编程语言。它由Guido van Rossum创建并于1991年发布。",
        "线粒体是真核细胞中发现的细胞器。它们产生细胞大部分的三磷酸腺苷供应,用作化学能源。",
    ]
    
    # 构建 TF-IDF 索引(复用文件02的概念)
    def tokenise(text):
        return text.lower().split()
    
    vocab = sorted(set(w for doc in knowledge_base for w in tokenise(doc)))
    word2idx = {w: i for i, w in enumerate(vocab)}
    V = len(vocab)
    N = len(knowledge_base)
    
    # 文档频率
    doc_freq = Counter()
    for doc in knowledge_base:
        for w in set(tokenise(doc)):
            doc_freq[w] += 1
    
    def tfidf_vector(text):
        words = tokenise(text)
        counts = Counter(words)
        vec = jnp.zeros(V)
        for w, c in counts.items():
            if w in word2idx:
                tf = 1 + math.log(c)
                idf = math.log(N / (doc_freq.get(w, 0) + 1))
                vec = vec.at[word2idx[w]].set(tf * idf)
        return vec
    
    # 索引所有文档
    doc_vectors = jnp.stack([tfidf_vector(doc) for doc in knowledge_base])
    
    def cosine_sim(a, b):
        return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8)
    
    def retrieve(query, top_k=2):
        """为查询检索最相关的前 k 个段落。"""
        q_vec = tfidf_vector(query)
        sims = jnp.array([cosine_sim(q_vec, doc_vectors[i]) for i in range(N)])
        top_indices = jnp.argsort(-sims)[:top_k]
        return [(int(i), float(sims[i]), knowledge_base[int(i)]) for i in top_indices]
    
    # 测试检索
    queries = [
        "谁建造了埃菲尔铁塔?",
        "植物如何制造食物?",
        "爱因斯坦发现了什么?",
    ]
    
    for query in queries:
        results = retrieve(query, top_k=1)
        print(f"\n查询: '{query}'")
        for idx, sim, passage in results:
            print(f"  检索到 (相似度={sim:.3f}): '{passage[:80]}...'")
    
        # RAG 风格的提示构造
        context = results[0][2]
        rag_prompt = f"上下文:{context}\n\n问题:{query}\n答案:"
        print(f"  RAG 提示:\n    {rag_prompt[:120]}...")
    

  2. 使用玩具草稿模型和目标模型实现推测解码。展示接受的输出与目标模型的分布匹配。

    import jax
    import jax.numpy as jnp
    
    # 模拟草稿模型(快速,精度较低)和目标模型(慢速,精度高)
    vocab_size = 8
    seq_len = 5
    
    key = jax.random.PRNGKey(42)
    
    # 目标模型:给定序列返回 logits
    def target_model(seq, key):
        """模拟的目标模型:产生标记 logits(昂贵)。"""
        # 实际中这会是一个大型 Transformer 前向传播
        k1, k2 = jax.random.split(key)
        logits = jax.random.normal(k1, (len(seq), vocab_size)) * 2
        # 使其具有某种可预测性:偏向于标记 (seq[-1] + 1) % vocab_size
        for i in range(len(seq)):
            logits = logits.at[i, (seq[i] + 1) % vocab_size].add(3.0)
        return logits
    
    def draft_model(seq, key):
        """模拟的草稿模型:相似但噪声更大(便宜)。"""
        k1, k2 = jax.random.split(key)
        logits = jax.random.normal(k1, (len(seq), vocab_size))
        for i in range(len(seq)):
            logits = logits.at[i, (seq[i] + 1) % vocab_size].add(2.0)
        return logits
    
    def sample_token(logits, key):
        return jax.random.categorical(key, logits)
    
    def speculative_decode(prefix, draft_steps=3, key=jax.random.PRNGKey(0)):
        """推测解码:草稿提议,目标验证。"""
        seq = list(prefix)
        total_accepted = 0
        total_proposed = 0
    
        for _ in range(4):  # 生成 4 轮
            key, *subkeys = jax.random.split(key, draft_steps + 3)
    
            # 草稿模型提议 draft_steps 个标记
            draft_tokens = []
            draft_probs = []
            draft_seq = list(seq)
            for i in range(draft_steps):
                d_logits = draft_model(jnp.array(draft_seq), subkeys[i])
                d_probs = jax.nn.softmax(d_logits[-1])
                tok = sample_token(d_logits[-1], subkeys[i])
                draft_tokens.append(int(tok))
                draft_probs.append(d_probs)
                draft_seq.append(int(tok))
    
            # 目标模型在一次前向传播中对所有草稿标记评分
            target_logits = target_model(jnp.array(draft_seq), subkeys[draft_steps])
            target_start = len(seq) - 1  # 最后一个前缀标记的位置
    
            # 接受/拒绝每个草稿标记
            accepted = 0
            for i in range(draft_steps):
                t_probs = jax.nn.softmax(target_logits[target_start + i])
                d_prob = draft_probs[i][draft_tokens[i]]
                t_prob = t_probs[draft_tokens[i]]
    
                # 以概率 min(1, target_prob / draft_prob) 接受
                accept_prob = jnp.minimum(1.0, t_prob / (d_prob + 1e-10))
                key, accept_key = jax.random.split(key)
                if jax.random.uniform(accept_key) < accept_prob:
                    seq.append(draft_tokens[i])
                    accepted += 1
                else:
                    # 拒绝:从调整后的分布采样
                    key, resample_key = jax.random.split(key)
                    adjusted = jnp.maximum(0, t_probs - draft_probs[i])
                    adjusted = adjusted / (adjusted.sum() + 1e-10)
                    new_tok = jax.random.categorical(resample_key, jnp.log(adjusted + 1e-10))
                    seq.append(int(new_tok))
                    break
    
            total_accepted += accepted
            total_proposed += draft_steps
    
        return seq, total_accepted, total_proposed
    
    # 运行推测解码
    prefix = [0, 1]
    result_seq, accepted, proposed = speculative_decode(prefix)
    acceptance_rate = accepted / proposed if proposed > 0 else 0
    
    print(f"前缀: {prefix}")
    print(f"生成的序列: {result_seq}")
    print(f"草稿提议数: {proposed}")
    print(f"接受数: {accepted}")
    print(f"接受率: {acceptance_rate:.1%}")
    print(f"潜在加速: {(accepted + proposed) / proposed:.2f}x")
    

  3. 构建一个简单的 DPO 训练循环。给定偏好和非偏好补全对,使用 DPO 损失更新一个小模型。

    import jax
    import jax.numpy as jnp
    
    # 微型语言模型:从独热到 logits 的线性投影
    vocab_size = 10
    seq_len = 4
    
    key = jax.random.PRNGKey(42)
    k1, k2 = jax.random.split(key)
    
    # 当前策略参数(可训练)
    theta = jax.random.normal(k1, (vocab_size, vocab_size)) * 0.1
    # 参考策略参数(初始 theta 的冻结副本)
    theta_ref = theta.copy()
    
    def log_prob_sequence(params, sequence):
        """在简单的自回归模型下计算 P(序列) 的对数。"""
        total = 0.0
        for t in range(1, len(sequence)):
            # 简单:位置 t 的 logits 依赖于位置 t-1 的标记
            logits = params[sequence[t-1]]
            log_probs = jax.nn.log_softmax(logits)
            total += log_probs[sequence[t]]
        return total
    
    def dpo_loss(theta, theta_ref, preferred, dispreferred, beta=0.1):
        """单对偏好的直接偏好优化损失。"""
        log_pi_w = log_prob_sequence(theta, preferred)
        log_pi_l = log_prob_sequence(theta, dispreferred)
        log_ref_w = log_prob_sequence(theta_ref, preferred)
        log_ref_l = log_prob_sequence(theta_ref, dispreferred)
    
        # DPO 目标
        return -jax.nn.log_sigmoid(
            beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))
        )
    
    # 偏好数据集:(提示前缀, 偏好补全, 非偏好补全)
    preferences = [
        (jnp.array([1, 3, 5, 7]), jnp.array([1, 3, 5, 2])),  # 末尾偏好 7 而非 2
        (jnp.array([0, 2, 4, 6]), jnp.array([0, 2, 4, 9])),  # 偏好 6 而非 9
        (jnp.array([3, 3, 3, 3]), jnp.array([3, 3, 3, 0])),  # 偏好重复而非 0
        (jnp.array([5, 6, 7, 8]), jnp.array([5, 6, 7, 1])),  # 偏好 8 而非 1
    ]
    
    grad_fn = jax.jit(jax.grad(dpo_loss))
    lr = 0.05
    
    print("训练 DPO...")
    for epoch in range(100):
        total_loss = 0.0
        for preferred, dispreferred in preferences:
            loss = dpo_loss(theta, theta_ref, preferred, dispreferred)
            grads = grad_fn(theta, theta_ref, preferred, dispreferred)
            theta = theta - lr * grads
            total_loss += loss
        if (epoch + 1) % 20 == 0:
            avg_loss = total_loss / len(preferences)
            print(f"  轮次 {epoch+1}: 平均 DPO 损失 = {avg_loss:.4f}")
    
    # 检查:模型现在应该偏好偏好补全
    print("\nDPO 训练后的偏好检查:")
    for preferred, dispreferred in preferences:
        lp_w = log_prob_sequence(theta, preferred)
        lp_l = log_prob_sequence(theta, dispreferred)
        print(f"  偏好 {list(preferred.astype(int))}: logP={lp_w:.3f}  "
              f"非偏好 {list(dispreferred.astype(int))}: logP={lp_l:.3f}  "
              f"{'正确' if lp_w > lp_l else '错误'}")