Skip to content

图像与视频标记化

图像与视频标记化将连续视觉数据转换为离散标记序列,使 Transformer 能像处理文本一样处理它们。本文涵盖 VQ-VAE、VQ-GAN、码本学习、DALL-E 的 dVAE、视频标记化以及无查找量化。

为什么要标记化图像

  • 将语言想象成有限字母表:英语约有 26 个字母,现代语言模型将文本切分为 30,000-100,000 个子词标记。每个句子都变成离散符号序列,Transformer 可以逐个预测。而图像存在于连续高维空间中:单张 256×256 的 RGB 图像是 \(\mathbb{R}^{256 \times 256 \times 3} \approx \mathbb{R}^{196{,}608}\) 中的一个点。如果你想让语言模型用处理英语的相同机制"说"图像,就需要将这些连续像素数组转换为有限词汇表中可管理的离散标记序列。这种转换就是图像标记化

  • 想象你是一位马赛克艺术家。你没有无限种瓷砖颜色;你只有固定调色板,比如 8192 种独特瓷砖颜色。要将照片再现为马赛克,你必须:(1) 决定每块瓷砖代表照片的哪个区域,(2) 为每个区域选择最接近的瓷砖颜色,(3) 接受某些细节丢失但整体画面可识别。图像标记化正是如此:编码器将空间块压缩为潜在向量,码本将每个向量映射到其最近条目,结果是整数索引网格,每个块一个索引,离散模型可以处理。

  • 标记化的好处有三方面。首先,它大幅压缩图像:256×256 图像可能变成 16×16 标记网格,将序列长度从 65,536 像素减少到 256 个标记,这对于成本随序列长度二次方增长的注意力模型是可处理的。其次,它统一了表示:文本标记和图像标记存在于同一离散词汇表中,使单个自回归 Transformer 能够生成交错的文本和图像。第三,它施加了有用的瓶颈,迫使模型学习语义上有意义的编码,而非记忆像素噪声。

图像标记化流程概览:连续图像进入编码器,潜在向量针对码本量化,产生离散标记索引网格

  • 回想第 8 章卷积网络如何从图像提取分层特征图,以及第 7 章文本标记器如何将字符串转换为整数序列。图像标记化位于交叉点:它使用 CNN 或视觉 Transformer 编码器(第 8 章)生成空间特征,然后借用离散词汇表的概念(第 7 章)将这些特征转换为标记索引。

VQ-VAE:向量量化

  • 正如我们在第 6 章所见,标准变分自编码器(VAE)将输入编码为连续潜在分布,并从该分布中采样解码回重建。潜在空间是连续的,这使得将其输入离散序列模型变得不便。向量量化变分自编码器(VQ-VAE,van den Oord et al., 2017)通过引入可学习的嵌入向量码本并将每个编码器输出"吸附"到其最近码本条目,用离散潜在替换连续潜在。

  • 想象一个恰好有 \(K\) 个带标签书架的图书馆。当新书(编码器输出)到达时,图书管理员将其放在与其现有书籍(码本向量)最相似的书架上,并记录书架编号。稍后检索书籍时,你只需要书架编号:该书架上的码本条目是足够好的替代品。这就是向量量化。

  • 形式上,VQ-VAE 有三个组件:

  • 编码器 \(E\) 将输入图像 \(\mathbf{x} \in \mathbb{R}^{H \times W \times 3}\) 映射到连续潜在向量的空间网格 \(\mathbf{z}_e = E(\mathbf{x}) \in \mathbb{R}^{h \times w \times d}\),其中 \(h \times w\) 是下采样空间分辨率,\(d\) 是嵌入维度。

  • 码本 \(\mathcal{C} = \{\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_K\} \subset \mathbb{R}^d\) 包含 \(K\) 个可学习嵌入向量。典型码本大小范围为 512 到 16,384 个条目。

  • 解码器 \(D\) 从量化潜在重建图像。

  • 量化步骤将每个空间位置 \((i, j)\) 的编码器输出 \(\mathbf{z}_e(\mathbf{x})\) 替换为其最近码本条目:

\[\mathbf{z}_q(i,j) = \mathbf{e}_{k^\ast} \quad \text{其中} \quad k^\ast = \arg\min_k \|\mathbf{z}_e(i,j) - \mathbf{e}_k\|_2\]
  • 这是嵌入空间中的最近邻查找,与 k-means 分配(第 6 章)完全相同的操作。索引 \(k^\ast\) 是空间位置 \((i,j)\) 的离散标记,完整图像表示为来自 \(\{1, \ldots, K\}\)\(h \times w\) 整数网格。

VQ-VAE 架构:编码器产生连续潜在,每个潜在向量匹配到最近码本条目,解码器从量化编码重建

  • 挑战在于 \(\arg\min\) 不可微分:你无法通过离散选择反向传播。VQ-VAE 用直通估计器(straight-through estimator)解决此问题:前向传播时,解码器接收 \(\mathbf{z}_q\)(量化向量);反向传播时,重建损失关于 \(\mathbf{z}_q\) 的梯度直接复制到 \(\mathbf{z}_e\),仿佛量化步骤是恒等函数。这简洁地写为:
\[\mathbf{z}_q = \mathbf{z}_e + \text{sg}(\mathbf{z}_q - \mathbf{z}_e)\]
  • 其中 \(\text{sg}(\cdot)\) 是停止梯度算子。前向传播时这求值为 \(\mathbf{z}_q\);反向传播时,梯度仅通过 \(\mathbf{z}_e\) 项流动。

  • 完整 VQ-VAE 损失有三项:

\[\mathcal{L} = \underbrace{\|\mathbf{x} - D(\mathbf{z}_q)\|_2^2}_{\text{重建}} + \underbrace{\|\text{sg}(\mathbf{z}_e) - \mathbf{e}\|_2^2}_{\text{码本(VQ)}} + \underbrace{\beta \|\mathbf{z}_e - \text{sg}(\mathbf{e})\|_2^2}_{\text{承诺}}\]
  • 重建损失训练编码器和解码器忠实再现输入。码本损失(也称 VQ 损失)将码本向量拉向编码器输出;注意 \(\text{sg}(\mathbf{z}_e)\) 意味着编码器不从此项接收梯度,因此仅更新码本。承诺损失反之:它鼓励编码器输出保持接近码本向量,防止编码器"逃离"码本。超参数 \(\beta\)(通常 0.25)控制码本和承诺项之间的平衡。

  • 实践中,码本通常用指数移动平均(EMA)而非梯度下降更新,这更稳定。令 \(\mathbf{n}_k\) 为分配给码本条目 \(k\) 的编码器输出计数,\(\mathbf{s}_k\) 为其和。EMA 更新为:

\[\mathbf{n}_k \leftarrow \gamma \mathbf{n}_k + (1 - \gamma) |\{(i,j) : k^\ast_{ij} = k\}|\]
\[\mathbf{s}_k \leftarrow \gamma \mathbf{s}_k + (1 - \gamma) \sum_{(i,j) : k^\ast_{ij} = k} \mathbf{z}_e(i,j)\]
\[\mathbf{e}_k \leftarrow \frac{\mathbf{s}_k}{\mathbf{n}_k}\]
  • 其中 \(\gamma\) 是衰减率(通常 0.99)。这等价于在编码器输出上运行在线 k-means 算法。

码本坍塌

  • VQ-VAE 的一个著名失败模式是码本坍塌(也称索引坍塌):模型学会仅使用 \(K\) 个码本条目的一小部分,使大多数条目"死亡"。想象一个图书馆中 90% 的书架是空的,因为图书管理员总是将书籍路由到相同的几个热门书架。这浪费了表示能力。

  • 码本坍塌发生是因为编码器、码本和解码器在训练期间协同适应。如果一个条目在几个批次中未被选择,它会漂移远离编码器流形,使其更不可能被选择,形成正反馈循环。

  • 几种技术可缓解码本坍塌:

    • 码本重置:定期通过复制随机采样的编码器输出重新初始化死亡条目。这给死亡条目在潜在空间活跃区域附近的新起点。
    • 带拉普拉斯平滑的 EMA 更新:向 \(\mathbf{n}_k\) 添加小常数以防止任何条目计数为零,确保所有条目接收梯度信号。
    • 承诺损失调优:增加 \(\beta\) 迫使编码器输出更紧密地聚集在码本条目周围,更均匀地分配分配。
    • 因子化编码:将码本查找分解为较小编查的乘积(如两个大小为 \(\sqrt{K}\) 的码本),通过减少每次查找的有效码本大小提高利用率。
    • 熵正则化:添加鼓励码本使用均匀分布的惩罚,最大化熵 \(H = -\sum_k p_k \log p_k\),其中 \(p_k\) 是经验分配概率。

码本利用率:分配均匀的健康码本与大多数条目未使用的坍塌码本对比

VQ-GAN:用于更高保真度的对抗训练

  • VQ-VAE 产生不错的重建,但像素级 \(\ell_2\) 损失倾向于生成模糊输出,因为它平等惩罚每个像素偏差,在合理细节上平均而非选择清晰细节。想象要求某人绘制一张最小化与所有可能脸平均差异的脸——他们会绘制模糊的平均脸,而非清晰的个体脸。

  • VQ-GAN(Esser et al., 2021)通过将 VQ-VAE 框架与生成对抗网络(第 6 章)的判别器结合来解决此问题。判别器是基于块的卷积网络,判断局部图像块是真实的(来自训练数据)还是假的(来自解码器)。这种对抗损失鼓励解码器产生感知上清晰、真实的纹理,而非像素级平均。

  • VQ-GAN 目标在 VQ-VAE 损失上添加两项:

\[\mathcal{L}_\text{VQ-GAN} = \mathcal{L}_\text{VQ-VAE} + \lambda_\text{adv} \mathcal{L}_\text{adv} + \lambda_\text{perc} \mathcal{L}_\text{perc}\]
  • 对抗损失 \(\mathcal{L}_\text{adv}\) 是应用于解码器输出的标准 GAN 目标。判别器 \(\mathcal{D}\) 尝试区分真实块与解码块,解码器(生成器)尝试欺骗它。非饱和形式为:
\[\mathcal{L}_\text{adv} = -\mathbb{E}[\log \mathcal{D}(D(\mathbf{z}_q))]\]
  • 感知损失 \(\mathcal{L}_\text{perc}\) 比较预训练网络(通常 VGG 或 LPIPS)在原始与重建图像之间的特征激活:
\[\mathcal{L}_\text{perc} = \sum_l \|\phi_l(\mathbf{x}) - \phi_l(D(\mathbf{z}_q))\|_2^2\]
  • 其中 \(\phi_l\) 表示预训练网络第 \(l\) 层的特征图。此损失捕获高级结构相似性而非像素级精度。

  • 权重 \(\lambda_\text{adv}\) 自适应设置,使对抗梯度与重建梯度平衡,防止对抗损失在重建质量差的训练早期主导。

VQ-GAN 训练:编码器和解码器通过量化步骤连接,块判别器对解码输出提供对抗反馈

  • 结果是标记器在相同码本大小下产生比 VQ-VAE 显著更清晰的重建。VQ-GAN 是许多主要图像生成系统(包括原始 DALL-E、Parti 和众多文本到图像模型)的骨干标记器。它将 256×256 图像转换为 16×16 或 32×32 离散标记网格,码本大小 1024-16384,在每个空间维度实现 16 倍到 64 倍的压缩比。

残差量化与多尺度码本

  • 单个码本对重建质量施加硬性上限:每个空间位置由恰好一个码本向量表示,任何比码本可表达的更细细节都会丢失。想象用固定调色板中的单个词描述颜色:"青绿色"接近但不精确。如果你能添加细化——"青绿色,但略偏蓝且稍亮"——你会更接近。

  • 残差量化(RQ)迭代应用此思想。第一次量化步骤产生 \(\mathbf{z}_q^{(1)}\) 后,计算残差 \(\mathbf{r}^{(1)} = \mathbf{z}_e - \mathbf{z}_q^{(1)}\),然后用第二个码本量化残差得到 \(\mathbf{z}_q^{(2)}\),依此类推进行 \(T\) 级:

\[\mathbf{r}^{(0)} = \mathbf{z}_e\]
\[\mathbf{z}_q^{(t)} = \text{Quantise}(\mathbf{r}^{(t-1)}, \mathcal{C}^{(t)})\]
\[\mathbf{r}^{(t)} = \mathbf{r}^{(t-1)} - \mathbf{z}_q^{(t)}\]
  • 最终量化表示为 \(\hat{\mathbf{z}} = \sum_{t=1}^{T} \mathbf{z}_q^{(t)}\)。使用 \(T\) 级、每级码本大小为 \(K\),有效词汇表大小为 \(K^T\),但你只需存储 \(T \times K\) 个向量而非 \(K^T\)。例如,8 级、\(K = 1024\) 给出有效 \(1024^8 \approx 10^{24}\) 条目,同时仅存储 8192 个向量。

  • 每级捕获更细细节:第一个码本捕获粗略结构,第二个捕获中频校正,依此类推。这类似于 JPEG 中的逐次逼近或网页图像中的渐进渲染,其中粗略版本首先出现,细节逐步填充。

残差量化:原始向量在连续阶段近似,每阶段量化前一阶段的残差

  • 多尺度码本通过在不同空间分辨率操作扩展此思想。不是重复量化相同空间网格,而是在多个尺度量化:粗网格捕获全局结构,细网格捕获局部细节。这与第 8 章目标检测部分的特征金字塔思想相关,其中不同尺度的特征捕获不同级别的细节。

  • 乘积量化是相关技术,其中 \(d\) 维潜在向量被拆分为 \(M\) 个维度为 \(d/M\) 的子向量,每个子向量用其自己的码本独立量化。这给出 \(K^M\) 的有效词汇表,同时仅存储 \(M \times K\) 个向量。乘积量化广泛用于近似最近邻搜索(第 13 章),并已适配用于图像标记化。

  • 有限标量量化(FSQ,Mentzer et al., 2023)采用完全不同的方法:不是学习码本,它只是将潜在向量的每个维度舍入到固定整数级别集之一(如 \(\{-2, -1, 0, 1, 2\}\))。每维度 \(L\) 个级别、\(d\) 个维度,隐式码本大小为 \(L^d\)。FSQ 完全避免码本坍塌,因为没有可学习码本向量,只有确定性舍入的学习编码器输出。直通估计器处理舍入的不可微分性。

实践中的图像标记器

  • 从 VQ-VAE 到 VQ-GAN 到残差量化的演进催生了一系列用于最先进生成模型的实用图像标记器家族。

DALL-E 标记器(dVAE)

  • 原始DALL-E(Ramesh et al., 2021)使用离散 VAE(dVAE)将 256×256 图像标记化为 32×32 标记网格,码本大小 8192。dVAE 用 Gumbel-Softmax 松弛替换硬 \(\arg\min\) 量化,使训练期间前向传播可微分。推理时使用 \(\arg\max\) 产生硬标记分配。dVAE 用重建损失、对抗均匀先验的 KL 散度以及 Gumbel-Softmax 的学习温度调度组合训练。DALL-E 然后训练 120 亿参数自回归 Transformer 来建模 256 个文本标记和 1024 个图像标记(32×32)的联合分布。

LlamaGen

  • LlamaGen(Sun et al., 2024)表明,只要有好的图像标记器,就可以重用标准 Llama 风格语言模型架构(第 7 章)进行自回归图像生成。LlamaGen 使用改进的 VQ-GAN 标记器,大码本(16,384 条目),训练纯自回归 Transformer(除标记器外无特殊图像特定修改)以光栅扫描顺序从左到右预测图像标记。关键洞见是:一旦图像被标记化为离散序列,适用于语言的相同下一标记预测范式也适用于图像,验证了标记化真正桥接模态差距的想法。

Cosmos 标记器

  • Cosmos 标记器(NVIDIA, 2024)在统一框架中设计用于图像和视频。它使用因果 3D 架构,将图像视为单帧视频,使相同标记器能处理两种模态。Cosmos 支持连续和离散标记化模式:连续模式输出实值潜在向量(用于扩散模型后端),离散模式应用有限标量量化产生整数标记(用于自回归模型后端)。编码器使用因果 3D 卷积,使每帧的标记仅依赖当前和先前帧,支持流式视频标记化。

图像标记器架构对比:带 Gumbel-Softmax 的 dVAE、带码本查找的 VQ-GAN、带标量舍入的 FSQ

视频标记化

  • 视频为图像的空间维度添加了第三轴——时间。视频是帧序列,通常 24-30 帧/秒,相邻帧高度冗余,因为视觉世界在 33 毫秒内不会剧烈变化。视频标记化利用这种时间冗余实现比独立标记化每帧高得多的压缩。

  • 想象视频压缩像翻页书。如果你从头绘制每页,你需要数千张详细绘图。但大多数页与其邻居几乎相同,所以你可以每 10 页绘制完整"关键帧",仅在中间页记录小变化。视频标记器自动学习此技巧。

3D VQ-VAE

  • VQ-VAE 扩展到视频的最直接方式是3D VQ-VAE,用同时操作空间和时间维度的 3D 卷积替换编码器和解码器中的 2D 卷积。如果编码器在空间上下采样 \(f_s\) 倍、时间上下采样 \(f_t\) 倍,\(T \times H \times W\) 的视频片段变成 \((T/f_t) \times (H/f_s) \times (W/f_s)\) 的标记网格。

  • 例如,\(f_s = 16\)\(f_t = 4\) 时,16 帧 256×256 视频片段变成 \(4 \times 16 \times 16 = 1024\) 标记序列。这足够紧凑供 Transformer 自回归建模,而原始像素数将是 \(16 \times 256 \times 256 \times 3 \approx 310\) 万值。

  • 3D 卷积联合学习空间和时间特征。早期层捕获局部运动(帧间边缘移动),更深层次捕获更高级动态(物体出现、消失或变形)。这与第 8 章卷积网络的分层特征提取原理相同,沿时间轴扩展。

视频 3D VQ-VAE:短视频片段通过 3D 卷积编码为时空潜在向量网格,量化,解码回帧

因果视频标记器

  • 标准 3D 卷积查看过去、当前和未来帧,这意味着你需要整个视频片段才能标记化其中任何部分。因果视频标记器约束时间卷积,使每个输出仅依赖当前和先前帧,永不依赖未来帧。这类似于自回归 Transformer 中的因果掩码(第 7 章):信息在时间上向前流动,永不向后。

  • 因果标记化对两种用例至关重要。首先,流式:你可以在帧到达时实时标记化视频,无需缓冲未来帧。其次,自回归生成:当 Transformer 逐帧生成视频时,帧 \(t\) 的标记必须在不知道帧 \(t+1\) 的情况下可计算,因为帧 \(t+1\) 尚未生成。

  • 因果约束通过非对称填充时间卷积实现:时间大小为 \(k\) 的核在过去侧填充 \(k-1\) 个零、未来侧填充零个零,确保时间 \(t\) 的输出仅依赖时间 \(t-k+1, \ldots, t\) 的输入。

  • 因果视频标记器的一个优雅特性是它们可以标记化单张图像("视频"一帧)而无需特殊处理。第一帧没有过去上下文,所以其标记仅从该帧计算。这种图像-视频统一意味着单个标记器服务于两种模态,简化架构并使模型能用相同解码器生成图像和视频。

时间压缩策略

  • 不同应用需要不同的时间压缩比。对于动作识别(细微运动重要),温和压缩(\(f_t = 2\))保留时间细节。对于长格式视频生成(存储数千帧不可行),需要激进压缩(\(f_t = 8\) 或更高)。

  • 一些标记器使用因子化压缩:空间和时间压缩在独立阶段执行。首先,2D 编码器独立压缩每帧,产生每帧潜在网格。然后,1D 时间编码器跨时间维度压缩。这种因子化比完整 3D 卷积计算更便宜,并允许空间和时间不同压缩比。权衡是它无法像联合 3D 编码那样高效捕获时空模式(如球对角移动)。

  • 时间插值标记是近期创新,标记器仅完全编码关键帧,并将中间帧表示为轻量级插值编码,描述如何在关键帧之间变形。这镜像经典视频压缩(H.264/HEVC 中的 I 帧和 P 帧),但在学习的潜在空间中。

时间压缩策略:帧独立空间编码后接时间编码,与联合时空 3D 编码对比

连续 vs 离散标记

  • 并非每个下游模型都需要离散标记。扩散模型(第 10 章,文件 04)原生处理连续值——它们迭代去噪高斯样本,其损失函数(去噪得分匹配)定义在连续空间上。对于扩散后端,标记器编码器产生永不量化的连续潜在向量。潜在扩散模型(Stable Diffusion、DALL-E 3、Flux)使用类 VQ-GAN 编码器-解码器,但完全跳过码本,在连续潜在空间中操作。

  • 自回归模型(GPT 风格)则用 \(K\) 类 softmax 从有限词汇表预测下一标记。它们根本上需要离散标记。每个使用自回归 Transformer 的图像生成系统(DALL-E、Parti、LlamaGen、Chameleon)都依赖离散标记器。

  • 因此,连续与离散标记的选择由生成后端驱动:

  • 使用离散标记当:模型是自回归的(带交叉熵损失的下一标记预测)、你想与文本标记共享词汇表以实现统一多模态模型、或你需要精确标记级控制(如通过标记替换进行检索或编辑)。

  • 使用连续标记当:模型是扩散模型或流匹配模型、任务需要非常高保真重建(连续潜在完全避免量化误差)、或你想使用操作实值向量的回归损失。

  • 一些近期架构支持两种模式。例如,Cosmos 标记器可以从相同编码器输出连续潜在(用于其扩散模式)或 FSQ 离散标记(用于其自回归模式),带可开关的轻量级量化头。

  • 软量化是中间方案:不用硬 \(\arg\min\) 分配,计算最接近的 top-\(k\) 码本条目的加权平均,权重由负距离的 softmax 给出。这比硬量化保留更多信息,同时仍近似离散。一些系统在训练时使用软量化,推理时使用硬量化。

基于下游生成模型选择连续与离散标记化的决策树

应用

自回归图像生成

  • 一旦图像是离散标记序列,你就可以训练标准自回归 Transformer 来建模它们。图像标记被展平为一维序列(通常按光栅扫描顺序:从左到右、从上到下),Transformer 用标准交叉熵损失学习 \(p(\text{token}_i | \text{token}_1, \ldots, \text{token}_{i-1})\)。生成时,标记逐个采样,完成的网格通过标记器的解码器产生像素。

  • 以文本为条件很简单:在图像标记序列前添加文本标记,使模型学习 \(p(\text{图像标记} | \text{文本标记})\)。这正是 DALL-E、Parti 和 LlamaGen 执行文本到图像生成的方式。文本和图像标记共享相同 Transformer、相同注意力机制,且通常共享相同嵌入表(文本和图像标记占据不同索引范围)。

  • 光栅扫描顺序引入人为不对称性:图像左上角最先生成,没有任何关于右下角的上下文。多项工作解决此问题。掩码图像建模(MaskGIT)训练双向 Transformer 同时生成所有标记但具有不同置信度,迭代取消掩码最置信标记。多尺度生成首先生成粗标记(捕获全局构图),然后用残差标记细化。这些方法在纯从左到右生成的简单性与更好全局一致性之间权衡。

统一视觉-语言标记

  • 图像标记化的最深动机是统一:将视觉和语言放入相同表示格式,使单个模型架构处理两者。正如我们在第 7 章讨论的,语言模型是非凡的序列到序列机器。通过将图像表示为标记序列,我们免费继承语言建模的所有基础设施——预训练配方、缩放定律、RLHF、上下文长度扩展。

  • Chameleon(Meta, 2024)是突出示例:它使用 8192 码本条目的 VQ-GAN 标记器将图像转换为标记,与文本标记交错在约 65,000 条目的单一词汇表中(文本 + 图像)。标准 Transformer 在混合文本-图像序列上训练,使其能根据图像生成文本、根据文本生成图像、或生成交错文本-图像内容,全部使用相同前向传播。

  • Gemini(Google, 2024)以大规模采用类似方法,在单个 Transformer 中原生理解和生成图像、音频和文本,模态特定标记器输入共享序列。

  • 统一模型中的关键工程挑战是词汇表平衡:如果 65,000 词汇表条目中有 8192 个是图像标记,模型可能为视觉分配不足容量。解决方案包括为每种模态使用独立嵌入层(仅在注意力级别共享)、模态特定损失加权、以及预训练期间仔细的数据混合比例。

统一视觉-语言模型:来自独立标记器的文本和图像标记交错成单个序列,由一个 Transformer 处理

编程任务(使用 CoLab 或 Notebook)

  1. 在 JAX 中实现最小 VQ 层:给定一批编码器输出向量,执行最近邻码本查找并计算 VQ-VAE 损失(重建 + 码本 + 承诺)。将码本利用率可视化为直方图。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # --- 最小 VQ 层 ---
    key = jax.random.PRNGKey(42)
    d = 8          # 嵌入维度
    K = 64         # 码本大小
    n_vectors = 256  # 一批编码器输出
    
    # 随机编码器输出和码本
    k1, k2 = jax.random.split(key)
    z_e = jax.random.normal(k1, (n_vectors, d))       # 编码器输出
    codebook = jax.random.normal(k2, (K, d)) * 0.1     # 码本(小初始化)
    
    # 最近邻查找: 为每个 z_e 找到最近码本条目
    # distances[i, k] = ||z_e[i] - codebook[k]||^2
    distances = (
        jnp.sum(z_e ** 2, axis=1, keepdims=True)
        - 2 * z_e @ codebook.T
        + jnp.sum(codebook ** 2, axis=1, keepdims=True).T
    )
    indices = jnp.argmin(distances, axis=1)       # 标记索引
    z_q = codebook[indices]                        # 量化向量
    
    # VQ-VAE 损失项
    beta = 0.25
    loss_codebook = jnp.mean((jax.lax.stop_gradient(z_e) - z_q) ** 2)
    loss_commit   = jnp.mean((z_e - jax.lax.stop_gradient(z_q)) ** 2)
    loss_total    = loss_codebook + beta * loss_commit
    print(f"码本损失: {loss_codebook:.4f}, 承诺损失: {loss_commit:.4f}")
    
    # 码本利用率
    unique, counts = jnp.unique(indices, return_counts=True, size=K, fill_value=-1)
    plt.figure(figsize=(10, 4))
    plt.bar(range(K), counts, color='#3498db', alpha=0.8)
    plt.xlabel('码本索引'); plt.ylabel('分配计数')
    plt.title(f'码本利用率({jnp.sum(counts > 0)}/{K} 个条目被使用)')
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()
    # 尝试: 增加 K 到 512 并观察坍塌。然后添加码本重置逻辑。
    

  2. 构建玩具 2D 向量量化器,学习平铺 2D 分布。生成随机 2D 点,通过 EMA 更新学习码本,并可视化 Voronoi 区域。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 从高斯混合生成 2D 数据
    key = jax.random.PRNGKey(0)
    n_points = 2000
    K = 16  # 码本条目
    gamma = 0.99  # EMA 衰减
    
    # 四个簇
    keys = jax.random.split(key, 5)
    centres = jnp.array([[2, 2], [-2, 2], [-2, -2], [2, -2]], dtype=jnp.float32)
    data = jnp.concatenate([
        jax.random.normal(keys[i], (n_points // 4, 2)) * 0.5 + centres[i]
        for i in range(4)
    ])
    
    # 从随机数据点初始化码本
    idx = jax.random.choice(keys[4], n_points, (K,), replace=False)
    codebook = data[idx]
    ema_count = jnp.ones(K)
    ema_sum = codebook.copy()
    
    # 运行基于 EMA 的码本学习数个周期
    for epoch in range(30):
        # 将每个点分配到最近码本条目
        dists = jnp.sum((data[:, None, :] - codebook[None, :, :]) ** 2, axis=2)
        assignments = jnp.argmin(dists, axis=1)
        # EMA 更新
        for k in range(K):
            mask = (assignments == k)
            count_k = jnp.sum(mask)
            ema_count = ema_count.at[k].set(gamma * ema_count[k] + (1 - gamma) * count_k)
            if count_k > 0:
                sum_k = jnp.sum(data[mask], axis=0)
                ema_sum = ema_sum.at[k].set(gamma * ema_sum[k] + (1 - gamma) * sum_k)
        codebook = ema_sum / ema_count[:, None]
    
    # 可视化分配和码本
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    colors = plt.cm.tab20(jnp.linspace(0, 1, K))
    for k in range(K):
        mask = assignments == k
        ax.scatter(data[mask, 0], data[mask, 1], c=[colors[k]], s=5, alpha=0.3)
    ax.scatter(codebook[:, 0], codebook[:, 1], c='black', s=120, marker='X',
               edgecolors='white', linewidths=1.5, zorder=10, label='码本')
    ax.set_title(f'在 2D 数据上学习的 VQ 码本({K} 个条目)')
    ax.legend(); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()
    # 尝试: 增加 K 到 64 并观察更细平铺。减少 gamma 并观察不稳定性。
    

  3. 演示残差量化:用 \(T\) 个连续量化阶段编码一批向量,并测量重建误差如何随每级降低。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(7)
    d = 16         # 嵌入维度
    K = 32         # 每级码本大小
    T = 8          # 残差级数
    n_vectors = 512
    
    # 要量化的随机数据
    k1, *cb_keys = jax.random.split(key, T + 1)
    z = jax.random.normal(k1, (n_vectors, d))
    
    # 每级独立随机码本
    codebooks = [jax.random.normal(cb_keys[t], (K, d)) * (0.5 ** t)
                 for t in range(T)]
    
    # 残差量化循环
    residual = z.copy()
    z_hat = jnp.zeros_like(z)
    errors = []
    
    for t in range(T):
        cb = codebooks[t]
        dists = (jnp.sum(residual ** 2, axis=1, keepdims=True)
                 - 2 * residual @ cb.T
                 + jnp.sum(cb ** 2, axis=1, keepdims=True).T)
        indices = jnp.argmin(dists, axis=1)
        z_q_t = cb[indices]
        z_hat = z_hat + z_q_t
        residual = residual - z_q_t
        mse = jnp.mean(jnp.sum((z - z_hat) ** 2, axis=1))
        errors.append(float(mse))
        print(f"级别 {t+1}: MSE = {mse:.4f}")
    
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, T + 1), errors, 'o-', color='#e74c3c', linewidth=2, markersize=8)
    plt.xlabel('残差量化级别')
    plt.ylabel('重建 MSE')
    plt.title('残差量化的误差降低')
    plt.xticks(range(1, T + 1)); plt.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()
    # 尝试: 使用大小为 K*T 的单个码本并与 RQ 比较。哪个胜出?
    

  4. 模拟简单 1D"视频标记器":生成 1D 信号序列(模拟视频帧),应用因果时间压缩,并在重建质量方面与非因果压缩比较。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    key = jax.random.PRNGKey(99)
    n_frames = 16
    frame_len = 64
    
    # 生成"视频": 高斯包络在帧间缓慢移动
    x_axis = jnp.linspace(-3, 3, frame_len)
    frames = jnp.stack([
        jnp.exp(-0.5 * (x_axis - (-2 + 4 * t / n_frames)) ** 2)
        for t in range(n_frames)
    ])  # 形状: (n_frames, frame_len)
    
    # 因果时间压缩: 每帧的编码仅依赖过去帧
    # 简单方法: 用过去帧的指数衰减平均当前帧
    alpha_causal = 0.6
    causal_codes = jnp.zeros_like(frames)
    causal_codes = causal_codes.at[0].set(frames[0])
    for t in range(1, n_frames):
        causal_codes = causal_codes.at[t].set(
            alpha_causal * frames[t] + (1 - alpha_causal) * causal_codes[t - 1]
        )
    
    # 非因果: 平均过去和未来(双边平滑)
    kernel = jnp.array([0.2, 0.6, 0.2])  # 过去,当前,未来
    padded = jnp.concatenate([frames[:1], frames, frames[-1:]], axis=0)
    noncausal_codes = jnp.stack([
        kernel[0] * padded[t] + kernel[1] * padded[t+1] + kernel[2] * padded[t+2]
        for t in range(n_frames)
    ])
    
    # 重建误差
    mse_causal = jnp.mean((frames - causal_codes) ** 2)
    mse_noncausal = jnp.mean((frames - noncausal_codes) ** 2)
    print(f"因果 MSE: {mse_causal:.6f}, 非因果 MSE: {mse_noncausal:.6f}")
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for ax, data, title in zip(axes,
        [frames, causal_codes, noncausal_codes],
        ['原始帧', f'因果(MSE={mse_causal:.5f})',
         f'非因果(MSE={mse_noncausal:.5f})']):
        ax.imshow(data, aspect='auto', cmap='viridis', origin='lower')
        ax.set_xlabel('空间位置'); ax.set_ylabel('帧索引')
        ax.set_title(title)
    plt.tight_layout(); plt.show()
    # 尝试: 改变 alpha_causal 和核权重。alpha=1.0 时会发生什么?