Skip to content

视觉Transformer与生成模型

视觉Transformer将自注意力应用于图像块,以数据驱动的空间学习挑战CNN的主导地位。本文涵盖ViT、DeiT、Swin Transformer、基于GAN(StyleGAN)、VAE和扩散模型(DDPM、Stable Diffusion)的图像生成,以及超分辨率和神经风格迁移。

  • CNN(第02章)内建了强大的空间归纳偏置:局部连接、权重共享和平移等变性。视觉Transformer(ViT)提出了一个发人深省的问题:如果我们完全抛弃这些偏置,仅使用第06章中的注意力机制,让模型从数据中学习空间结构,结果会怎样?

  • 视觉Transformer(ViT)(Dosovitskiy 等,2021)直接将标准Transformer编码器应用于图像。其核心思想是将图像视为一个图像块的序列,就像NLP将文本视为一个词元序列一样。

  • 其工作流程如下:

    1. 将图像(高度 \(H\)、宽度 \(W\)、通道数 \(C\))划分为大小为 \(P \times P\) 的不重叠图像块网格。这样得到 \(N = HW / P^2\) 个图像块。
    2. 将每个图像块展平为长度为 \(P^2 \cdot C\) 的向量,并通过一个可学习的线性嵌入(单次矩阵乘法,第02章)投影到模型维度 \(D\)
    3. 在序列前添加一个可学习的 [CLS] 词元嵌入(类似于BERT中的[CLS],第07章)。该词元会关注所有图像块,其最终表示用于分类。
    4. 添加位置嵌入(每个位置一个可学习向量)以提供空间信息,因为注意力机制是置换等变的。
    5. \((N+1)\) 个词元嵌入的序列通过一个标准Transformer编码器(多头自注意力 + FFN,第06章)。
    6. 将[CLS]词元的最终表示送入分类头(一个小型MLP)。

ViT流程图:图像被分割为16x16的图像块,每个块被展平并线性投影,加入[CLS]词元,添加位置嵌入,然后通过Transformer编码器模块

  • 图像块嵌入等价于一个卷积核大小为 \(P\)、步长为 \(P\)(无重叠)的卷积。ViT将二维图像直接转换为一维序列,然后使用与语言处理相同的架构进行处理。

  • ViT的归纳偏置比CNN少:它不强制局部连接或平移等变性。这意味着它需要更多的训练数据来从头学习空间结构。在小数据集上,CNN优于ViT。但当在非常大的数据集(JFT-300M,3亿张图像)上训练时,ViT达到或超越了最佳的CNN,这表明CNN的归纳偏置有助于数据效率,但对于最终性能并非必需。

  • ViT的自注意力复杂度是图像块数量的 \(O(N^2)\)。对于224x224的图像使用16x16的图像块,\(N=196\),这是可管理的。但对于更高分辨率的图像或更小的图像块,平方成本变得难以承受。

  • DeiT(Data-efficient Image Transformer,Touvron 等,2021)表明,仅使用ImageNet数据集(无需庞大的JFT数据集),通过强数据增强、正则化(随机深度、标签平滑、丢弃法)和知识蒸馏(预训练的CNN教师提供软标签供ViT学生学习匹配),ViT也能被有效训练。DeiT在[CLS]词元旁边添加了一个蒸馏词元,其训练目标是预测教师网络的输出。

  • Swin Transformer(Liu 等,2021)解决了ViT的两个主要限制:其计算成本随图像大小呈平方增长,以及缺乏层次化特征图(检测和分割需要这样的特征图)。

  • Swin引入了移位窗口:不在所有图像块上进行全局自注意力,而是在局部窗口内计算注意力(例如,7x7个图像块)。这使得计算成本与图像大小成线性关系:\(O(N)\) 而不是 \(O(N^2)\)。但仅靠局部窗口会阻碍区域之间的信息流动。

  • 窗口移位解决了这个问题:在交替的层中,窗口划分偏移半个窗口大小。这创建了跨窗口的连接,使得信息可以在所有图像块之间流动,而不需要全局注意力的成本。

Swin Transformer:第l层在规则窗口内计算注意力,第l+1层将窗口划分偏移一半,创建跨窗口连接

  • Swin还通过跨阶段合并图像块来构建层次化表示。在每个阶段之后,相邻的2x2图像块被拼接并投影到双倍的通道数和一半的空间分辨率。这产生了类似于CNN和FPN(第03章)中的多尺度特征图,使得Swin可以直接兼容Faster R-CNN等检测头和U-Net等分割头。

  • PVT(Pyramid Vision Transformer)采用了类似的层次化方法,带有空间约简注意力:在每个阶段,在计算注意力之前对键和值进行空间下采样,降低了平方成本,同时保持了全局感受野。

  • 自监督视觉学习从未标注的图像中训练表示。标注成本高昂,但图像资源丰富。其目标是学习能够在没有任何人工标注的情况下,很好地迁移到下游任务的特征。

  • 对比学习训练模型认识到同一张图像的两个增强视图(“正样本对”)应该具有相似的表示,而不同图像的视图(“负样本对”)应该具有不相似的表示。

  • SimCLR(Chen 等,2020)为一个批次中的每张图像创建两个增强视图,使用共享骨干网络+投影头对两者进行编码,并应用NT-Xent损失(归一化温度缩放交叉熵):

\[\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}\]
  • 其中 \(\text{sim}\) 是余弦相似度(第01章),\(\tau\) 是温度参数。分子将正样本对拉近;分母将负样本对推远。SimCLR需要大批量(4096+)来提供足够的负样本。

  • MoCo(Momentum Contrast,He 等,2020)通过维护一个动量更新的负嵌入队列来避免大批量的需求。查询编码器通过梯度下降更新;键编码器作为查询编码器的指数移动平均(EMA,第04章)更新:\(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\),其中 \(m = 0.999\)。该队列存储最近的键嵌入,在不需大批量的情况下提供大量且一致的负样本集。

  • BYOL(Bootstrap Your Own Latent,Grill 等,2020)完全去掉了负样本对。它使用两个网络:“在线”网络和“目标”网络(在线网络的EMA)。在线网络预测目标网络对不同增强视图的表示。在没有负样本的情况下,BYOL通过预测头和非对称的EMA目标避免了模型坍塌问题(模型对所有输入输出相同向量)。

  • DINO(Self-Distillation with No Labels,Caron 等,2021)将自蒸馏应用于ViT。一个学生网络预测教师网络(学生的EMA)在不同增强视图上的输出。教师使用更大的裁剪,学生使用更小的裁剪。DINO产生的特征包含关于场景布局的显式信息:经过DINO训练的ViT的自注意力图无需任何分割监督就能自然地分割物体。

  • 掩码图像建模是BERT掩码语言建模(第07章)在视觉领域的对应方法。输入图像块的很大一部分被掩蔽,模型学习重建它们。

  • MAE(Masked Autoencoders,He 等,2022)掩蔽75%的图像块,并训练一个ViT编码器-解码器来重建缺失的像素值。编码器只处理未被掩蔽的图像块(预训练时节省了4倍计算量),轻量级解码器从编码后的可见图像块加上可学习的掩码词元重建完整图像。

  • BEiT(BERT Pre-training of Image Transformers,Bao 等,2022)掩蔽图像块并预测离散的视觉词元(从预训练的dVAE分词器获得),而不是原始像素。这与BERT预测离散词元的方式相呼应,避免了像素重建的低级细节。

  • 图像生成旨在产生训练集中不存在的新颖、逼真的图像。其核心挑战是对自然图像的高维概率分布进行建模。

  • 生成对抗网络(GAN)(Goodfellow 等,2014)使用两个相互竞争的网络:生成器 \(G\) 从随机噪声生成虚假图像,判别器 \(D\) 试图区分真实图像和虚假图像。它们通过对抗方式训练:\(G\) 试图欺骗 \(D\)\(D\) 试图捕获 \(G\)

\[\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]\]
  • 生成器接收一个随机潜在向量 \(z\)(从简单分布如高斯分布中采样),并通过一系列转置卷积将其映射生成图像。判别器是一个标准的CNN分类器。在平衡状态下,\(G\) 生成的图像与真实数据无法区分,\(D\) 对所有输入输出0.5。

  • 模式坍塌是GAN的主要失败模式:生成器只学会产生少数几种能欺骗判别器的图像,忽略了训练数据的多样性。生成器找到一小部分“安全”的输出,而不是覆盖整个分布。

  • 稳定GAN的训练技巧包括:谱归一化(约束判别器的Lipschitz常数)、渐进式增长(先在低分辨率训练,然后逐渐提高)、特征匹配(匹配中间判别器特征的统计量而非最终输出),以及使用Wasserstein距离代替原始的JS散度目标。

  • StyleGAN(Karras 等,2019)是最具影响力的高质量图像合成GAN架构。其关键创新是基于风格的生成器:不是将潜在向量 \(z\) 直接输入生成器,而是首先通过一个映射网络(一个8层MLP)将其映射为风格向量 \(w\)。然后通过自适应实例归一化(AdaIN) 将这个风格向量注入生成器的每一层,从而调制特征图的统计量:

\[\text{AdaIN}(x, y) = y_{s} \cdot \frac{x - \mu(x)}{\sigma(x)} + y_{b}\]
  • 其中 \(y_s\)\(y_b\) 是从 \(w\) 导出的缩放和偏置。不同层控制不同方面的特征:早期层控制粗糙特征(姿态、脸型),中间层控制中等级别特征(发型、眼睛),后期层控制精细细节(雀斑、发质)。StyleGAN能够生成1024x1024分辨率的逼真人脸。

  • 变分自编码器(VAE)(第06章)提供了另一种生成方法。与GAN不同,VAE具有原理性的概率框架和明确的训练目标(ELBO)。它们生成的图像通常比GAN更模糊,但提供了更平滑、结构更清晰的潜在空间。VAE是潜在扩散模型中用于将图像压缩到潜在空间以及从潜在空间重建图像的编码器-解码器对。

  • 扩散模型已成为图像生成的主导范式,在质量和多样性上都超过了GAN。其思想概念上简单:逐步向数据中添加噪声,直到变成纯高斯噪声(正向过程),然后学习逐步逆转这个过程(逆向过程)。

  • 正向过程\(T\) 个时间步中添加高斯噪声:

\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \, x_{t-1}, \beta_t I)\]
  • 其中 \(\beta_t\) 是随时间增加的噪声调度。经过足够多的步数后,无论原始图像 \(x_0\) 是什么,\(x_T\) 都近似为纯高斯噪声。利用重参数化技巧(第06章),并令 \(\alpha_t = 1 - \beta_t\)\(\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s\),我们可以直接从 \(x_0\) 采样 \(x_t\)
\[x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]
  • 逆向过程学习去噪:从纯噪声 \(x_T\) 开始,模型预测每一步添加的噪声 \(\epsilon\) 并将其减去以恢复 \(x_{t-1}\)。这个过程由一个神经网络 \(\epsilon_\theta\)(通常是一个U-Net,来自第03章)参数化,并使用简单的MSE损失进行训练:
\[\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]\]

扩散正向和逆向过程:干净图像在T步内逐渐被噪声破坏(正向),神经网络学习逆转每一步(逆向),从纯噪声开始生成干净图像

  • DDPM(Denoising Diffusion Probabilistic Models,Ho 等,2020)确立了这一框架。采样需要迭代所有 \(T\) 步(通常为1000步),速度较慢。DDIM(Denoising Diffusion Implicit Models,Song 等,2021)将采样过程重新表述为确定性映射,允许大步长跳跃(例如,50步代替1000步),且质量损失极小。

  • 基于分数的模型(Song 和 Ermon,2019)提供了另一种视角。模型不是预测噪声 \(\epsilon\),而是估计得分函数 \(\nabla_{x_t} \log p(x_t)\),即对数概率关于带噪图像的梯度。该梯度指向数据分布中概率更高(更干净)的区域。采样通过朗之万动力学沿着该梯度方向进行。基于分数的模型和DDPM在随机微分方程(SDE) 框架下统一起来:正向过程是添加噪声的SDE,逆向过程是时间反转的SDE。

  • 无分类器引导(Ho 和 Salimans,2022)控制了样本质量和多样性之间的权衡。模型同时有条件地(带有文本提示或类别标签)和无条件地(随机丢弃条件)进行训练。在采样时,预测是有权重的组合:

\[\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))\]
  • 其中 \(c\) 是条件,\(\varnothing\) 是空条件,\(s > 1\) 是引导尺度。越大的 \(s\) 生成的图像与条件匹配得越强,但多样性降低。\(s = 1\) 表示无引导模型;\(s = 7.5\) 是常见的默认值。

  • 潜在扩散(Rombach 等,2022;Stable Diffusion)将扩散过程从像素空间转移到学习的潜在空间。预训练VAE编码器将图像压缩到较低维的潜在表示(通常空间下采样4倍或8倍),扩散在这个压缩空间中进行,然后VAE解码器从去噪后的潜在表示重建像素。这极大地提高了效率:在像素空间对512x512图像进行扩散意味着处理 \(512 \times 512 \times 3\) 的张量,而在潜在空间仅需处理 \(64 \times 64 \times 4\) 的张量。

  • 潜在扩散中的去噪U-Net接收带噪潜在表示、时间步(编码为正弦嵌入,类似于Transformer中的位置编码)以及一个条件信号(来自冻结的CLIP或T5文本编码器的文本嵌入)。文本条件通过U-Net内部的交叉注意力层注入:文本嵌入作为键和值,图像特征作为查询。这使得模型能够在每个空间位置关注文本提示的相关部分。

  • 流匹配是扩散模型的一种新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是DDPM的迭代去噪。

  • 连续归一化流(CNF) 定义了一个随时间变化的向量场 \(v_\theta(x, t)\),它将样本从简单分布 \(p_0\)(噪声)沿着平滑轨迹推动到数据分布 \(p_1\)。变换遵循常微分方程(ODE):

\[\frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1]\]
  • \(x_0 \sim \mathcal{N}(0, I)\) 开始,将ODE向前积分到 \(t = 1\) 即可从数据分布中生成样本。向量场由神经网络参数化,并训练其匹配目标条件流。

  • 最优传输(OT)流匹配(Lipman 等,2023)使用噪声和数据之间的直线路径作为目标流:从噪声样本 \(x_0\) 到数据样本 \(x_1\) 的条件路径就是简单的 \(x_t = (1 - t) x_0 + t x_1\),目标速度为 \(v = x_1 - x_0\)。训练损失变为:

\[\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\]
  • 修正流(Rectified flows,Liu 等,2022)迭代地拉直学习到的流动路径。经过初始训练后,模型被用来通过模拟ODE生成(噪声,数据)对。这些对相比随机配对更对齐,然后用于重新训练模型。重复此过程会产生越来越直的路径,可以在更少的ODE步数(甚至单步)内完成,从而实现极快的生成。

  • 流匹配相比扩散模型有几个优势:训练目标更简单(直接的速度回归,无需噪声调度),采样ODE更平滑(需要更少的积分步数),并且与最优传输的联系提供了理论基础。Stable Diffusion 3 和 Flux 使用流匹配而不是传统的DDPM。

编程任务(使用CoLab或notebook)

  1. 从零实现ViT的图像块嵌入。将图像分割成块,展平,投影到模型维度,添加位置嵌入,并添加[CLS]词元。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def create_patch_embedding(image, patch_size, d_model, params):
        """将图像转换为图像块嵌入序列。"""
        H, W, C = image.shape
        n_patches_h = H // patch_size
        n_patches_w = W // patch_size
        n_patches = n_patches_h * n_patches_w
    
        # 提取图像块
        patches = []
        for i in range(n_patches_h):
            for j in range(n_patches_w):
                patch = image[i*patch_size:(i+1)*patch_size,
                              j*patch_size:(j+1)*patch_size, :]
                patches.append(patch.ravel())
        patches = jnp.stack(patches)  # (N, P*P*C)
    
        # 线性投影到 d_model
        embeddings = patches @ params['proj_w'] + params['proj_b']  # (N, d_model)
    
        # 添加 CLS 词元
        cls_token = params['cls_token']  # (1, d_model)
        embeddings = jnp.concatenate([cls_token, embeddings], axis=0)  # (N+1, d_model)
    
        # 添加位置嵌入
        embeddings = embeddings + params['pos_embed']  # (N+1, d_model)
    
        return embeddings, patches
    
    # 设置参数
    H, W, C = 32, 32, 3
    patch_size = 8
    d_model = 64
    n_patches = (H // patch_size) * (W // patch_size)  # 16
    
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, 5)
    
    # 创建合成图像,四个象限不同颜色
    image = jnp.zeros((H, W, C))
    image = image.at[:16, :16, 0].set(1.0)   # 红色 左上角
    image = image.at[:16, 16:, 1].set(1.0)   # 绿色 右上角
    image = image.at[16:, :16, 2].set(1.0)   # 蓝色 左下角
    image = image.at[16:, 16:, :2].set(1.0)  # 黄色 右下角
    
    params = {
        'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02,
        'proj_b': jnp.zeros(d_model),
        'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02,
        'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02,
    }
    
    embeddings, patches = create_patch_embedding(image, patch_size, d_model, params)
    
    print(f"图像形状: {image.shape}")
    print(f"图像块大小: {patch_size}x{patch_size}")
    print(f"图像块数量: {n_patches}")
    print(f"图像块向量长度: {patch_size**2 * C}")
    print(f"嵌入形状: {embeddings.shape}  (CLS + {n_patches} 个图像块)")
    
    # 可视化图像块
    fig, axes = plt.subplots(2, 5, figsize=(14, 6))
    axes[0, 0].imshow(image); axes[0, 0].set_title('完整图像'); axes[0, 0].axis('off')
    for idx in range(min(9, n_patches)):
        ax = axes[(idx+1) // 5, (idx+1) % 5]
        patch_img = patches[idx].reshape(patch_size, patch_size, C)
        ax.imshow(patch_img); ax.set_title(f'图像块 {idx}'); ax.axis('off')
    plt.suptitle('ViT 图像块分解')
    plt.tight_layout(); plt.show()
    

  2. 实现一个简单的GAN训练循环。在二维数据上训练生成器和判别器,并可视化生成的分布如何收敛到真实分布。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def generator(z, params):
        h = jnp.tanh(z @ params['g_w1'] + params['g_b1'])
        h = jnp.tanh(h @ params['g_w2'] + params['g_b2'])
        return h @ params['g_w3'] + params['g_b3']
    
    def discriminator(x, params):
        h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2)
        h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2)
        return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3'])
    
    def init_params(key):
        keys = jax.random.split(key, 6)
        z_dim, h_dim, data_dim = 2, 32, 2
        scale = 0.1
        return {
            'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale,
            'g_b1': jnp.zeros(h_dim),
            'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale,
            'g_b2': jnp.zeros(h_dim),
            'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale,
            'g_b3': jnp.zeros(data_dim),
            'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale,
            'd_b1': jnp.zeros(h_dim),
            'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale,
            'd_b2': jnp.zeros(h_dim),
            'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale,
            'd_b3': jnp.zeros(1),
        }
    
    def d_loss(params, real_data, fake_data):
        real_score = discriminator(real_data, params)
        fake_score = discriminator(fake_data, params)
        return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7))
    
    def g_loss(params, fake_data):
        fake_score = discriminator(fake_data, params)
        return -jnp.mean(jnp.log(fake_score + 1e-7))
    
    # 真实数据:环状分布
    key = jax.random.PRNGKey(42)
    theta = jax.random.uniform(key, (512,)) * 2 * jnp.pi
    real_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)
    real_data = real_data + jax.random.normal(key, real_data.shape) * 0.05
    
    params = init_params(jax.random.PRNGKey(0))
    d_grad = jax.grad(d_loss)
    g_grad = jax.grad(g_loss)
    lr = 0.001
    
    snapshots = []
    for step in range(3000):
        key, k1 = jax.random.split(key)
        z = jax.random.normal(k1, (512, 2))
        fake_data = generator(z, params)
    
        # 更新判别器
        grads = d_grad(params, real_data, fake_data)
        for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']:
            params[k] = params[k] - lr * grads[k]
    
        # 更新生成器
        fake_data = generator(z, params)
        grads = g_grad(params, fake_data)
        for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']:
            params[k] = params[k] - lr * grads[k]
    
        if step in [0, 500, 1500, 2999]:
            snapshots.append((step, fake_data.copy()))
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for ax, (step, fake) in zip(axes, snapshots):
        ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='真实')
        ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='生成')
        ax.set_title(f'步骤 {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)
        ax.set_aspect('equal'); ax.legend(markerscale=3)
    plt.suptitle('GAN训练:生成器学习环状分布')
    plt.tight_layout(); plt.show()
    

  3. 实现扩散正向过程:在递增的时间步上向图像添加噪声,并可视化逐步退化。然后实现一个单步去噪过程。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def noise_schedule(T, beta_start=0.0001, beta_end=0.02):
        """线性噪声调度。"""
        betas = jnp.linspace(beta_start, beta_end, T)
        alphas = 1.0 - betas
        alpha_bars = jnp.cumprod(alphas)
        return betas, alphas, alpha_bars
    
    def forward_diffusion(x0, t, alpha_bars, key):
        """在时间步t向x0添加噪声。"""
        alpha_bar_t = alpha_bars[t]
        noise = jax.random.normal(key, x0.shape)
        xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise
        return xt, noise
    
    # 创建一个简单的2D“图像”(棋盘格)
    img = jnp.zeros((32, 32))
    for i in range(4):
        for j in range(4):
            if (i + j) % 2 == 0:
                img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0)
    
    T = 1000
    betas, alphas, alpha_bars = noise_schedule(T)
    
    # 可视化正向过程
    timesteps = [0, 50, 200, 500, 999]
    key = jax.random.PRNGKey(42)
    
    fig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5))
    for ax, t in zip(axes, timesteps):
        key, subkey = jax.random.split(key)
        xt, noise = forward_diffusion(img, t, alpha_bars, subkey)
        ax.imshow(xt, cmap='gray', vmin=-2, vmax=2)
        ax.set_title(f't={t}\n$\\bar{{\\alpha}}$={alpha_bars[t]:.3f}')
        ax.axis('off')
    plt.suptitle('扩散正向过程:渐进添加噪声')
    plt.tight_layout(); plt.show()
    
    # 简单去噪:训练一个微小的网络来预测t=200时的噪声
    t_denoise = 200
    key, k1 = jax.random.split(key)
    xt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1)
    
    # 微型“去噪器”:仅用于演示,学习一个常数噪声估计
    noise_estimate = jnp.zeros_like(img)
    lr = 0.01
    for step in range(100):
        residual = noise_estimate - true_noise
        noise_estimate = noise_estimate - lr * residual
    
    # 单步逆向
    alpha_bar_t = alpha_bars[t_denoise]
    x_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(img, cmap='gray'); axes[0].set_title('原始 $x_0$'); axes[0].axis('off')
    axes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2)
    axes[1].set_title(f'带噪 $x_{{200}}$'); axes[1].axis('off')
    axes[2].imshow(x_denoised, cmap='gray')
    axes[2].set_title('去噪后(单步)'); axes[2].axis('off')
    plt.tight_layout(); plt.show()
    
    mse = jnp.mean((x_denoised - img)**2)
    print(f"去噪MSE: {mse:.4f}")