量化¶
量化降低模型权重和激活值的精度,使模型更小、更快、运行成本更低。本文件涵盖数值格式、训练后量化、量化感知训练、仅权重量化方法(GPTQ、AWQ)、激活量化、混合精度和KV缓存量化
-
一个70B参数的模型在float16下需要140 GB内存,超过任何单张GPU的容量。量化到INT4后,它可以放入35 GB(一张A100)甚至20 GB(搭配offloading的消费级RTX 4090)。量化不是一个可有可无的优化;它是使大模型部署在经济上可行的关键。
-
根本权衡:更低精度意味着更少内存、更高吞吐量和更低功耗,但会引入量化误差,可能降低模型质量。量化的艺术在于将这个误差最小化。
为什么需要量化¶
-
内存缩减:INT8比FP16小2倍,INT4小4倍。对于LLM,模型权重主导内存占用。将精度减半意味着内存需求减半。
-
吞吐量提升:更低精度意味着每秒更多操作。NVIDIA Tensor Core(第16章)在FP16下达到FP32的2倍吞吐量,INT8比FP16再翻2倍,INT4又比INT8翻2倍。一张H100在FP8下能达到989 TFLOPS,而FP32只有67 TFLOPS——相差15倍。
-
带宽节省:LLM推理通常是内存带宽受限的(第16章,roofline模型)。瓶颈在于从GPU内存加载权重,而非用权重进行计算。更小的权重意味着更少的字节需要传输,直接提高每秒token数。这就是为什么量化常常为LLM推理带来接近线性的加速。
-
能耗节省:更低精度每次操作消耗更少的能量。在数据中心规模(数千张GPU)下,这转化为显著的电力成本缩减。
数值格式¶
- 我们在第13章(计算机体系结构)中介绍了IEEE 754浮点数。以下是ML的完整精度全景:
| 格式 | 位数 | 指数位 | 尾数位 | 范围 | 使用场景 |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | ±3.4×10³⁸ | 训练(黄金标准) |
| TF32 | 19 | 8 | 10 | ±3.4×10³⁸ | Tensor Core训练(A100+) |
| FP16 | 16 | 5 | 10 | ±65504 | 混合精度训练 |
| BF16 | 16 | 8 | 7 | ±3.4×10³⁸ | 训练(与FP32相同范围) |
| FP8 E4M3 | 8 | 4 | 3 | ±448 | 前向传播(Hopper+) |
| FP8 E5M2 | 8 | 5 | 2 | ±57344 | 梯度(更宽范围) |
| INT8 | 8 | — | — | -128到127 | PTQ推理 |
| INT4 | 4 | — | — | -8到7 | 仅权重量化 |
| INT2/三元 | 2 | — | — | {-1, 0, 1} | 极致压缩 |
-
FP8有两种变体:E4M3(4位指数、3位尾数,范围更窄但精度更高)用于前向传播,E5M2(5位指数、2位尾数,范围更宽但精度更低)用于梯度。Transformer Engine(第16章)按张量自动在两者之间切换。
-
BF16 vs FP16:BF16与FP32具有相同的指数范围(无溢出风险),但尾数精度较低。FP16精度更高但范围较窄(最大65504),训练时需要loss scaling。对于推理,两者都很好;对于训练,BF16更安全。
-
整数格式没有指数——它们表示定点值。在浮点和整数之间转换需要一个缩放因子和可选的零点:\(x_{\text{float}} = \text{scale} \times (x_{\text{int}} - \text{zero\_point})\)。
量化方程¶
- 所有量化方法都将浮点值映射到整数并反向映射:
-
缩放因子(scale)决定分辨率:\(\text{scale} = \frac{x_{\max} - x_{\min}}{q_{\max} - q_{\min}}\)。对于INT8:\(q_{\min} = -128\),\(q_{\max} = 127\)。
-
对称量化设置\(\text{zero\_point} = 0\),因此\(\text{scale} = \frac{\max(|x|)}{127}\)。更简单、更快(推理时无需零点减法)。
-
非对称量化使用非零\(\text{zero\_point}\)来处理非对称分布(例如ReLU输出全为非负值)。将\([x_{\min}, x_{\max}]\)映射到\([0, 255]\)用于无符号INT8。
- 量化粒度:多少个值共享同一个缩放因子:
- 逐张量:整个张量一个缩放因子。最简单但准确度最低(一个异常值会扭曲整个张量的缩放)。
- 逐通道:每个输出通道(卷积)或每行(线性层)一个缩放因子。效果好得多,开销极小。
- 逐组:每组\(g\)个元素一个缩放因子(例如\(g = 128\))。最佳准确度,用于现代仅权重量化(GPTQ、AWQ)。
- 逐token:激活值的每个token一个缩放因子。处理不同token具有非常不同的激活幅度的情况。
训练后量化(PTQ)¶
- PTQ在没有任何重新训练的情况下量化一个预训练模型。你通过一个校准集(一个小的代表性数据集,通常128-512个样本)运行模型以收集激活统计信息,然后计算最优缩放因子。
校准方法¶
-
最小-最大法:基于观察到的最小值和最大值设置缩放。简单但对异常值敏感(一个极端值将大部分量化范围浪费在极少使用的值上)。
-
百分位法:使用99.99百分位而非绝对最大值。削减极端异常值,为大多数值提供更好的分辨率。被削减的值饱和到\(q_{\min}\)或\(q_{\max}\)。
-
MSE最优:找到最小化原始张量和量化张量之间均方误差的缩放因子。这是一个1D优化(在可能的剪切值上搜索),通常给出最佳的PTQ准确度。
-
基于熵(KL散度):找到最小化原始值分布和量化值分布之间KL散度的缩放因子。用于TensorRT的INT8校准。
PTQ实践¶
# 简化的PTQ(PyTorch概念示例)
import torch
def quantise_tensor_symmetric(tensor, bits=8):
qmax = 2 ** (bits - 1) - 1 # INT8时为127
scale = tensor.abs().max() / qmax
quantised = torch.clamp(torch.round(tensor / scale), -qmax, qmax).to(torch.int8)
return quantised, scale
def dequantise(quantised, scale):
return quantised.float() * scale
# 量化一个权重矩阵
weight = torch.randn(512, 512) # 预训练权重
weight_q, scale = quantise_tensor_symmetric(weight, bits=8)
weight_reconstructed = dequantise(weight_q, scale)
# 量化误差
error = (weight - weight_reconstructed).abs().mean()
print(f"平均绝对误差: {error:.6f}")
print(f"压缩比: {weight.numel() * 4 / (weight_q.numel() * 1 + 4):.1f}x") # +4字节用于缩放因子
- PTQ在大多数模型上以<1%的准确度下降在INT8下工作良好。对于INT4,PTQ质量下降显著——仅权重方法(见下文)能更好地处理INT4。
量化感知训练(QAT)¶
- QAT在训练图中插入伪量化操作:权重和激活在前向传播中被量化和反量化,但梯度像没有量化一样流过(直通估计器)。
-
模型在训练时学习对量化噪声具有鲁棒性。QAT通常能恢复PTQ丢失的大部分或全部准确度,特别是在低位宽(INT4、INT2)下。
-
代价:QAT需要重新训练(或微调)模型,这对于大模型来说很昂贵。对于一个70B参数的模型,QAT可能花费\(10,000-\)100,000的计算成本。PTQ基本零成本(仅校准)。
-
何时使用QAT:当PTQ质量不可接受时(通常是INT4或更低)、当部署到有严格延迟预算的边缘设备时,或者当模型将被量化数百万次时(一次性QAT成本被分摊)。
仅权重量化¶
- 对于LLM推理,瓶颈在于从内存加载权重,而非用权重进行计算(内存带宽受限阶段)。仅权重量化将权重量化到INT4或INT3,同时保持激活在FP16。计算在FP16中进行(即时反量化权重后),但内存消耗和带宽减少了4-8倍。
GPTQ¶
- GPTQ(Frantar et al., 2022)每次量化一列权重,通过调整后续列来补偿每列的误差。它使用Hessian矩阵(来自校准集的二阶信息)来确定最优量化顺序和误差补偿:
-
关键洞察:量化第\(j\)列会引入误差。GPTQ立即通过调整所有剩余列来补偿,使层的总输出(\(XW\))变化尽可能小。这是最优大脑量化(OBQ)应用于Transformer的方法。
-
使用4位组量化(组大小128)的GPTQ在大多数LLM上实现<1%困惑度下降。量化一个70B模型在单张GPU上约需1小时。
AWQ¶
-
AWQ(激活感知权重量化,Lin et al., 2023)观察到一小部分权重通道(1-3%)比其他通道重要得多——它们对应具有大幅度的激活通道。保护这些重要通道显著减少量化误差。
-
AWQ在量化前将重要通道乘以因子\(s\)(使其更大,受舍入影响更小),并将对应激活乘以\(1/s\)(以保持输出不变)。因子\(s\)按组优化以最小化整体量化误差。
-
AWQ比GPTQ更简单(无需Hessian计算),运行更快,并达到可比较的质量。它已成为许多开源LLM量化流水线的默认选择。
GGUF / llama.cpp量化¶
-
GGUF(GGML通用格式)是llama.cpp用于CPU推理的格式。它支持多种量化方案:
- Q4_0:4位、32元素块、对称。
- Q4_K_M:4位与混合精度重要通道(k-quants)。
- Q5_K_M:5位与k-quants(更高质量)。
- Q8_0:8位、简单快速。
-
"K"变体(k-quants)为重要权重块分配更多位,类似于AWQ的洞察但在格式层面实现。Q4_K_M是大多数模型的最佳甜点:平均4位、质量损失最小。
QuIP和QuIP¶
-
QuIP(Chee et al., 2023)引入不相干处理:在量化前使用随机正交变换旋转权重矩阵。这将信息分散到所有权重上,防止少数异常权重主导量化误差。
-
直觉:如果一个权重是100而其他约~1,使用相同的缩放因子量化所有会将大部分INT4范围浪费在异常值上。经过正交旋转(保持矩阵的数学性质)后,所有权重具有相似幅度,均匀量化效果更好。
-
QuIP#通过格码本扩展:不再映射到均匀整数网格,而是映射到最优格(8D中的E8格)中的点。格码在相同位数内打包更多量化点,达到比均匀量化更好的率失真。QuIP#在2位精度下实现可用质量——是典型INT4方法的一半位数。
SpQR¶
-
SpQR(Dettmers et al., 2023)观察到极少量权重(0.1-1%)是异常值,对输出质量贡献不成比例。SpQR不将所有内容量化到相同精度,而是:
- 使用敏感性分析识别异常值权重(量化此权重会如何改变层输出?)。
- 以完整精度(FP16)以稀疏格式存储异常值。
- 将所有剩余权重量化到INT3或INT4。
-
结果:约99%的权重被激进量化(小),而关键的1%保持完整精度(准确)。稀疏异常值存储增加的开销极小(<5%总大小)。
HQQ¶
-
HQQ(半二次量化,Badri & Shaji, 2023)是一种零样本仅权重量化方法,完全不需要校准数据。它将量化公式化为半二次优化问题,迭代求解最优量化权重和缩放因子。
-
优势:无校准集意味着无数据依赖、即时量化、无校准数据不匹配的风险。HQQ对于无法获取代表性校准数据或校准数据敏感的模型特别有用。
AQLM¶
- AQLM(Egiazarian et al., 2024)将加法量化(多码本向量量化)应用于LLM。不是独立量化每个权重,AQLM将权重分组为向量,并将每个向量表示为来自多个学习码本条目的和:
- 其中\(\mathbf{c}_i^{(m)}\)是来自码本\(m\)的一个条目。使用\(M = 2\)个码本,每个256个条目,一个8元素向量被编码为两个8位索引 = 8个权重2字节 = 有效每权重2位。AQLM在2位精度下达到最先进质量,在此极端压缩级别上优于GPTQ和AWQ。
BitNet和1位LLM¶
-
BitNet(Wang et al., 2023)将量化做到极致:权重为三元(\(\{-1, 0, +1\}\)),每个权重仅需约1.58位。矩阵乘法变为仅加法和减法——不需要浮点乘法。
-
BitNet b1.58(Ma et al., 2024)将每个权重约束为\(\{-1, 0, +1\}\)。"1.58位"来自\(\log_2(3) \approx 1.58\)。在此精度下,70B模型装入约15 GB,推理无需乘法操作——只需加法、减法和符号翻转。
-
矩阵乘法变为:
- 这在任何硬件上都比FP16矩阵乘法便宜得多,可能在没有浮点计算单元的设设备上实现LLM推理。当前模型的质量让步显著,但随着规模扩大和训练时量化意识的增强而改善。
Microscaling(MX)格式¶
- Microscaling(MX)格式是一种新的行业标准(由AMD、Arm、Intel、Meta、Microsoft、NVIDIA、Qualcomm支持),使用块浮点:一组元素共享一个指数,每个元素有自己的尾数。
| 格式 | 共享指数 | 元素位数 | 总计(每元素) | 等效 |
|---|---|---|---|---|
| MXFP8 | 每块8位 | 8(E4M3/E5M2) | ~8 | 类似FP8但范围更好 |
| MXFP6 | 每块8位 | 6 | ~6.5 | 介于FP8和INT4之间 |
| MXFP4 | 每块8位 | 4 | ~4.5 | 类似INT4但具有浮点行为 |
| MXINT8 | 每块8位 | 8(整数) | ~8.5 | 具有共享缩放的INT8 |
- 共享指数将指数成本分摊到一个块(通常16-32个元素)。每个元素比单独拥有指数时保留更多的尾数位,提供更高的每比特精度。预计MX格式将在未来硬件中取代单独的FP8和INT8格式。
FP8训练¶
-
FP8训练(不仅是推理)现在在NVIDIA Hopper和Blackwell GPU上实用。方案:
-
前向传播:权重和激活使用E4M3(更高精度,更窄范围)。Transformer Engine使用延迟缩放(跟踪前一次迭代的统计信息,应用于当前迭代)动态计算逐张量缩放因子。
-
反向传播:梯度使用E5M2(更宽范围,更低精度)。梯度比权重/激活具有更宽的值范围,因此额外的指数位防止溢出。
-
主权重:以FP32保留用于优化器状态(类似于FP16的标准混合精度训练,第6章)。FP8计算仅用于矩阵乘法,不用于权重更新。
-
Loss Scaling:FP8仍然需要,就像FP16一样。动态loss scaler调整缩放因子以保持梯度值在FP8可表示的范围内。
-
-
FP8训练在大多数模型规模下达到与BF16训练可比较的质量,吞吐量提升约2倍。它是H100集群上新的大规模训练运行的默认选项。
激活量化¶
-
激活(层之间流动的中间张量)也可以量化,从而实现完全INT8计算(权重和激活都为INT8,累积为INT32)。
-
动态量化:在运行时从实际的激活值计算缩放因子。更准确(适应每个输入)但增加开销(在每个层计算最小/最大或百分位)。
-
静态量化:在校准期间一次性计算缩放因子并在推理时固定使用。推理更快(无需运行时统计)但如果校准数据不具有代表性则准确度较低。
-
逐token量化:为序列中的每个token计算单独的缩放因子。对LLM至关重要,因为不同token可能有非常不同的激活幅度(某些token产生的激活比其他token大100倍)。
-
激活量化比权重量化更难,因为激活是数据依赖的(每个输入都不同),而权重是固定的。"异常值"问题特别严重:少数激活通道具有极端值(均值的100倍),用与正常通道相同的缩放因子量化它们会浪费精度。
-
SmoothQuant(Xiao et al., 2022)通过数学上将量化难度从激活(因异常值难以量化)迁移到权重(易于量化)来解决异常值问题:将激活乘以\(1/s\),权重乘以\(s\),其中\(s\)平衡了难度。输出\(XW = (X \cdot \text{diag}(s^{-1})) \cdot (\text{diag}(s) \cdot W)\)保持不变。
混合精度量化¶
-
并非所有层对量化的敏感度相同。注意力层通常容忍INT4,而嵌入层和最终分类器需要更高精度。
-
敏感度分析:单独量化每一层并测量准确度影响。高敏感度的层获得更多位;不敏感的层获得更少位。
-
Transformer Engine(第16章,NVIDIA Hopper)在操作级别实现动态混合精度:每个矩阵乘法根据张量统计信息在FP8和FP16之间选择,在保持质量的同时最大化吞吐量。
KV缓存量化¶
- 在LLM生成过程中,KV缓存存储所有先前token的键和值张量。对于长序列,这主导了内存占用:
-
一个70B模型有80层、64个头、128维头,在序列长度128K下FP16:\(2 \times 80 \times 64 \times 128 \times 131072 \times 2 = 330\) GB。这超过了GPU内存。
-
KV缓存量化通过以INT8或INT4而非FP16存储缓存键值来减少此内存。量化误差在序列上累积(每个新token关注所有缓存的K/V),但通过逐通道或逐头量化,质量下降是可接受的。
-
KV缓存量化具有乘法效益:它支持更长的序列(更多上下文)、更大的批次大小(更多并发用户)和更快的推理(加载缓存所需的内存带宽更少)。这是LLM推理服务中影响最大的优化之一。
编程任务(使用CoLab或notebook)¶
-
从头实现对称INT8量化。量化一个权重矩阵,反量化它,并测量重建误差作为值分布的函数。
import jax.numpy as jnp import jax def quantise_int8(tensor): scale = jnp.max(jnp.abs(tensor)) / 127.0 quantised = jnp.clip(jnp.round(tensor / scale), -127, 127).astype(jnp.int8) return quantised, scale def dequantise(quantised, scale): return quantised.astype(jnp.float32) * scale # 正态分布权重(训练模型的典型情况) key = jax.random.PRNGKey(0) weights = jax.random.normal(key, (1024, 1024)) * 0.02 q, s = quantise_int8(weights) recon = dequantise(q, s) print(f"原始: {weights.nbytes / 1024:.0f} KB") print(f"量化: {q.nbytes / 1024:.0f} KB ({weights.nbytes / q.nbytes:.0f}x 更小)") print(f"平均绝对误差: {jnp.abs(weights - recon).mean():.6f}") print(f"最大绝对误差: {jnp.abs(weights - recon).max():.6f}") print(f"相对误差: {jnp.abs(weights - recon).mean() / jnp.abs(weights).mean():.4%}") -
演示异常值问题。创建具有少数极端通道的激活,展示逐张量量化失败而逐通道成功。
import jax.numpy as jnp import jax key = jax.random.PRNGKey(42) # 激活:大多数通道正常,2个通道有100倍异常值 activations = jax.random.normal(key, (32, 512)) * 0.1 activations = activations.at[:, 0].set(activations[:, 0] * 100) # 异常值通道 activations = activations.at[:, 1].set(activations[:, 1] * 50) # 异常值通道 # 逐张量量化(整个张量一个缩放因子) scale_tensor = jnp.max(jnp.abs(activations)) / 127.0 q_tensor = jnp.clip(jnp.round(activations / scale_tensor), -127, 127) recon_tensor = q_tensor * scale_tensor # 逐通道量化(每个通道一个缩放因子) scales_channel = jnp.max(jnp.abs(activations), axis=0) / 127.0 q_channel = jnp.clip(jnp.round(activations / scales_channel), -127, 127) recon_channel = q_channel * scales_channel err_tensor = jnp.abs(activations - recon_tensor).mean() err_channel = jnp.abs(activations - recon_channel).mean() print(f"逐张量误差: {err_tensor:.6f}") print(f"逐通道误差: {err_channel:.6f}") print(f"逐通道好 {err_tensor / err_channel:.1f}倍") print(f"\n异常值通道浪费了量化范围的 {(activations.shape[1] - 2) / activations.shape[1]:.0%}," f"仅用于 {2 / activations.shape[1]:.1%} 的通道") -
计算不同模型大小和序列长度的KV缓存内存。展示为什么KV缓存量化对于长上下文模型至关重要。
def kv_cache_gb(n_layers, n_heads, d_head, seq_len, bytes_per_elem): return 2 * n_layers * n_heads * d_head * seq_len * bytes_per_elem / 1e9 models = [ ("Llama-7B", 32, 32, 128), ("Llama-70B", 80, 64, 128), ("GPT-4 (估)", 120, 96, 128), ] print(f"{'模型':<15} {'SeqLen':>8} {'FP16 (GB)':>10} {'INT8 (GB)':>10} {'INT4 (GB)':>10}") print("-" * 60) for name, layers, heads, d_head in models: for seq_len in [4096, 32768, 131072]: fp16 = kv_cache_gb(layers, heads, d_head, seq_len, 2) int8 = kv_cache_gb(layers, heads, d_head, seq_len, 1) int4 = kv_cache_gb(layers, heads, d_head, seq_len, 0.5) print(f"{name:<15} {seq_len:>8} {fp16:>9.1f} {int8:>9.1f} {int4:>9.1f}") print()