Skip to content

深度学习

深度学习堆叠非线性层以构建层次化表示,自动将原始输入转化为有用的特征。本文件涵盖多层感知机、激活函数、反向传播、卷积神经网络、循环神经网络、长短期记忆网络、注意力机制、Transformer、生成对抗网络、变分自编码器、扩散模型以及归一化技术

  • 是什么让一个网络“深”?浅层网络只有一个隐藏层;深层网络有许多隐藏层。深度让网络能够构建层次化表示:早期层学习简单特征(边缘、色调),后期层将它们组合成复杂概念(人脸、句子)。这种组合性是深度学习强大能力的来源。

  • 最简单的深度网络是多层感知机(MLP),也称为全连接网络或稠密网络。每一层计算:

\[h = \sigma(Wx + b)\]
  • 这里 \(W\) 是权重矩阵(第02章),\(b\) 是偏置向量,\(\sigma\) 是非线性激活函数。一层的输出成为下一层的输入。没有非线性,堆叠层将毫无意义:\(W_2(W_1 x) = (W_2 W_1)x\),仅仅是另一个线性变换。这正是第02章中矩阵乘法坍缩的体现。

  • 激活函数引入非线性,使深度变得有意义。

  • ReLU(修正线性单元):\(\text{ReLU}(x) = \max(0, x)\)。它是使用最广泛的激活函数。计算快,对正输入不会饱和,并产生稀疏激活(许多神经元输出精确为零)。缺点:输入为负的神经元总是输出零,如果它们永久卡在那里,就会“死亡”并停止学习。

  • Sigmoid\(\sigma(x) = \frac{1}{1+e^{-x}}\),将输入压缩到 \((0, 1)\)。适用于二分类的输出层,但在隐藏层中有问题,因为当输入远离零时梯度会消失(曲线几乎平坦)。

  • Tanh\(\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\),压缩到 \((-1, 1)\)。零中心(不同于sigmoid),有助于梯度流动,但在极值处仍存在梯度消失问题。

  • GELU(高斯误差线性单元):\(\text{GELU}(x) = x \cdot \Phi(x)\),其中 \(\Phi\) 是标准正态分布的累积分布函数。它是ReLU的光滑近似,允许小的负值通过。GELU 是 GPT 和 BERT 中的默认激活函数。

  • Swish\(\text{Swish}(x) = x \cdot \sigma(x)\),另一种光滑门控函数,实践中与GELU类似。

ReLU、Sigmoid、Tanh和GELU的并排图,附有各自的关键属性

  • 一个输入维度为 \(d_{\text{in}}\)、输出维度为 \(d_{\text{out}}\) 的稠密层有 \(d_{\text{in}} \times d_{\text{out}} + d_{\text{out}}\) 个参数(权重加偏置)。矩阵乘法 \(Wx\) 就是第02章的矩阵-向量乘法。在批处理设置中,输入是形状为 \((B, d_{\text{in}})\) 的矩阵 \(X\),输出是形状为 \((B, d_{\text{out}})\)\(XW^T + b\)

  • 万能逼近定理指出:一个具有足够多神经元的单隐藏层可以以任意精度逼近紧致域上的任何连续函数。这听起来好像深度不重要,但关键在于“足够多神经元”。在实践中,深度网络可以用指数级更少的参数表示同样的函数。深度带来的不仅是表示能力,更是效率。

  • 随着网络变深,两种梯度病理问题出现。梯度消失:当梯度通过许多层时(通过链式法则,第03章),它们会乘以许多因子。如果这些因子持续小于1(如 sigmoid 和 tanh 饱和时出现的情况),梯度会指数级收缩到零。早期层几乎学不到东西。梯度爆炸:如果因子持续大于1,梯度会指数级增长,导致数值溢出和训练不稳定。

  • 梯度消失/爆炸的解决方案:

  • 使用 ReLU 或 GELU 激活函数(正输入时梯度为1,不饱和)
  • 谨慎的权重初始化
  • 归一化层
  • 残差连接(跳跃连接)
  • 梯度裁剪(针对梯度爆炸):将梯度范数限制在最大值以内

  • 权重初始化很重要,因为它决定了训练开始时激活值和梯度的尺度。权重大大会使激活值爆炸;太小则使激活值消失。

  • Xavier(Glorot)初始化从方差为 \(\frac{2}{d_{\text{in}} + d_{\text{out}}}\) 的分布中设置权重。这使激活值的方差在各层之间大致保持不变,假设使用线性或 tanh 激活函数。

  • He(Kaiming)初始化使用方差 \(\frac{2}{d_{\text{in}}}\),是为 ReLU 激活函数校准的(因为 ReLU 将一半激活值置零,需要加倍方差来补偿)。

  • 归一化层通过确保每层输入具有一致的统计特性(大致零均值、单位方差)来稳定训练。

  • 批量归一化(BatchNorm)在批次维度上归一化:对每个通道/特征,计算小批量中所有样本的均值和方差,然后归一化。它添加了可学习的缩放参数 \(\gamma\) 和偏移参数 \(\beta\),以便网络在需要时可以撤销归一化:

\[\hat{x} = \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}, \quad y = \gamma \hat{x} + \beta\]
  • BatchNorm 有一个问题:它依赖于批次大小。当批次非常小时,统计量有噪声。在推理时,使用运行平均值而不是批次统计量,这会造成训练/测试不一致。

  • 层归一化(LayerNorm)对每个样本在特征维度上进行归一化。它不依赖于批次中的其他样本,因此成为 Transformer 和循环网络的标准选择。

  • 实例归一化对每个样本和每个通道独立地在空间维度上归一化。它在风格迁移中很流行。

  • 组归一化将通道分成组,在每组内归一化。它是 LayerNorm 和 InstanceNorm 之间的折中。

3D张量,彩色的切片显示BatchNorm、LayerNorm和InstanceNorm分别归一化哪些维度

  • Dropout 是一种正则化技术,在训练期间随机将一部分 \(p\) 的神经元置为零。这迫使网络不依赖任何单个神经元,鼓励冗余表示。测试时,所有神经元都激活。反向 Dropout 在训练期间将激活值缩放 \(1/(1-p)\),这样测试时就无需缩放。这是标准实现。

  • 卷积神经网络(CNN)利用空间结构。卷积层不是将每个输入连接到每个输出(如稠密层那样),而是在输入上滑动一个小的滤波器(核),在每个位置计算点积。相同的滤波器权重在所有位置共享,这极大地减少了参数,并内置了平移不变性。

  • 对于二维输入,滤波器 \(K\) 大小为 \(k \times k\)卷积操作

\[(\text{input} * K)[i,j] = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} \text{input}[i+m, j+n] \cdot K[m, n]\]

输入网格,一个3x3滤波器在其上滑动,在每个位置进行逐元素相乘并求和,生成输出特征图

  • 输出大小取决于三个超参数。步长控制滤波器在位置之间移动的像素数(步长2会使空间尺寸减半)。填充在输入边界周围添加零(“same”填充保持空间尺寸,“valid”填充则不加)。输出大小公式:\(\text{out} = \lfloor (\text{in} - k + 2p) / s \rfloor + 1\)

  • 池化层对特征图进行下采样。最大池化取每个窗口中的最大值;平均池化取均值。池化降低空间维度,同时保留最重要的信息。

  • 空洞卷积在滤波器元素之间插入间隔,在不增加参数的情况下增大感受野。空洞率为2意味着3x3滤波器覆盖5x5区域。

  • 1x1卷积是使用1x1滤波器的卷积。它们不关注空间邻居,而是在通道间混合信息。可以将其视为在每个空间位置应用一个稠密层。它们用于廉价地改变通道数。

  • 跳跃连接(残差连接)允许输入绕过一层或多层:\(\text{output} = F(x) + x\)。层只需要学习残差 \(F(x) = \text{output} - x\),当最优变换接近恒等映射时这更容易。ResNet(残差网络)利用这个技巧堆叠了超过100层,解决了更深网络反而比浅层网络表现更差的退化问题。

  • CNN 构建特征层次结构。早期层检测边缘和纹理。中间层将它们组合成部件(眼睛、轮子)。后期层识别完整物体。每一层的感受野(它能够“看到”的输入区域)随深度增长。

  • 嵌入将离散的标记(单词、字符、物品ID)映射为稠密向量。嵌入层本质上是一个查找表:形状为(词汇表大小,嵌入维度)的矩阵 \(E\)。查找标记 \(i\) 就是选取 \(E\) 的第 \(i\) 行。这等价于乘以一个独热向量,而这只是矩阵-向量乘法的一个特例(第02章)。嵌入在训练过程中学习,因此相似的标记最终会有相似的向量。

  • 分词是将原始文本转换为标记序列的过程。单词级分词按空格切分,但无法处理未见过的单词。子词分词(BPE、WordPiece、SentencePiece)将文本切分为频繁出现的子词单元,平衡词汇表大小和覆盖率。例如 "unhappiness" 可能变成 ["un", "happiness"] 或 ["un", "happ", "iness"]。

  • 循环神经网络(RNN)逐个元素处理序列,维护一个携带信息向前的隐藏状态:

\[h_t = \tanh(W_h h_{t-1} + W_x x_t + b)\]
  • 隐藏状态 \(h_t\) 是网络截至时间 \(t\) 所看到的所有内容的压缩摘要。相同的权重 \(W_h\)\(W_x\) 在所有时间步共享(权重共享,类似于 CNN 共享空间权重)。

  • 普通 RNN 难以处理长序列,因为梯度消失:从时间步 \(t\)\(t-k\) 的梯度信号要经过 \(k\) 次乘 \(W_h\),会指数级收缩(或爆炸)。

  • LSTM(长短期记忆网络)通过引入一个独立的细胞状态 \(c_t\) 来解决这个问题,该状态在时间上以最小的干扰流动。三个门控制哪些信息进入、离开和持续:

  • 遗忘门决定从细胞状态中擦除什么:\(f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)\)

  • 输入门决定写入什么新信息:\(i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)\),同时有候选值 \(\tilde{c}_t = \tanh(W_c [h_{t-1}, x_t] + b_c)\)
  • 细胞状态更新:\(c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t\)
  • 输出门决定暴露什么:\(o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)\),以及 \(h_t = o_t \odot \tanh(c_t)\)

LSTM单元示意图,显示遗忘门、输入门、输出门、细胞状态高速公路和数据流连接

  • 细胞状态就像一条传送带:信息可以在许多时间步上保持不变地流动(遗忘门保持接近1),这解决了长距离依赖的梯度消失问题。

  • GRU(门控循环单元)简化了 LSTM,将细胞状态和隐藏状态合并为一个,使用两个门而不是三个:一个更新门(合并遗忘和输入)和一个重置门。GRU 参数更少,性能通常与 LSTM 相当。

  • RNN(包括 LSTM)的根本限制是顺序处理:必须先处理标记1,再处理标记2,然后标记3。这阻止了并行化,并造成信息瓶颈,因为所有上下文都必须挤过固定大小的隐藏状态。

  • 注意力解决了这两个问题。注意力机制不将整个输入压缩成一个固定向量,而是让模型能够回顾所有输入位置,并决定哪些位置与当前输出相关。

  • 现代公式使用查询、键和值(Q, K, V)。可以将其想象为一次图书馆搜索:你有一个查询(你在寻找什么),每个书本有一个键(标签),以及实际的书本内容(值)。你将查询与所有键进行比较,以确定检索哪些值。

  • 缩放点积注意力

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]
  • \(QK^T\) 计算每个查询与每个键之间的相似度。这是一个矩阵乘法(第02章),其条目是点积,用于衡量余弦相似度(第01章)。除以 \(\sqrt{d_k}\) 防止点积变得太大(否则会使 softmax 饱和,产生接近独热的分布,导致梯度消失)。softmax 将相似度转换为概率分布。乘以 \(V\) 产生值的加权组合。

  • 多头注意力并行运行 \(h\) 个注意力操作,每个操作使用不同的学习到的 Q、K、V 投影。这让模型可以同时从不同的表示子空间中获取信息。一个头可能关注句法关系,而另一个头关注语义关系。输出被拼接并投影:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O\]
  • Transformer 架构(Vaswani 等人,2017)完全由注意力和前馈层构成,没有循环。编码器块重复:多头自注意力、相加和层归一化、前馈网络、相加和层归一化。解码器块添加了一个掩蔽自注意力(防止模型看到未来标记)和一个交叉注意力层(关注编码器的输出)。

Transformer编码器块:多头注意力、相加和层归一化、前馈网络、相加和层归一化,带有残差连接

  • 位置编码是必要的,因为注意力是置换等变的,意味着它将输入视为一个集合而非序列。没有位置信息,“猫坐在垫子上”和“垫子坐在猫上”将会是相同的。原始 Transformer 使用正弦位置编码:
\[PE_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)\]
  • 每个位置得到一个唯一的向量,模型可以用它来区分位置。现代模型通常使用可学习的位置嵌入或相对位置编码(RoPE、ALiBi)来代替。

  • Transformer 并行处理所有标记(自注意力矩阵 \(QK^T\) 通过一次矩阵乘法计算),这使得在现代硬件上训练比 RNN 快得多。代价是自注意力在序列长度上是 \(O(n^2)\)(每个标记关注所有其他标记),而 RNN 是 \(O(n)\)。这就是为什么长上下文模型需要特殊的注意力变体(稀疏注意力、线性注意力、Flash Attention)。

  • 视觉 Transformer(ViT) 通过将图像分割成固定大小的块(例如16×16),将每个块展平为一个向量,并将这些块视为标记序列,从而将 Transformer 应用于图像。在序列开头添加一个可学习的 [CLS] 标记,其最终表示用于分类。尽管没有卷积的归纳偏置,但在足够数据上训练时,ViT 可以达到或超越 CNN。

  • MLP-Mixer 是一种更简单的架构,用 MLP 取代了注意力和卷积。它交替进行“标记混合”MLP(在空间位置上应用)和“通道混合”MLP(在特征上应用)。它的表现具有竞争力,这表明现代架构的关键洞察不是注意力本身,而是有效混合跨标记和跨特征的信息。

  • 自编码器通过训练网络重构自身输入来学习压缩表示。编码器将输入映射到更低维的瓶颈(潜在编码),解码器将其映射回来:

\[z = f_{\text{enc}}(x), \quad \hat{x} = f_{\text{dec}}(z), \quad \mathcal{L} = \|x - \hat{x}\|^2\]
  • 瓶颈迫使网络学习最重要的特征。自编码器用于降维、去噪(用带噪声的输入训练,重构干净输出)以及异常检测(高重构误差表示输入不寻常)。

  • 变分自编码器(VAE) 添加了概率性的改变。编码器不是编码到一个单点 \(z\),而是输出一个分布的参数(高斯的均值 \(\mu\) 和方差 \(\sigma^2\))。潜在代码从这个分布中采样:\(z = \mu + \sigma \odot \epsilon\),其中 \(\epsilon \sim \mathcal{N}(0, I)\)。这种重参数化技巧使采样可微分,从而梯度可以流过。

  • VAE 的损失有两项:

\[\mathcal{L} = \underbrace{\|x - \hat{x}\|^2}_{\text{重构}} + \underbrace{D_{\text{KL}}(q(z|x) \| p(z))}_{\text{正则化}}\]
  • KL 散度项(来自第05章)将学到的后验 \(q(z|x)\) 推向先验 \(p(z) = \mathcal{N}(0, I)\),确保潜在空间平滑且结构良好。然后你可以从先验中采样并解码以生成新数据。这就是 VAE 成为生成模型的原因。

  • 扩散模型 通过逐步向数据添加噪声,然后学习逆转该过程来生成数据。前向过程:从真实数据 \(x_0\) 开始,逐渐添加高斯噪声,得到 \(x_1, x_2, \ldots, x_T\),直到 \(x_T\) 接近纯噪声。反向过程:学习一个去噪网络 \(\epsilon_\theta(x_t, t)\),预测添加到 \(x_t\) 中的噪声,从而逐步恢复 \(x_{t-1}\)

  • 训练目标是简化的均方误差:\(\mathcal{L} = \|\epsilon - \epsilon_\theta(x_t, t)\|^2\),其中 \(\epsilon\) 是实际添加的噪声。在推理时,从纯噪声 \(x_T\) 开始,重复应用去噪网络 \(T\) 步,得到生成的 \(x_0\)。扩散模型在图像生成方面取得了最先进的结果(DALL-E 2、Stable Diffusion、Imagen)。

  • 生成对抗网络(GAN) 包含两个博弈的网络:生成器 \(G\) 尝试生成逼真的假样本,判别器 \(D\) 尝试区分真实样本和假样本。训练目标是极小极大博弈:

\[\min_G \max_D \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]\]
  • GAN 能生成高度逼真的样本,但训练可能不稳定,容易出现模式坍缩(生成器只产生少数几种变体)。现代的变体如 StyleGAN 和 Wasserstein GAN 缓解了这些问题。

  • 归一化流 构建一系列可逆变换,将简单分布(如高斯)映射到复杂的数据分布。通过变量变换公式精确计算似然,实现直接的密度估计和生成。与 VAE 和 GAN 相比,流的计算成本更高,但能提供精确的似然。

编码任务(使用 CoLab 或 notebook)

  1. 在 JAX 中从零构建一个简单的 MLP。在二维分类问题(例如同心圆)上训练,并可视化决策边界。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_circles
    
    # 数据
    X, y = make_circles(n_samples=500, noise=0.1, factor=0.5, random_state=42)
    X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)
    
    # 初始化一个2层MLP: 2 -> 16 -> 16 -> 1
    def init_params(key):
        k1, k2, k3 = jax.random.split(key, 3)
        return {
            'W1': jax.random.normal(k1, (2, 16)) * 0.5,
            'b1': jnp.zeros(16),
            'W2': jax.random.normal(k2, (16, 16)) * 0.5,
            'b2': jnp.zeros(16),
            'W3': jax.random.normal(k3, (16, 1)) * 0.5,
            'b3': jnp.zeros(1),
        }
    
    def forward(params, x):
        h = jnp.maximum(0, x @ params['W1'] + params['b1'])  # ReLU
        h = jnp.maximum(0, h @ params['W2'] + params['b2'])   # ReLU
        logit = (h @ params['W3'] + params['b3']).squeeze()
        return jax.nn.sigmoid(logit)
    
    def loss_fn(params, X, y):
        pred = forward(params, X)
        return -jnp.mean(y * jnp.log(pred + 1e-7) + (1 - y) * jnp.log(1 - pred + 1e-7))
    
    grad_fn = jax.jit(jax.grad(loss_fn))
    params = init_params(jax.random.PRNGKey(0))
    lr = 0.1
    
    for step in range(2000):
        grads = grad_fn(params, X, y)
        params = {k: params[k] - lr * grads[k] for k in params}
    
    # 绘制决策边界
    xx, yy = jnp.meshgrid(jnp.linspace(-2, 2, 200), jnp.linspace(-2, 2, 200))
    grid = jnp.column_stack([xx.ravel(), yy.ravel()])
    zz = forward(params, grid).reshape(xx.shape)
    
    plt.figure(figsize=(7, 6))
    plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])
    plt.scatter(X[y==0,0], X[y==0,1], c='#e74c3c', s=10, label='类别0')
    plt.scatter(X[y==1,0], X[y==1,1], c='#3498db', s=10, label='类别1')
    plt.title("MLP在同心圆上的决策边界")
    plt.legend(); plt.grid(alpha=0.3); plt.show()
    
    acc = jnp.mean((forward(params, X) > 0.5) == y)
    print(f"准确率: {acc:.2%}")
    

  2. 从零实现一维卷积。将一个简单的边缘检测滤波器应用于信号,并与内置的 jnp.convolve 进行比较。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def conv1d(signal, kernel):
        """一维卷积(valid模式),从零实现。"""
        n, k = len(signal), len(kernel)
        output = jnp.zeros(n - k + 1)
        for i in range(n - k + 1):
            output = output.at[i].set(jnp.sum(signal[i:i+k] * kernel))
        return output
    
    # 创建一个带有阶跃函数的信号
    t = jnp.linspace(0, 4, 200)
    signal = jnp.where(t < 1, 0.0, jnp.where(t < 2, 1.0, jnp.where(t < 3, 0.5, 1.5)))
    
    # 边缘检测核
    edge_kernel = jnp.array([-1.0, 0.0, 1.0])
    
    # 我们的实现 vs 内置函数
    our_output = conv1d(signal, edge_kernel)
    jnp_output = jnp.convolve(signal, edge_kernel, mode='valid')
    
    fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
    axes[0].plot(t, signal, color='#3498db', linewidth=1.5)
    axes[0].set_title("原始信号"); axes[0].set_ylabel("值")
    
    axes[1].plot(t[:len(our_output)], our_output, color='#e74c3c', linewidth=1.5)
    axes[1].set_title("边缘检测后(我们的conv1d)"); axes[1].set_ylabel("值")
    
    axes[2].plot(t[:len(jnp_output)], jnp_output, color='#27ae60', linewidth=1.5, linestyle='--')
    axes[2].set_title("边缘检测后(jnp.convolve)"); axes[2].set_ylabel("值")
    axes[2].set_xlabel("t")
    
    plt.tight_layout(); plt.show()
    print(f"输出匹配: {jnp.allclose(our_output, jnp_output)}")
    

  3. 从零实现缩放点积注意力。在一个小例子中计算注意力权重,并将注意力矩阵可视化为热力图。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def scaled_dot_product_attention(Q, K, V):
        """缩放点积注意力。"""
        d_k = Q.shape[-1]
        scores = Q @ K.T / jnp.sqrt(d_k)
        weights = jax.nn.softmax(scores, axis=-1)
        output = weights @ V
        return output, weights
    
    # 示例:4个标记,嵌入维度8
    key = jax.random.PRNGKey(42)
    k1, k2, k3 = jax.random.split(key, 3)
    seq_len, d_model = 4, 8
    
    Q = jax.random.normal(k1, (seq_len, d_model))
    K = jax.random.normal(k2, (seq_len, d_model))
    V = jax.random.normal(k3, (seq_len, d_model))
    
    output, weights = scaled_dot_product_attention(Q, K, V)
    
    print(f"Q形状: {Q.shape}")
    print(f"注意力权重形状: {weights.shape}")
    print(f"输出形状: {output.shape}")
    print(f"\n注意力权重(行和为1):")
    print(weights)
    print(f"行和: {weights.sum(axis=-1)}")
    
    # 可视化注意力
    fig, ax = plt.subplots(figsize=(5, 4))
    im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
    ax.set_xlabel("键位置"); ax.set_ylabel("查询位置")
    ax.set_title("注意力权重")
    tokens = ['tok 0', 'tok 1', 'tok 2', 'tok 3']
    ax.set_xticks(range(4)); ax.set_xticklabels(tokens)
    ax.set_yticks(range(4)); ax.set_yticklabels(tokens)
    for i in range(4):
        for j in range(4):
            ax.text(j, i, f"{weights[i,j]:.2f}", ha='center', va='center', fontsize=10)
    plt.colorbar(im); plt.tight_layout(); plt.show()
    

  4. 构建一个简单的自编码器,将二维数据通过一维瓶颈压缩并重构。可视化潜在空间和重构结果。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_moons
    
    # 数据
    X, _ = make_moons(n_samples=500, noise=0.05, random_state=42)
    X = jnp.array(X)
    
    # 自编码器: 2 -> 8 -> 1 -> 8 -> 2
    def init_ae(key):
        k1, k2, k3, k4 = jax.random.split(key, 4)
        return {
            'enc_W1': jax.random.normal(k1, (2, 8)) * 0.5, 'enc_b1': jnp.zeros(8),
            'enc_W2': jax.random.normal(k2, (8, 1)) * 0.5, 'enc_b2': jnp.zeros(1),
            'dec_W1': jax.random.normal(k3, (1, 8)) * 0.5, 'dec_b1': jnp.zeros(8),
            'dec_W2': jax.random.normal(k4, (8, 2)) * 0.5, 'dec_b2': jnp.zeros(2),
        }
    
    def encode(p, x):
        h = jnp.tanh(x @ p['enc_W1'] + p['enc_b1'])
        return h @ p['enc_W2'] + p['enc_b2']
    
    def decode(p, z):
        h = jnp.tanh(z @ p['dec_W1'] + p['dec_b1'])
        return h @ p['dec_W2'] + p['dec_b2']
    
    def ae_loss(p, X):
        z = encode(p, X)
        X_hat = decode(p, z)
        return jnp.mean((X - X_hat) ** 2)
    
    grad_fn = jax.jit(jax.grad(ae_loss))
    params = init_ae(jax.random.PRNGKey(0))
    lr = 0.01
    
    for step in range(3000):
        grads = grad_fn(params, X)
        params = {k: params[k] - lr * grads[k] for k in params}
    
    z = encode(params, X)
    X_hat = decode(params, z)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].scatter(X[:,0], X[:,1], c=z.squeeze(), cmap='viridis', s=10)
    axes[0].set_title("原始数据(按潜在代码着色)")
    axes[1].scatter(X_hat[:,0], X_hat[:,1], c=z.squeeze(), cmap='viridis', s=10)
    axes[1].set_title("从一维瓶颈重构")
    for ax in axes:
        ax.set_aspect('equal'); ax.grid(alpha=0.3)
    plt.tight_layout(); plt.show()
    
    print(f"重构MSE: {ae_loss(params, X):.4f}")