Skip to content

嵌入与序列模型

词嵌入将稀疏、符号化的文本压缩成稠密的向量空间,其中语义相似性变为几何上的邻近性。本文件涵盖 Word2Vec(CBOW、Skip-gram)、GloVe、FastText、RNN、LSTM、GRU、带注意力的 seq2seq 以及编码器-解码器范式——这是从词袋模型到上下文表示的演进历程。

  • 在文件 01 中,我们介绍了分布假说:出现在相似上下文中的单词往往具有相似的含义。在文件 02 中,我们使用稀疏、手工设计的特征(如 TF-IDF 向量)来表示文本。这些向量存在于非常高维的空间中(每个词汇表单词一个维度),并且大部分为零。词嵌入将这些信息压缩为稠密、低维的向量,捕获语义关系,并且它们直接从数据中学习得到。

  • Word2Vec(Mikolov 等人,2013)通过在一个简单的预测任务上训练一个浅层神经网络来学习词嵌入。有两种架构。

  • 连续词袋(CBOW)模型根据周围的上下文单词预测目标单词。给定一个上下文窗口(例如“the cat ___ on the”),模型对其上下文单词的嵌入向量取平均,并将结果通过一个线性层来预测缺失的单词(“sat”)。训练目标最大化:

\[P(w_t \mid w_{t-k}, \ldots, w_{t-1}, w_{t+1}, \ldots, w_{t+k})\]
  • Skip-gram模型则相反:给定一个目标单词,预测其周围的上下文单词。对于目标单词“sat”,模型在多个预测中分别尝试预测“the”、“cat”、“on”、“the”。目标最大化:
\[P(w_{t+j} \mid w_t) \quad \text{对于每个 } j \in [-k, k], \; j \neq 0\]

Skip-gram 和 CBOW 架构并排:CBOW 对上下文嵌入取平均以预测中心词,skip-gram 使用中心词嵌入预测每个上下文词

  • Skip-gram 对于稀有词效果更好,因为每个词会生成多个训练样本(每个上下文位置一个)。CBOW 速度更快,对常见词稍好,因为它平均了多个上下文信号。

  • 在整个词汇表上训练代价高昂,因为 softmax 分母要对所有 \(V\) 个词求和。负采样通过将问题转化为二元分类来近似:区分真正的上下文词(正样本)和随机采样的噪声词(负样本)。模型不再计算完整的 softmax,而是只更新目标词、真正上下文词以及少量负样本的嵌入:

\[\mathcal{L} = \log \sigma(v_{w_O}^T v_{w_I}) + \sum_{i=1}^{k} \mathbb{E}_{w_i \sim P_n} [\log \sigma(-v_{w_i}^T v_{w_I})]\]
  • 这里 \(v_{w_I}\) 是输入词嵌入,\(v_{w_O}\) 是输出(上下文)词嵌入,\(P_n\) 是噪声分布,通常是一元频率的 3/4 次方(这会降低像“the”这样的超高频词的权重)。

  • 为什么这个简单的目标函数能产生有意义的嵌入?Levy 和 Goldberg(2014)证明,带负采样的 skip-gram 实际上在分解一个移位点互信息矩阵。收敛时,两个词向量的点积近似:

\[v_w^T v_c \approx \text{PMI}(w, c) - \log k\]
  • 其中 \(\text{PMI}(w, c) = \log \frac{P(w, c)}{P(w) P(c)}\) 衡量词 \(w\)\(c\) 共现的概率比随机预期高出多少(第 05 章信息论),\(k\) 是负样本数量。共现远超预期的词具有高的 PMI,因此具有高的点积(相似的嵌入)。共现低于预期的词具有负的 PMI 和不相似的嵌入。这表明 Word2Vec 在做与经典分布语义学方法(如潜在语义分析,即对共现矩阵进行 SVD)相同的事情,但采用的是更可扩展、在线的方式。

  • Word2Vec 嵌入最令人惊讶的性质是它们通过向量算术捕捉类比。向量 \(v_{\text{king}} - v_{\text{man}} + v_{\text{woman}}\) 最接近 \(v_{\text{queen}}\)。这是因为嵌入空间将语义关系编码为近似的线性方向:“王权”方向大致是 \(v_{\text{king}} - v_{\text{man}}\),将其加到 \(v_{\text{woman}}\) 上就会落在 \(v_{\text{queen}}\) 附近。这与第 01 章的线性代数相联系:语义关系就是向量平移。

  • GloVe(Pennington 等人,2014)采取了不同的方法。它不是逐个地从局部上下文窗口学习,而是构建一个全局的词共现矩阵 \(X\),其中 \(X_{ij}\) 统计在整个语料库中词 \(j\) 出现在词 \(i\) 上下文中的次数。然后模型学习嵌入,使其点积近似于对数共现值:

\[w_i^T \tilde{w}_j + b_i + \tilde{b}_j = \log X_{ij}\]
  • 损失函数通过一个上限函数 \(f(X_{ij})\) 对每一对进行加权,防止过高的共现频率主导训练:
\[\mathcal{L} = \sum_{i,j=1}^{V} f(X_{ij}) \left(w_i^T \tilde{w}_j + b_i + \tilde{b}_j - \log X_{ij}\right)^2\]
  • GloVe 结合了全局矩阵分解(如潜在语义分析)和 Word2Vec 局部上下文学习的优点。在实践中,GloVe 和 Word2Vec 产生的嵌入质量相似。

  • FastText(Bojanowski 等人,2017)通过将每个词表示为其字符 n-gram 的集合来扩展 skip-gram。例如,词“where”在 \(n=3\) 时变为:“”,加上整个词的标记“”。该词的嵌入是其所有 n-gram 嵌入之和。

  • 这有一个关键优势:FastText 可以为训练过程中从未见过的词生成嵌入。例如“whereabouts”与“where”共享 n-gram,因此即使“whereabouts”从未出现在训练数据中,它的嵌入也会是合理的。这对于形态丰富的语言(文件 01)尤其有用,这类语言中有许多屈折形式。

  • 嵌入评估通常使用两种基准测试。类比任务测试 \(v_a - v_b + v_c \approx v_d\) 是否成立(例如“Paris” − “France” + “Italy” ≈ “Rome”)。相似度基准将词对之间的余弦相似度(第 01 章)与人工判断进行比较。常见的数据集包括 WordSim-353、SimLex-999 和 Google 类比测试集。一个实际的注意事项:擅长类比的嵌入未必最适合情感分类等下游任务。最好的评估往往就是任务本身。

  • 在第 06 章中,我们介绍了 RNN、LSTM 和 GRU 作为处理序列数据的架构。这里我们重点关注它们如何具体应用于语言任务。

  • 语言模型 RNN 逐个读取标记,并在每一步预测下一个标记。隐藏状态 \(h_t\) 将整个历史 \(w_1, \ldots, w_t\) 压缩为一个固定大小的向量,线性层加 softmax 将 \(h_t\) 映射为词汇表上的概率分布。训练使用与真实下一个标记的交叉熵损失,这与最小化困惑度(文件 02)相同。关键的限制是:固定大小的隐藏状态必须编码关于历史的所有信息,而早期标记的信息会逐渐被覆盖。

  • 双向 RNN 从两个方向处理序列:一个 RNN 从左到右读取,另一个从右到左读取。在每个位置 \(t\),前向隐藏状态 \(\overrightarrow{h}_t\) 和后向隐藏状态 \(\overleftarrow{h}_t\) 被拼接起来,形成一个上下文感知的表示 \(h_t = [\overrightarrow{h}_t ; \overleftarrow{h}_t]\)。这使得模型可以同时访问过去和未来的上下文,这对于像词性标注和命名实体识别(文件 02)这样的任务非常强大,因为这些任务中一个词的标签取决于其前后的词。双向 RNN 不能用于语言建模,因为在预测未来标记时不能窥视它们。

双向 RNN:前向 RNN 从左到右读取生成隐藏状态,后向 RNN 从右到左读取,每个位置的输出被拼接起来

  • 深层堆叠 RNN 将多个 RNN 层堆叠在一起。第 \(l\) 层在所有时间步的隐藏状态成为第 \(l+1\) 层的输入序列。通常堆叠 2-4 层可以通过构建层次化表示来提高性能,类似于更深的 CNN 构建特征层次结构(第 06 章)。超过 4 层后,除非在层之间添加残差连接,否则梯度消失和过拟合会成为问题。

  • 序列到序列(seq2seq)架构(Sutskever 等人,2014)将可变长度的输入序列映射到可变长度的输出序列。它由一个编码器 RNN 和一个解码器 RNN 组成。编码器读取输入并将其压缩成一个上下文向量(最终的隐藏状态);解码器基于这个上下文向量逐个生成输出标记。

Seq2seq 编码器-解码器:编码器 RNN 从左到右读取输入标记,最终的隐藏状态作为初始状态传递给解码器 RNN,解码器自回归地生成输出标记

  • Seq2seq 是机器翻译领域的突破性架构。编码器读取一个法语句子,解码器生成英语翻译。解码器以一个特殊的序列起始标记开始,并自回归地生成标记,直到产生序列结束标记。一个实用的技巧:将输入序列反转(输入“chat le”而不是“le chat”)改善了结果,因为这样将第一个输入词在计算图中放到了更靠近第一个输出词的位置,缩短了梯度路径。

  • 瓶颈问题:整个输入必须被压缩成一个固定大小的向量。对于长句子,这个向量无法捕获所有信息,性能会下降。这促使了注意力机制的出现。

  • 第 06 章介绍了现代 Q、K、V 形式的注意力。NLP 领域最初的注意力机制有不同的表述,即编码器和解码器状态之间的对齐模型。

  • Bahdanau 注意力(加性注意力,Bahdanau 等人,2015)使用一个学习到的前馈网络计算解码器隐藏状态 \(s_t\) 与每个编码器隐藏状态 \(h_i\) 之间的对齐分数:

\[e_{ti} = v^T \tanh(W_s s_{t-1} + W_h h_i)\]
  • 这些分数通过 softmax 归一化为注意力权重,上下文向量是编码器状态的加权和:
\[\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}, \quad c_t = \sum_i \alpha_{ti} h_i\]
  • 然后解码器同时使用 \(s_{t-1}\)\(c_t\) 来产生下一个输出。关键的洞见是:不再为整个句子使用一个固定的上下文向量,每个解码器步骤都会获得编码器状态的不同加权组合,这使得模型能够“回顾”输入的相关部分。

  • Luong 注意力(乘性注意力,Luong 等人,2015)简化了分数计算。点积变体使用 \(e_{ti} = s_t^T h_i\)通用变体使用 \(e_{ti} = s_t^T W h_i\)。它们比 Bahdanau 的加性分数更快,因为它们使用矩阵乘法而不是前馈网络。Luong 注意力还从当前解码器状态 \(s_t\)(而不是 \(s_{t-1}\))计算上下文向量,这使其能够访问更多信息,但计算略有不同。

源句子与其翻译之间的注意力对齐热力图,显示每个目标词关注哪些源词,较亮的单元格表示更高的注意力权重

  • 注意力权重通常被可视化为热力图,显示解码器在生成每个输出标记时关注哪些输入标记。在翻译中,这些热力图大致勾勒出源语言和目标语言之间的词对齐,对角线模式会被词序调整(例如法语和英语中形容词-名词顺序不同)打破。

  • 在推理时,解码器必须在每一步选择一个标记。贪婪解码在每个位置选择概率最高的标记,但这可能导致次优序列:一个局部的好选择可能迫使模型进入一个全局糟糕的句子。集束搜索在每一步维护前 \(k\) 个(集束宽度)部分序列,通过所有可能的下一个标记扩展每个部分序列,并保留全局最好的 \(k\) 个。

  • 当集束宽度 \(k = 1\) 时,集束搜索退化为贪婪解码。典型值为 \(k = 4\)\(k = 10\)。更大的集束能找到更好的序列,但速度成比例变慢。集束搜索还需要长度归一化,以避免偏向较短的序列,因为较短的序列自然具有更高的总概率(它们相乘的项更少)。归一化分数为:

\[\text{score}(y) = \frac{1}{|y|^\alpha} \sum_{t=1}^{|y|} \log P(y_t \mid y_{<t})\]
  • 其中 \(|y|\) 是序列长度,\(\alpha\)(通常为 0.6-0.7)控制长度惩罚的强度。当 \(\alpha = 0\) 时,没有长度归一化。当 \(\alpha = 1\) 时,分数是每个标记的对数概率(几何平均值)。中间值平衡了偏好简洁输出和不过早截断之间的权衡。

  • 当 RNN 顺序处理文本时,一维 CNN 通过在标记序列上滑动滤波器来并行处理它们。每个滤波器检测一个局部模式(一个 n-gram 特征)。

  • TextCNN(Kim,2014)对输入嵌入矩阵应用多个不同宽度(例如 3、4、5 个标记)的一维卷积滤波器。每个滤波器产生一个特征图,随时间步最大池化从每个特征图中取单个最大值,捕获该模式是否在文本中的任何位置被检测到,而与位置无关。来自所有滤波器的池化特征被拼接起来并传递给一个分类器。

TextCNN 架构:输入嵌入通过宽度为 3、4 和 5 的并行卷积滤波器,每个滤波器后接随时间步最大池化,然后拼接并输入全连接分类器

  • TextCNN 速度快,对于情感分析等文本分类任务出奇地有效。它能捕获局部 n-gram 模式,但无法建模长距离依赖:宽度为 5 的滤波器只能看到 5 个连续的标记。空洞因果卷积通过在滤波器元素之间插入间隙(空洞)来解决这个问题。堆叠膨胀率指数增长(1, 2, 4, 8, …)的层,感受野呈指数增长而不增加参数,使模型能够捕获跨越数百个标记的依赖关系。

  • 到目前为止讨论的所有嵌入(Word2Vec、GloVe、FastText)都为每个词类型生成一个单一的向量,与上下文无关。“Bank”无论是表示金融机构还是河岸,都得到相同的嵌入。这是一个根本性的限制,而上下文嵌入解决了这个问题。

  • ELMo(来自语言模型的嵌入,Peters 等人,2018)通过在输入文本上运行一个深层的双向 LSTM 语言模型来产生上下文词表示。前向 LSTM 在每个位置预测下一个词;一个单独的后向 LSTM 预测前一个词。两者都在大规模语料上作为语言模型进行训练。

  • 在每个位置 \(k\),ELMo 使用任务特定的学习权重组合所有 \(L\) 层的隐藏状态:

\[\text{ELMo}_k = \gamma \sum_{j=0}^{L} s_j \, h_{k,j}\]
  • 这里 \(h_{k,j}\) 是位置 \(k\)\(j\) 层的隐藏状态(第 0 层是原始词嵌入),\(s_j\) 是经过 softmax 归一化的标量权重,\(\gamma\) 是任务特定的缩放因子。不同的层捕获不同的信息:低层捕获句法(词性标注、词形态),高层捕获语义(词义、语义角色)。通过使用学习到的权重混合所有层,ELMo 嵌入能够适应各种下游任务。

  • ELMo 标志着预训练然后微调范式的开端:在大量无标签文本上训练一个大型语言模型,然后将其表示用于下游任务。ELMo 具体做法是使用预训练表示作为固定或轻度调整的特征,与任务特定的输入拼接。BERT 和 GPT(文件 04)通过端到端微调整个模型更进一步,这被证明要有效得多。

  • 从 Word2Vec 到 ELMo 的演进说明了 NLP 中的一个反复出现的主题:从静态表示到动态表示,从局部上下文到全局上下文,从浅层模型到深层模型。每一步都以计算成本换取更丰富的表示。Transformer(文件 04)通过用注意力完全替代循环完成了这一演进,同时实现了深度上下文化和并行计算。

编码任务(使用 CoLab 或 notebook)

  1. 从头实现带负采样的 Word2Vec skip-gram。在一个小型语料上训练,并使用 PCA 可视化学习到的嵌入。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 小型语料
    corpus = """the king ruled the kingdom . the queen ruled the kingdom .
    the prince is the son of the king . the princess is the daughter of the queen .
    a man worked in the castle . a woman worked in the castle .
    the king and queen lived in the castle . the prince and princess played outside .""".lower().split()
    
    vocab = sorted(set(corpus))
    word2idx = {w: i for i, w in enumerate(vocab)}
    idx2word = {i: w for w, i in word2idx.items()}
    V = len(vocab)
    
    # 生成 skip-gram 对,窗口大小为 2
    window = 2
    pairs = []
    for i, word in enumerate(corpus):
        for j in range(max(0, i - window), min(len(corpus), i + window + 1)):
            if i != j:
                pairs.append((word2idx[word], word2idx[corpus[j]]))
    
    pairs = jnp.array(pairs)
    print(f"词汇表大小: {V} 词, 训练对数量: {len(pairs)}")
    
    # 模型参数
    embed_dim = 16
    key = jax.random.PRNGKey(42)
    k1, k2 = jax.random.split(key)
    W_in = jax.random.normal(k1, (V, embed_dim)) * 0.1    # 输入嵌入
    W_out = jax.random.normal(k2, (V, embed_dim)) * 0.1   # 输出嵌入
    
    # 单对负采样损失
    def neg_sampling_loss(W_in, W_out, target, context, neg_ids):
        v_in = W_in[target]      # (embed_dim,)
        v_out = W_out[context]   # (embed_dim,)
        v_neg = W_out[neg_ids]   # (k, embed_dim)
    
        pos_loss = -jax.nn.log_sigmoid(jnp.dot(v_in, v_out))
        neg_loss = -jnp.sum(jax.nn.log_sigmoid(-v_neg @ v_in))
        return pos_loss + neg_loss
    
    # 训练循环
    num_neg = 5
    lr = 0.05
    
    @jax.jit
    def train_step(W_in, W_out, target, context, neg_ids):
        loss, (g_in, g_out) = jax.value_and_grad(neg_sampling_loss, argnums=(0, 1))(
            W_in, W_out, target, context, neg_ids)
        return loss, W_in - lr * g_in, W_out - lr * g_out
    
    key = jax.random.PRNGKey(0)
    for epoch in range(50):
        total_loss = 0.0
        for i in range(len(pairs)):
            key, subkey = jax.random.split(key)
            neg_ids = jax.random.randint(subkey, (num_neg,), 0, V)
            loss, W_in, W_out = train_step(W_in, W_out, pairs[i, 0], pairs[i, 1], neg_ids)
            total_loss += loss
        if (epoch + 1) % 10 == 0:
            print(f"轮次 {epoch+1}: 平均损失 = {total_loss / len(pairs):.4f}")
    
    # 使用 PCA 可视化(第 01 章)
    embeddings = W_in
    mean = embeddings.mean(axis=0)
    centered = embeddings - mean
    U, S, Vt = jnp.linalg.svd(centered, full_matrices=False)
    coords = centered @ Vt[:2].T  # 投影到前两个主成分
    
    plt.figure(figsize=(10, 8))
    for i, word in idx2word.items():
        plt.scatter(coords[i, 0], coords[i, 1], c='#3498db', s=40)
        plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=9)
    plt.title("Word2Vec Skip-gram 嵌入 (PCA 投影)")
    plt.grid(alpha=0.3); plt.show()
    

  2. 构建一个字符级 RNN 语言模型,使其从一个小型训练字符串中学习生成文本。

    import jax
    import jax.numpy as jnp
    
    # 小型训练文本
    text = "to be or not to be that is the question "
    chars = sorted(set(text))
    char2idx = {c: i for i, c in enumerate(chars)}
    idx2char = {i: c for c, i in char2idx.items()}
    V = len(chars)
    data = jnp.array([char2idx[c] for c in text])
    
    # RNN 参数
    hidden_dim = 64
    key = jax.random.PRNGKey(0)
    k1, k2, k3, k4, k5 = jax.random.split(key, 5)
    
    params = {
        'Wx': jax.random.normal(k1, (V, hidden_dim)) * 0.1,
        'Wh': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.05,
        'bh': jnp.zeros(hidden_dim),
        'Wy': jax.random.normal(k3, (hidden_dim, V)) * 0.1,
        'by': jnp.zeros(V),
    }
    
    def rnn_step(params, h, x_idx):
        x = jnp.eye(V)[x_idx]  # one-hot
        h = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['bh'])
        logits = h @ params['Wy'] + params['by']
        return h, logits
    
    def loss_fn(params, inputs, targets):
        h = jnp.zeros(hidden_dim)
        total_loss = 0.0
        for t in range(len(inputs)):
            h, logits = rnn_step(params, h, inputs[t])
            log_probs = jax.nn.log_softmax(logits)
            total_loss -= log_probs[targets[t]]
        return total_loss / len(inputs)
    
    grad_fn = jax.jit(jax.grad(loss_fn))
    
    # 训练
    inputs = data[:-1]
    targets = data[1:]
    lr = 0.01
    
    for step in range(500):
        grads = grad_fn(params, inputs, targets)
        params = {k: params[k] - lr * grads[k] for k in params}
        if (step + 1) % 100 == 0:
            l = loss_fn(params, inputs, targets)
            print(f"步数 {step+1}: 损失 = {l:.4f}")
    
    # 生成文本
    def generate(params, seed_char, length=60):
        h = jnp.zeros(hidden_dim)
        idx = char2idx[seed_char]
        result = [seed_char]
        key = jax.random.PRNGKey(42)
        for _ in range(length):
            h, logits = rnn_step(params, h, idx)
            key, subkey = jax.random.split(key)
            idx = jax.random.categorical(subkey, logits)
            result.append(idx2char[int(idx)])
        return ''.join(result)
    
    print(f"\n生成的文本: {generate(params, 't')}")
    

  3. 实现一个带有 Bahdanau 注意力的简易 seq2seq 模型,用于序列反转。可视化注意力对齐矩阵。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 任务:反转数字序列(例如 [3, 1, 4] -> [4, 1, 3])
    vocab_size = 10  # 数字 0-9
    SOS, EOS = 10, 11  # 特殊标记
    total_vocab = 12
    embed_dim, hidden_dim = 16, 32
    max_len = 5
    
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, 8)
    
    params = {
        'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1,
        'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1,
        'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05,
        'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1,
        'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05,
        # Bahdanau 注意力
        'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1,
        'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1,
        'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1,
        # 输出投影(从隐藏状态 + 上下文到词汇表)
        'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1,
    }
    
    def encode(params, seq):
        """编码输入序列,返回所有隐藏状态。"""
        h = jnp.zeros(hidden_dim)
        states = []
        for t in range(len(seq)):
            x = params['embed'][seq[t]]
            h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh'])
            states.append(h)
        return jnp.stack(states), h
    
    def bahdanau_attention(params, dec_state, enc_states):
        """计算 Bahdanau 注意力权重和上下文向量。"""
        scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws'])
        e = scores @ params['v_att']  # (src_len,)
        alpha = jax.nn.softmax(e)
        context = alpha @ enc_states
        return context, alpha
    
    def decode_step(params, dec_h, prev_token, enc_states):
        x = params['embed'][prev_token]
        dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh'])
        context, alpha = bahdanau_attention(params, dec_h, enc_states)
        combined = jnp.concatenate([dec_h, context])
        logits = combined @ params['Wo']
        return dec_h, logits, alpha
    
    def seq2seq_loss(params, src, tgt):
        enc_states, enc_final = encode(params, src)
        dec_h = enc_final
        loss = 0.0
        prev_token = SOS
        for t in range(len(tgt)):
            dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states)
            log_probs = jax.nn.log_softmax(logits)
            loss -= log_probs[tgt[t]]
            prev_token = tgt[t]
        return loss / len(tgt)
    
    # 生成训练数据:反转序列
    key = jax.random.PRNGKey(0)
    train_srcs, train_tgts = [], []
    for _ in range(200):
        key, subkey = jax.random.split(key)
        length = jax.random.randint(subkey, (), 3, max_len + 1)
        key, subkey = jax.random.split(key)
        seq = jax.random.randint(subkey, (int(length),), 0, vocab_size)
        train_srcs.append(seq)
        train_tgts.append(seq[::-1])  # 反转
    
    # 训练
    grad_fn = jax.grad(seq2seq_loss)
    lr = 0.01
    
    for epoch in range(100):
        total_loss = 0.0
        for src, tgt in zip(train_srcs, train_tgts):
            grads = grad_fn(params, src, tgt)
            params = {k: params[k] - lr * grads[k] for k in params}
            total_loss += seq2seq_loss(params, src, tgt)
        if (epoch + 1) % 20 == 0:
            print(f"轮次 {epoch+1}: 平均损失 = {total_loss / len(train_srcs):.4f}")
    
    # 可视化一个例子的注意力
    test_src = jnp.array([3, 1, 4, 1, 5])
    test_tgt = test_src[::-1]
    
    enc_states, enc_final = encode(params, test_src)
    dec_h = enc_final
    attentions = []
    prev_token = SOS
    for t in range(len(test_tgt)):
        dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states)
        attentions.append(alpha)
        prev_token = test_tgt[t]
    
    att_matrix = jnp.stack(attentions)
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(att_matrix, cmap='Blues')
    ax.set_xlabel("源位置"); ax.set_ylabel("目标位置")
    src_labels = [str(int(x)) for x in test_src]
    tgt_labels = [str(int(x)) for x in test_tgt]
    ax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels)
    ax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels)
    for i in range(len(tgt_labels)):
        for j in range(len(src_labels)):
            ax.text(j, i, f"{att_matrix[i,j]:.2f}", ha='center', va='center', fontsize=9)
    ax.set_title("Bahdanau 注意力对齐(序列反转)")
    plt.colorbar(im); plt.tight_layout(); plt.show()