Skip to content

高效架构

让模型更快不仅仅是降低精度,还在于设计更智能的架构,让每个token做更少的计算。本文件涵盖StreamingLLM、稀疏注意力和线性注意力、多查询和分组查询注意力、推理时的混合专家、知识蒸馏、剪枝和神经架构搜索

  • 量化(文件01)让每次操作更便宜。本文件让操作数量首先变得更少。两者互补:一个既在架构上高效又被量化的模型可以比原始模型快10-100倍。

StreamingLLM:无限长度生成

  • 标准Transformer将所有先前token存储在KV缓存中,缓存随序列长度线性增长。在某个时刻,缓存超过GPU内存,生成失败。StreamingLLM(Xiao et al., 2023)通过固定大小的滚动KV缓存解决此问题。

  • 关键观察:序列中的前几个token无论其内容如何,都获得不成比例的高注意力分数。这些被称为注意力汇(attention sinks)。如果将它们从缓存中剔除,注意力分布会崩塌,生成质量灾难性下降。

  • StreamingLLM的解决方案:在缓存中永久保留少量汇token(前1-4个token),加上最近\(w\)个token的滚动窗口。总缓存大小为\(\text{sink} + w\),无论已生成多少token都固定不变。

\[\text{Cache} = [\text{token}_0, \text{token}_1, \text{token}_{t-w+1}, \ldots, \text{token}_t]\]
  • 注意力汇锚定softmax分布,滚动窗口提供近期上下文。这使得以恒定内存实现无限长度生成成为可能,代价是丢失序列中间部分的上下文。

  • StreamingLLM对已自然形成注意力汇的模型(大多数预训练LLM都有)无需任何重训练即可工作。对于没有的模型,在训练时添加一个可学习的汇token即可修复。

稀疏注意力

  • 完整自注意力在序列长度\(n\)上是\(O(n^2)\)的,因为每个token关注所有其他token。对于\(n = 128K\),注意力矩阵有\(128K^2 = 160\)亿个条目。稀疏注意力模式通过限制哪些token关注哪些来减少这个数量。

注意力稀疏模式:完整注意力是O(n²),滑动窗口是O(n·w),局部+全局添加了长距离token

  • 滑动窗口注意力(Mistral、Gemma):每个token只关注前\(w\)个token(例如\(w = 4096\))。注意力是\(O(n \cdot w)\)而非\(O(n^2)\)。信息通过多个层在窗口之外传播:经过\(L\)层后,有效上下文是\(L \times w\)

  • 局部+全局注意力(Longformer、BigBird):大多数token使用滑动窗口注意力(局部),但少数指定token(如[CLS]、每512个token)关注所有token(全局)。这同时捕获局部模式和长距离依赖。

  • 膨胀注意力:在窗口内每隔\(k\)个token关注一次,创建稀疏模式,以相同数量的注意力分数覆盖更大范围。在各层增加膨胀率,创建类似膨胀卷积(第8章)的层次模式。

  • 现代LLM的实际赢家是滑动窗口+完整注意力交替:某些层使用滑动窗口(便宜,处理局部上下文),某些层使用完整注意力(昂贵,捕获长距离)。Mistral/Mixtral使用这种模式。

线性注意力和状态空间模型

  • 能否完全替代\(O(n^2)\)的注意力?线性注意力状态空间模型(SSM)通过避免显式注意力矩阵以\(O(n)\)时间处理序列。

  • 线性注意力用核近似替代softmax注意力:

\[\text{标准: } O = \text{softmax}(QK^T / \sqrt{d}) V$$ $$\text{线性: } O = \phi(Q) (\phi(K)^T V)\]
  • 通过首先结合\(K^T V\)积(它是\(d \times d\),与序列长度无关),计算变为\(O(n \cdot d^2)\)而非\(O(n^2 \cdot d)\)。对于\(n \gg d\)的长序列,这是巨大的节省。

  • RWKV结合了RNN和Transformer的思想。它使用一种循环公式,像RNN一样顺序处理token,但在训练时可并行化(像Transformer一样)。推理时每个token是\(O(1)\)(恒定内存,无KV缓存增长)。

  • Mamba(Gu & Dao, 2023)是一种选择性的状态空间模型。它通过学习的状态转换处理序列:

\[h_t = \bar{A} h_{t-1} + \bar{B} x_t, \quad y_t = C h_t\]
  • 其中\(\bar{A}\)\(\bar{B}\)是输入依赖的(选择性),允许Mamba动态关注或忽略输入的部分。与固定SSM不同,选择性使Mamba在语言任务上与Transformer有竞争力,同时保持\(O(n)\)扩展。

  • 权衡:线性注意力和SSM对长序列更快,但在需要精确长距离检索的任务上通常不如完整注意力。混合架构(一些Transformer层+一些Mamba层)通常取两者之长。

多查询和分组查询注意力

  • 标准多头注意力(MHA,第7章)为每个头使用独立的\(K\)\(V\)投影。对于\(h\)个头,这意味着KV缓存中有\(h\)个独立的键和值张量。多查询注意力(MQA)分组查询注意力(GQA)减少这个数量。

  • MQA(Shazeer, 2019):所有头共享一组\(K\)\(V\)投影。每个头仍有自己的\(Q\)投影。KV缓存缩小\(h\)倍(例如,32个头缩小32倍)。

  • GQA(Ainslie et al., 2023):中间方案。头被分组,每组共享一组\(K\)\(V\)投影。使用\(h = 32\)个头和\(g = 8\)个组,每组4个头共享K/V。KV缓存缩小\(h/g = 4\)倍。

\[\text{MHA: } h \text{个头, } h \text{组K/V} \quad \to \quad \text{GQA: } h \text{个头, } g \text{组K/V} \quad \to \quad \text{MQA: } h \text{个头, } 1 \text{组K/V}\]

MHA vs GQA vs MQA:MHA为每个头提供独立KV,GQA在组间共享KV,MQA为所有头使用单一KV——大幅减少KV缓存大小

  • 大多数现代LLM使用GQA(Llama 2/3、Gemma、Mistral)。它减少KV缓存内存和推理延迟,与MHA相比质量损失可忽略不计。

多头潜在注意力(MLA)

  • MLA(DeepSeek-V2, 2024)比GQA更进一步,将KV缓存压缩到低秩潜在空间。MLA不缓存完整的键和值向量,而是缓存每个token的压缩潜在向量\(\mathbf{c}_t\),在注意力期间即时重建K/V:
\[\mathbf{c}_t = W_{\text{compress}} \cdot [\mathbf{k}_t; \mathbf{v}_t], \quad \mathbf{k}_t = W_K^{\text{up}} \cdot \mathbf{c}_t, \quad \mathbf{v}_t = W_V^{\text{up}} \cdot \mathbf{c}_t\]
  • 压缩向量\(\mathbf{c}_t\)比原始K和V的总和小得多。DeepSeek-V2实现了与MHA相比93.3%的KV缓存缩减,超过甚至MQA,同时保持MHA级别的质量。

  • 权衡:从潜在向量重建K/V增加了每次注意力操作的小计算成本。但由于LLM解码是内存带宽受限(而非计算受限),这是净收益:加载更少的内存 > 每个token稍多的计算。

Flash Attention

  • Flash Attention(Dao et al., 2022,在第16章文件05中详细介绍)不是架构更改,而是属于任何关于高效注意力的讨论中都应该包含的实现优化。它计算精确的标准注意力,具有:

    • O(n)内存而非O(n²)(注意力矩阵从未在HBM中实现化)。
    • 比标准注意力快2-4倍(通过tiling和online softmax将数据保留在SRAM中)。
    • 无质量损失——输出在数学上与标准注意力完全相同。
  • Flash Attention现在是PyTorch(torch.nn.functional.scaled_dot_product_attention)、JAX和所有主流推理框架中默认的注意力实现。如果你在2024+年运行注意力,几乎肯定在使用Flash Attention。

环形注意力

  • 环形注意力(Liu et al., 2023)将注意力计算分布到多个设备上,用于即使使用Flash Attention也无法放入单GPU内存的超长序列。

  • 思想:将序列分区到\(N\)个设备上。每个设备保存\(n/N\)个token的Q、K、V。设备排列成环形。每步:

    1. 每个设备计算局部注意力(其Q对其局部K/V)。
    2. 每个设备将其K/V块发送到环中的下一个设备。
    3. 每个设备从前一个设备接收K/V并计算针对这些的注意力。
    4. 经过\(N\)步后,每个设备都已关注每个K/V块。
  • 通信与计算重叠:在对当前K/V块计算注意力的同时,下一个块正在传输中。这几乎完全隐藏了通信延迟。

  • 环形注意力通过将KV缓存分布到GPU环上,实现百万token上下文窗口。每设备内存为O(n/N),使任意长序列成为可能(仅受设备数量限制)。

推理时的混合专家

  • MoE模型(第7章)每个token只激活一小部分参数(通常是8个专家中的2个)。在推理时,独特的挑战是专家缓存:所有专家必须在内存中(因为任何token可能路由到任何专家),但每个token只有2个处于活跃状态。

  • 对于Mixtral 8x7B模型:总参数 = 47B(8×7B专家,但有共享组件)。每个token的活跃参数 ≈ 13B(2个专家+共享层)。模型具有LLM-70B级别的质量,同时具有LLM-13B级别的推理成本,但需要47B参数在内存中。

  • 专家卸载:对于GPU内存受限的部署,将非活跃专家保留在CPU或SSD上,按需加载。这之所以可行,是因为token路由足够可预测,可以预取可能的专家。

  • 专家缓存:在GPU内存中维护最近使用的专家的LRU缓存。如果相同的专家被重复激活(对域内数据常见),缓存命中率就高。

知识蒸馏

  • 蒸馏(第6章)训练一个小型"学生"模型模仿大型"教师"。"学生"从教师的软预测(类别上的概率分布)中学习,这包含比硬标签更多的信息。
\[\mathcal{L} = \alpha \cdot \text{KL}(p_{\text{teacher}}^{T} \| p_{\text{student}}^{T}) + (1 - \alpha) \cdot \mathcal{L}_{\text{CE}}(y, p_{\text{student}})\]
  • 其中\(T\)是温度(更高的\(T\)软化分布,揭示教师的不确定性),\(\alpha\)平衡蒸馏损失与标准交叉熵损失。

  • 对于LLM:蒸馏用于从大型、能力强的模型创建小型、快速模型。GPT-4 → 一个捕捉GPT-4大部分行为的7B学生模型,用于特定任务。学生模型的服务成本可便宜10-100倍。

  • 任务特定蒸馏:仅在与你部署任务相关的数据上蒸馏。在医疗问答上从70B教师蒸馏出的7B模型在该特定任务上可以超越70B模型(因为学生的有限容量完全集中在目标领域)。

剪枝

  • 剪枝移除不必要的权重(设为零),减少模型大小和计算量。

  • 非结构化剪枝(基于幅度):移除具有最小绝对值的单独权重。创建稀疏权重矩阵。对压缩简单有效,但当前硬件(GPU)无法高效加速稀疏操作,除非稀疏性遵循特定模式。

  • 结构化剪枝:移除整个单元——注意力头、MLP神经元或层。生成更小的稠密模型,在标准硬件上轻松加速。权衡是更粗的粒度(移除完整头可能同时移除有用和无关的权重)。

  • 2:4稀疏性(NVIDIA Ampere+):一种硬件支持的稀疏模式,每4个权重中2个为零。GPU的稀疏Tensor Core跳过零乘法,实现约2倍加速。这是目前唯一具有实用硬件加速的稀疏模式。

  • 彩票假设(Frankle & Carlin, 2019):在随机初始化的网络中存在一个子网络("中奖彩票"),可以被隔离训练并达到完整网络的性能。找到这些子网络(通过训练、剪枝和回滚)代价高昂,但这一洞察推动了剪枝研究。

神经架构搜索(NAS)

  • NAS通过在可能的架构空间上搜索来自动化架构设计,找到在硬件约束(延迟、内存、功耗)下最大化准确度的架构。

  • EfficientNet(第8章)就是通过NAS找到的:复合缩放规则(平衡深度、宽度、分辨率)来自搜索,而非人类直觉。

  • 对于推理效率,NAS可以找到为特定硬件目标优化的架构:"在iPhone Neural Engine上找到延迟<5ms且在ImageNet上>80%准确度的模型"。搜索空间包括层类型、宽度、激活函数和注意力模式。

  • 一劳永逸网络训练一个过参数化的网络,并从中提取子网络用于不同部署目标。一次训练产出适用于云GPU、移动GPU和CPU的模型,每个都为其目标优化。

编程任务(使用CoLab或notebook)

  1. 实现滑动窗口注意力并与完整注意力比较内存使用。

    import jax
    import jax.numpy as jnp
    
    def full_attention(Q, K, V):
        """标准O(n^2)注意力。"""
        scores = Q @ K.T / jnp.sqrt(Q.shape[-1])
        weights = jax.nn.softmax(scores, axis=-1)
        return weights @ V
    
    def sliding_window_attention(Q, K, V, window_size=128):
        """滑动窗口注意力:每个token关注前window_size个token。"""
        n = Q.shape[0]
        d = Q.shape[-1]
        output = jnp.zeros_like(Q)
    
        for i in range(n):
            start = max(0, i - window_size + 1)
            k_window = K[start:i+1]
            v_window = V[start:i+1]
            scores = Q[i] @ k_window.T / jnp.sqrt(d)
            weights = jax.nn.softmax(scores)
            output = output.at[i].set(weights @ v_window)
    
        return output
    
    n, d = 512, 64
    key = jax.random.PRNGKey(0)
    Q = jax.random.normal(key, (n, d))
    K = jax.random.normal(jax.random.PRNGKey(1), (n, d))
    V = jax.random.normal(jax.random.PRNGKey(2), (n, d))
    
    print(f"完整注意力内存:    O(n^2) = {n*n} 个条目")
    print(f"窗口 (w=128) 内存: O(n*w) = {n*128} 个条目")
    print(f"缩减: {n*n / (n*128):.1f}x")
    

  2. 比较MHA、GQA和MQA的KV缓存大小。展示为什么GQA是实际最佳甜点。

    def kv_cache_size(n_heads, n_kv_heads, d_head, seq_len, bytes=2):
        """KV缓存大小(MB)。"""
        return 2 * n_kv_heads * d_head * seq_len * bytes / 1e6
    
    n_heads = 32
    d_head = 128
    seq_len = 32768
    
    mha = kv_cache_size(n_heads, n_heads, d_head, seq_len)       # 32个KV头
    gqa = kv_cache_size(n_heads, 8, d_head, seq_len)              # 8个KV头
    mqa = kv_cache_size(n_heads, 1, d_head, seq_len)              # 1个KV头
    
    print(f"MHA (32个KV头): {mha:.0f} MB/层")
    print(f"GQA (8个KV头):  {gqa:.0f} MB/层 ({mha/gqa:.0f}x 更小)")
    print(f"MQA (1个KV头):   {mqa:.0f} MB/层 ({mha/mqa:.0f}x 更小)")
    

  3. 模拟结构化剪枝:从随机注意力层中移除最不重要的注意力头并测量输出变化。

    import jax
    import jax.numpy as jnp
    
    key = jax.random.PRNGKey(0)
    n_heads, seq_len, d_head = 8, 64, 32
    
    # 随机多头注意力输出(每头一个)
    head_outputs = jax.random.normal(key, (n_heads, seq_len, d_head))
    
    # 完整输出:拼接所有头
    full_output = head_outputs.reshape(seq_len, n_heads * d_head)
    
    # 重要性:通过范数衡量每个头的贡献
    head_norms = jnp.linalg.norm(head_outputs, axis=(1, 2))
    print("头重要性(按范数):", jnp.round(head_norms, 2))
    
    # 剪枝最不重要的头
    for n_keep in [8, 6, 4, 2]:
        top_heads = jnp.argsort(head_norms)[-n_keep:]
        pruned = head_outputs[top_heads].reshape(seq_len, n_keep * d_head)
    
        # 填充到原始大小以便比较(将剪枝的头归零)
        full_pruned = jnp.zeros_like(head_outputs)
        full_pruned = full_pruned.at[top_heads].set(head_outputs[top_heads])
        full_pruned = full_pruned.reshape(seq_len, n_heads * d_head)
    
        error = jnp.linalg.norm(full_output - full_pruned) / jnp.linalg.norm(full_output)
        print(f"保留 {n_keep}/{n_heads} 个头: 相对误差 = {error:.4f}, "
              f"内存 = {n_keep/n_heads:.0%}")