Skip to content

自动语音识别

自动语音识别将口语音频转换为书面文本,架起人类语音与机器可读语言之间的桥梁。本文涵盖GMM-HMM、CTC损失、RNN-Transducer、基于注意力机制的编码器-解码器模型(LAS)、Whisper以及端到端ASR,从经典流程到现代神经架构。

  • 自动语音识别(Automatic Speech Recognition, ASR)是将口语音频转换为书面文本的任务。它是人工智能领域最古老的问题之一(1950年代的首批系统仅能识别单个数字),也是商业部署最广泛的应用之一(语音助手、转录服务、字幕生成)。

  • 其难度源于语音的巨大变异性:不同说话人、口音、语速、背景噪声、麦克风特性,以及将连续声学信号映射到离散词汇的根本性歧义。

  • 可以将ASR想象成法庭速记员。速记员听到连续的声音流,在脑中将其分割为单词,利用上下文消解歧义(如"they're"、"their"、"there"),然后输入结果。ASR系统做同样的事情,但通过可显式表达并可独立或联合优化的阶段完成。

  • 经典ASR流程通过一系列独立阶段处理音频:原始音频被转换为特征(文件01中的MFCCs或对数梅尔谱图),声学模型评估每个特征帧与各音素单元的匹配程度,发音模型(词典)将音素单元映射到单词,语言模型评估词序列的可能性,解码器搜索使综合得分最高的词序列。每个组件分别训练和调优。

从原始音频经特征提取、声学模型、解码器和语言模型到输出文本的ASR流程

  • 音素(Phonemes)是语言中区分单词的最小声音单位。英语约有39-44个音素(确切数量取决于方言和所用音素清单)。例如,"bat"和"pat"仅在一个音素上不同(/b/ vs /p/)。大多数ASR系统建模上下文相关音素,称为三音子(triphones):由左右邻居定义的音素(例如"b_t"上下文中的"a"与"c_t"上下文中的"a"是不同单元),因为音素的声学实现受其邻居强烈影响(这称为协同发音,coarticulation)。

  • 可能的三音子数量巨大(40个音素的立方=64,000),因此决策树聚类将声学相似的三音子分组为音素状态(senones,通常2000-10,000类)。每个senone拥有独立的声学模型。这种聚类是文件06中决策树算法的一种形式。

  • GMM-HMM(高斯混合模型-隐马尔可夫模型)是1980年代至2010年代初主导的声学建模方法。HMM(文件05)建模语音的时间结构:每个音素是一个从左到右的HMM,含3-5个状态,每个状态代表一个子音素片段(起始、中间、结束)。状态间转移隐式建模时长。

  • 在每个HMM状态,发射概率(给定该状态下特定特征向量的可能性)由高斯混合模型(GMM,文件05中的多元高斯分布加权和)建模:

\[ p(\mathbf{x} | s) = \sum_{m=1}^{M} w_m \cdot \mathcal{N}(\mathbf{x} ; \boldsymbol{\mu}_m, \boldsymbol{\Sigma}_m) \]
  • 其中 \(\mathbf{x}\) 是特征向量(如39维MFCCs),\(s\) 是HMM状态,\(M\) 是混合分量数量(通常8-64),\(w_m\) 是混合权重,\(\boldsymbol{\mu}_m\)\(\boldsymbol{\Sigma}_m\) 是每个高斯分量的均值和协方差。为计算效率,协方差矩阵通常设为对角阵(假设特征维度独立,这对经DCT去相关的MFCCs近似成立)。

  • 训练使用Baum-Welch算法(文件05中EM算法的特例)从带标注语音数据中迭代估计GMM参数和HMM转移概率。解码(寻找最可能的状态序列)使用Viterbi算法(文件05中的动态规划):

\[ \delta_t(j) = \max_{i} \left[ \delta_{t-1}(i) \cdot a_{ij} \right] \cdot b_j(\mathbf{x}_t) \]
  • 其中 \(\delta_t(j)\) 是在时间 \(t\) 以状态 \(j\) 结束的最优路径概率,\(a_{ij}\) 是从状态 \(i\)\(j\) 的转移概率,\(b_j(\mathbf{x}_t)\) 是状态 \(j\) 下特征 \(\mathbf{x}_t\) 的发射概率。

  • DNN-HMM(Hinton et al., 2012)用深度神经网络(文件06中的DNN)替换GMM发射模型,从特征帧窗口预测senone后验概率 \(p(s | \mathbf{x})\)。HMM仍处理时间结构和序列,但神经网络提供更具判别力的发射分数。这种混合方法相比GMM将词错误率相对降低20-30%,是2012-2016年的主导范式。

  • WFST解码(加权有限状态变换器)是传统ASR的标准解码框架。每个组件(HMM拓扑H、上下文依赖C、词典L、语法/语言模型G)表示为加权有限状态变换器,并组合成单一搜索图 \(H \circ C \circ L \circ G\)。Viterbi搜索随后在该组合图中寻找最低代价路径。WFST允许知识源的模块化组合和高效动态规划搜索。其数学框架来自有限自动机理论(与文件05中的状态机相关)。

  • 端到端ASR消除了独立组件(发音模型、音素清单、WFST解码器),训练单个神经网络直接从音频特征映射到字符或词片段。核心挑战是对齐问题:输入(每秒数百个特征帧)和输出(每秒几个字符)长度差异巨大,且训练时对齐关系未知。

  • 连接时序分类(CTC)(Graves et al., 2006)通过引入特殊空白(blank)标记解决对齐问题,允许网络输出任意字符和空白序列,只要折叠连续重复项并移除空白后能得到正确转写。例如,转写"cat"可由输出序列"--cc-aa-t--"产生(其中"-"表示空白)。

  • 形式上,CTC定义了一个多对一映射 \(\mathcal{B}\),从所有长度-\(T\) 的输出序列(字母表加空白)到标签序列。标签序列 \(\mathbf{y}\) 的概率是所有折叠后得到它的对齐路径概率之和:

\[P(\mathbf{y} | \mathbf{x}) = \sum_{\boldsymbol{\pi} \in \mathcal{B}^{-1}(\mathbf{y})} \prod_{t=1}^{T} p(\pi_t | \mathbf{x})\]

CTC对齐示意图:多条经过空白和字符标记的路径均折叠为相同输出文本

  • 朴素计算该求和需枚举指数级数量的对齐,但CTC前向-后向算法使用动态规划在 \(O(T \cdot |\mathbf{y}|)\) 时间内高效计算,类比于文件05中的HMM前向-后向算法。

  • CTC做出条件独立假设:给定输入,每个时间步的输出独立于其他所有输出。这意味着CTC无法建模输出依赖(例如无法学习"q"几乎总是后跟"u")。必须使用外部语言模型处理此类依赖。

  • CTC解码选项:

    • 贪婪解码:每步取最可能标记,然后折叠。速度快但次优。
    • 束搜索:每步维护前 \(k\) 个部分假设,合并折叠后前缀相同的假设。可融入语言模型分数。
    • 前缀束搜索:修改的束搜索,正确处理CTC空白合并,确保假设在折叠后比较。
  • RNN-Transducer(RNN-T)(Graves, 2012)通过添加显式预测网络(类似语言模型的RNN)扩展CTC,使每个输出条件依赖于先前输出,移除条件独立假设。RNN-T包含三个组件:

    • 编码器:处理音频特征生成隐藏表示 \(\mathbf{h}_t^\text{enc}\)(通常为LSTM或Conformer层堆叠)。
    • 预测网络:自回归RNN,从先前输出的标签生成隐藏表示 \(\mathbf{h}_u^\text{pred}\)
    • 联合网络:在每个(时间,标签)位置组合编码器和预测网络输出,生成下一个标记(含空白)的分布:
\[p(y | t, u) = \text{softmax}(W \cdot \text{tanh}(W_\text{enc} \mathbf{h}_t^\text{enc} + W_\text{pred} \mathbf{h}_u^\text{pred} + b))\]
  • RNN-T每时间步可输出零个或多个标签(通过在进入下一时间步前输出非空白标记,或输出空白以推进时间但不输出)。训练使用2D(时间,标签)格上的前向-后向算法,复杂度 \(O(T \cdot U)\)\(U\) 为输出长度)。由于天然支持流式处理(编码器从左到右处理音频,预测网络增量生成输出),RNN-T是设备端流式ASR的主导架构(用于Google Pixel手机等类似产品)。

  • Listen, Attend and Spell(LAS)(Chan et al., 2016)是基于注意力机制的编码器-解码器模型(文件06中的序列到序列架构)。包含三个组件:

    • Listener(编码器):金字塔双向LSTM,处理完整输入序列并以8倍因子下采样(每层拼接连续隐藏状态对),生成较短的编码器隐藏状态序列。
    • Attention:每解码步计算所有编码器状态的注意力权重以形成上下文向量(文件07中的相同注意力机制)。
    • Speller(解码器):自回归LSTM,每次生成一个字符,条件于上下文向量和先前生成的字符。
  • LAS取得强劲结果,但需要完整话语可用后才能解码(因为注意力需关注所有编码器状态),不适合流式应用。对于超长话语,注意力在长序列上变得分散,效果下降。

  • Conformer(Gulati et al., 2020)结合卷积的局部模式捕获能力与自注意力的全局依赖建模能力。每个Conformer模块采用三明治结构,含四个子模块:

    1. 前馈模块(半步):带残差连接的前馈网络,使用一半残差权重。
    2. 多头自注意力模块:文件07中的标准Transformer自注意力,含相对位置编码。
    3. 卷积模块:逐点卷积、门控线性单元(GLU)、1D深度卷积、批归一化、Swish激活、另一个逐点卷积。深度卷积捕获局部上下文(类似特征序列上的n-gram)。
    4. 前馈模块(半步):与模块1相同。
  • 输出为:\(\mathbf{y} = \text{LayerNorm}(\mathbf{x} + \frac{1}{2}\text{FFN}_1 + \text{MHSA} + \text{Conv} + \frac{1}{2}\text{FFN}_2)\)。这种类似马卡龙的结构(FFN-Attention-Conv-FFN)配合半步残差,经实证发现优于其他排序方式。Conformer已成为CTC和RNN-T系统的默认编码器,优于纯Transformer和纯LSTM编码器。

Conformer模块示意图:前馈、自注意力、卷积、前馈模块的三明治结构

  • Whisper(Radford et al., 2023)是OpenAI的大规模基于注意力的ASR模型。使用标准编码器-解码器Transformer架构(文件07),在从网络抓取的68万小时弱监督数据(音频配对近似转写)上训练。关键设计选择:

    • 输入:80通道对数梅尔谱图(文件01),25ms窗长、10ms跳长,归一化为零均值单位方差。
    • 编码器:标准Transformer编码器,含正弦位置嵌入和预激活层归一化。
    • 解码器:Transformer解码器,使用字节级BPE分词器(文件07)自回归生成标记。
    • 多任务:单个模型处理转写、翻译、语言识别和时间戳预测,通过解码器提示中的特殊任务标记条件控制。
    • 训练数据规模(而非架构创新)是Whisper在领域、口音和语言间强大泛化能力的主要驱动因素。
  • wav2vec 2.0(Baevski et al., 2020)是语音表征的自监督预训练框架。核心思想是从大量无标签音频学习语音表征,然后用少量标注数据微调。这遵循与BERT(文件07)相同的自监督范式,但适配连续音频信号。

  • wav2vec 2.0架构包含三部分:

    • 特征编码器:多层1D CNN,处理原始波形样本,以20ms帧率(16kHz下每320个样本一个向量)生成潜在表示 \(\mathbf{z}_t\)
    • 量化模块:使用乘积量化将潜在表示离散化为有限码本(将向量分组并独立量化每组,从 \(G\) 个码本中各选 \(V\) 个条目)。这为对比学习目标生成目标 \(\mathbf{q}_t\)
    • 上下文网络:Transformer编码器,接收(部分掩码的)潜在表示并生成上下文化表示 \(\mathbf{c}_t\)

wav2vec 2.0架构:CNN特征编码器、掩码、Transformer上下文网络、与量化目标的对比学习

  • 预训练期间,随机跨度的潜在表示被掩码(替换为学习到的掩码嵌入),模型必须从一组干扰项(从同一话语其他位置采样的负样本)中识别掩码位置的真实量化表示。对比损失为:
\[\mathcal{L} = -\log \frac{\exp(\text{sim}(\mathbf{c}_t, \mathbf{q}_t) / \kappa)}{\sum_{\tilde{\mathbf{q}} \in Q_t} \exp(\text{sim}(\mathbf{c}_t, \tilde{\mathbf{q}}) / \kappa)}\]
  • 其中 \(\text{sim}\) 为余弦相似度,\(\kappa\) 为温度参数,\(Q_t\) 包含真实量化目标和干扰项。额外的多样性损失鼓励均衡使用所有码本条目。该损失本质上是InfoNCE对比损失,与视觉自监督学习中使用的对比目标同属一族。

  • 预训练后,在顶部添加线性投影和CTC头,并在标注数据上微调。仅用10分钟标注数据(使用53,000小时无标签音频预训练),wav2vec 2.0即达到接近最先进的结果,展示了自监督学习在低资源语音识别中的强大能力。

  • HuBERT(Hsu et al., 2021)是另一种自监督方法,用掩码预测目标(预测掩码帧的离散聚类分配)替换对比目标。目标由离线聚类步骤生成(第一轮在MFCCs上执行k-means,后续轮次在HuBERT特征上执行k-means)。相比wav2vec 2.0,HuBERT简化了训练流程(无需量化模块或对比采样),并达到相当或更好的结果。

  • Fast Conformer(Rekesh et al., 2023, NVIDIA NeMo)用下采样注意力机制替换标准Conformer中的二次自注意力:计算注意力前压缩输入序列(通常通过步长卷积压缩8倍),然后扩展回原尺寸。这将注意力代价从 \(O(T^2)\) 降低到 \(O(T^2/64)\),同时保留全局上下文,使训练超长话语(长达数分钟)时无内存问题。Fast Conformer是NVIDIA NeMo工具包的默认编码器,构成其生产级模型的主干。

  • Parakeet(NVIDIA, 2024)是基于Fast Conformer编码器、搭配CTC和RNN-T解码器的高精度英语ASR模型家族,在64,000小时英语语音上训练。发布时,Parakeet模型(0.6B和1.1B参数)在标准基准上达到最低词错误率,在大多数英语测试集上超越Whisper large-v3。关键要素包括高效的Fast Conformer架构、激进的数据增强(SpecAugment、速度扰动、噪声混合)和大规模监督训练数据——表明对已知组件的精心工程仍能推动技术发展。

  • Canary(NVIDIA, 2024)将NeMo框架扩展到多语言和多任务ASR。使用Fast Conformer编码器搭配基于注意力的解码器(而非CTC或RNN-T),在单个模型中处理多语言的转写和翻译(类似Whisper的多任务设计,但采用更高效的Fast Conformer主干)。Canary模型支持英语、德语、西班牙语和法语,具有竞争力精度。

  • Moonshine(Useful Sensors, 2024)是专为设备端和边缘部署优化的ASR模型家族。编码器采用混合架构,用小型CNN后接少量Transformer层替换初始Transformer/Conformer层,大幅降低模型尺寸(基础模型不足3000万参数)。Moonshine面向CPU和低功耗设备上的实时流式处理,在Whisper过大过慢的场景下,以部分精度换取5-10倍的更低延迟和内存占用。

  • Distil-Whisper(Gandhi et al., 2023)应用知识蒸馏(文件06)将Whisper压缩为更小、更快的模型。学生模型仅用2层解码器(相比Whisper的32层),保留完整编码器,训练以匹配Whisper的输出分布。Distil-Whisper在词错误率上与教师模型差距在1%以内,同时速度快6倍,使完整Whisper模型过慢的实时应用变得可行。

  • 通用语音模型(USM)(Zhang et al., 2023, Google)将自监督预训练扩展到300+语言的1200万小时无标签音频,随后进行监督微调。USM证明wav2vec 2.0/自监督范式可扩展到真正大规模数据场景,在标注数据极少的低资源语言上实现强劲性能。

  • 大规模多语言语音(MMS)(Pratap et al., 2023, Meta)将wav2vec 2.0预训练扩展到1100+语言,使用宗教录音和其他多语言音频源。MMS覆盖的语言远超以往任何ASR系统,首次使许多低资源语言能够进行语音识别。

  • 现代ASR格局正收敛于几种主导模式:(1) 用于流式的Conformer家族编码器搭配CTC或RNN-T,(2) 用于离线/多任务的编码器-解码器Transformer,(3) 用于低资源场景的自监督预训练,(4) 规模——更多数据和更大模型持续提升精度。选择取决于部署约束:延迟预算、可用算力、语言数量、以及应用是流式还是批处理。

  • 语言模型集成通过融入声学模型未捕获的语言学知识提升ASR。基本思想是在解码时将声学模型分数 \(p(\mathbf{x} | \mathbf{y})\)(音频与转写的匹配程度)与语言模型分数 \(p(\mathbf{y})\)(转写作为句子的可能性)结合。

  • 浅层融合在束搜索时结合分数:

\[\hat{\mathbf{y}} = \arg\max_\mathbf{y} \left[ \log p_\text{AM}(\mathbf{y} | \mathbf{x}) + \lambda \log p_\text{LM}(\mathbf{y}) \right]\]
  • 其中 \(\lambda\) 为可调权重,\(p_\text{LM}\) 为外部语言模型(通常为文件07中的n-gram或神经语言模型)。这种方法简单有效,但要求语言模型与ASR模型使用相同的标记词汇表。

  • 深层融合(Gulcehre et al., 2015)将语言模型集成到解码器网络内部:LM隐藏状态与解码器隐藏状态拼接,经门控机制处理后进行输出投影。整个系统(包括预训练LM)联合微调。这允许更深层次集成,但训练更复杂。

  • 冷融合(Sriram et al., 2018)类似深层融合,但从头训练集成语言模型的ASR解码器,而非微调预训练解码器。这迫使声学模型学习互补信息,而非重复LM已掌握的知识。

  • 重排序(N-best重排序)是两阶段方法:先用束搜索生成 \(N\) 个候选转写,再用更强大的语言模型(如大型Transformer LM)重新排序。实现简单,且允许使用对首遍解码过慢的超大型语言模型。

  • 内部语言模型估计(ILME)解决一个微妙问题:端到端模型从训练转写中隐式学习内部语言模型,可能在浅层融合时与外部语言模型冲突(本质上是双重计数语言先验)。ILME估计内部语言模型并在融合时减去其分数:

\[\hat{\mathbf{y}} = \arg\max_\mathbf{y} \left[ \log p_\text{E2E}(\mathbf{y} | \mathbf{x}) - \beta \log p_\text{ILM}(\mathbf{y}) + \lambda \log p_\text{LM}(\mathbf{y}) \right]\]
  • 流式 vs. 离线ASR是根本性的架构选择。离线(或批处理)ASR在处理完整话语后才输出任何结果。流式ASR在音频到达时增量输出,具有有界延迟。

  • 流式对实时应用至关重要:实时字幕、语音助手(用户期望在说完前得到响应)、电话通话转录。挑战在于:部分未来上下文有助于识别(知道下一个词是"York"可消解"New"的歧义),但流式系统不能等待任意长的未来上下文。

  • 单向编码器(从左到右的LSTM、因果卷积、因果Transformer)天然支持流式,因为每个输出仅依赖过去和当前输入。双向编码器(关注未来上下文)无法直接支持流式。

  • 分块注意力(也称块状或分段注意力)将输入划分为固定长度块,仅在每块内(及可选的前几个块)应用自注意力。这将延迟限制为块大小加处理时间,同时允许每块内一定的局部双向上下文。权衡是:块尺寸越小,精度下降越明显。

  • 前瞻允许流式编码器在输出当前帧前窥视少量未来帧(如300-900ms)。通过为单向计算添加小的右上下文实现。前瞻窗口增加延迟,但显著提升精度。

  • 流式ASR中的延迟包含多个组件:

    • 算法延迟:音频到达至模型可处理的时间差(由块大小、前瞻和特征提取决定)。
    • 计算延迟:运行模型前向传播的时间。
    • 端点检测延迟:检测用户说完话的延迟。
    • 首词延迟:第一个词出现的速度。最终确认延迟:最终输出确认的速度(流式系统通常生成临时输出,随更多音频到达而修正)。
  • ASR的评估指标

  • 词错误率(Word Error Rate, WER)是主要指标。通过编辑距离(将一个序列转换为另一个所需的最少替换、删除、插入操作数)对齐假设(系统输出)与参考(真实转写),然后计算:

\[\text{WER} = \frac{S + D + I}{N}\]
  • 其中 \(S\) 为替换数,\(D\) 为删除数,\(I\) 为插入数,\(N\) 为参考中的总词数。若插入很多,WER可超过100%。对于清晰朗读语音,5%的WER大致相当于人类水平;对话或噪声语音则困难得多(10-20%+)。

  • 字符错误率(Character Error Rate, CER)是相同公式在字符级别而非词级别的应用。对于无清晰词边界的语言(中文、日语)以及评估近似错误("cat" vs "bat"的WER为100%但CER为33%),CER更具信息量。

  • 词信息丢失(WIL)和词信息保留(WIP)是信息论替代指标,比WER更精确地考虑参考与假设间的相关性,但报告较少。

  • 实时因子(Real-Time Factor, RTF)衡量计算效率:处理时间与音频时长的比值。RTF < 1表示系统运行快于实时;RTF > 1表示无法跟上实时音频。流式系统必须保持RTF < 1。

  • 数据增强对鲁棒ASR至关重要。常用技术:

    • 速度扰动:以0.9倍和1.1倍速度重采样音频(改变音高和时长)。
    • SpecAugment(Park et al., 2019):在谱图中掩码随机频带和时间步。这是音频领域的dropout类比,是ASR最有效的正则化技术之一。无需额外数据。
    • 噪声增强:以不同信噪比将干净语音与录制噪声混合。
    • 房间脉冲响应模拟:将干净语音与模拟房间声学卷积,模拟混响环境。
  • ASR的分词决定模型的输出词汇表。选项包括:

    • 字符:简单、词汇表小(英语约30个),但输出序列长且无隐式语言建模。
    • 词片段 / BPE(文件07):子词单元,平衡词汇表大小和序列长度。现代系统的标准(Whisper使用字节级BPE,约50,000标记)。
    • :词汇表大(50,000+)、输出序列短,但无法处理未登录词。
    • 音素:语言学动机、紧凑,但需要发音词典。
  • ASR的演进可总结为:从高度工程化的模块化系统(1990年代-2010年代的GMM-HMM + WFST解码),到混合系统(2012-2016年的DNN-HMM),再到将越来越多流程吸收到单一神经网络中的端到端系统(2016-2020年的CTC、RNN-T、LAS),再到利用海量无标签或弱标签数据的大规模预训练模型(2020年至今的wav2vec 2.0、Whisper)。每次转变都简化了工程同时提升了精度,遵循机器学习更广泛的趋势:从数据中学习表征而非手工设计(如同文件06中图像特征被CNN取代、文件07中NLP特征被Transformer取代的故事)。

编程任务(使用CoLab或Notebook)

  1. 用JAX从头实现CTC损失。创建一个短序列logits和目标标签的玩具示例,计算CTC前向算法得到总概率,并计算负对数似然损失。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def ctc_forward(log_probs, targets):
        """
        CTC前向算法(对数域以保证数值稳定性)。
        log_probs: (T, V) 词汇表上的对数概率(索引0=空白)
        targets: (U,) 目标标签索引(不含空白)
        返回:目标序列在CTC下的对数概率。
        """
        T, V = log_probs.shape
        U = len(targets)
    
        # 构建带空白的扩展标签序列: [blank, y1, blank, y2, ..., yU, blank]
        S = 2 * U + 1
        labels = jnp.zeros(S, dtype=jnp.int32)  # 全空白
        for i in range(U):
            labels = labels.at[2 * i + 1].set(targets[i])
    
        # 初始化alpha(对数域)
        NEG_INF = -1e30
        alpha = jnp.full((T, S), NEG_INF)
        alpha = alpha.at[0, 0].set(log_probs[0, labels[0]])        # 从空白开始
        alpha = alpha.at[0, 1].set(log_probs[0, labels[1]])        # 或第一个标签
    
        # 前向填充
        for t in range(1, T):
            for s in range(S):
                # 同一状态
                a = alpha[t - 1, s]
                # 来自前一状态
                if s > 0:
                    a = jnp.logaddexp(a, alpha[t - 1, s - 1])
                # 跳过空白(如果当前和隔一个标签不同)
                if s > 1 and labels[s] != 0 and labels[s] != labels[s - 2]:
                    a = jnp.logaddexp(a, alpha[t - 1, s - 2])
                alpha = alpha.at[t, s].set(a + log_probs[t, labels[s]])
    
        # 总对数概率:最后时间步最后两个状态之和
        log_prob = jnp.logaddexp(alpha[T - 1, S - 1], alpha[T - 1, S - 2])
        return log_prob, alpha
    
    # --- 玩具示例 ---
    T = 12   # 输入长度(时间步)
    V = 5    # 词汇表大小(0=空白, 1='c', 2='a', 3='t', 4='x')
    targets = jnp.array([1, 2, 3])  # "c", "a", "t"
    
    # 创建随机logits并转换为对数概率
    key = jax.random.PRNGKey(42)
    logits = jax.random.normal(key, (T, V))
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    
    log_prob, alpha = ctc_forward(log_probs, targets)
    ctc_loss = -log_prob
    
    print(f"目标序列: {targets.tolist()} ('c', 'a', 't')")
    print(f"输入长度 T={T}, 词汇表大小 V={V}")
    print(f"CTC对数概率: {log_prob:.4f}")
    print(f"CTC损失(负对数概率): {ctc_loss:.4f}")
    
    # 可视化前向变量(alpha格)
    fig, ax = plt.subplots(figsize=(12, 5))
    # 从对数转换为线性以便可视化
    alpha_linear = jnp.exp(alpha - jnp.max(alpha))  # 归一化以便显示
    im = ax.imshow(alpha_linear.T, aspect='auto', origin='lower', cmap='viridis')
    ax.set_xlabel('时间步 (t)')
    ax.set_ylabel('扩展标签索引 (s)')
    
    label_names = ['_', 'c', '_', 'a', '_', 't', '_']  # _ = 空白
    ax.set_yticks(range(len(label_names)))
    ax.set_yticklabels(label_names)
    ax.set_title(f'CTC前向变量(alpha格)| 损失 = {ctc_loss:.2f}')
    plt.colorbar(im, ax=ax, label='归一化概率')
    plt.tight_layout(); plt.show()
    

  2. 用JAX构建简单的基于编码器-解码器注意力的ASR模型(最小LAS类架构)。使用1D卷积编码器和单层解码器搭配点积注意力。在合成数据上运行并可视化注意力权重。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # --- 最小注意力编码器-解码器ASR模型 ---
    
    def init_params(key, input_dim, hidden_dim, vocab_size):
        """初始化小型LAS类模型参数."""
        keys = jax.random.split(key, 8)
        scale = 0.1
        params = {
            # 编码器: 简单线性投影(模拟卷积输出)
            'enc_w': jax.random.normal(keys[0], (input_dim, hidden_dim)) * scale,
            'enc_b': jnp.zeros(hidden_dim),
            # 注意力: query, key, value投影
            'attn_q': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * scale,
            'attn_k': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * scale,
            'attn_v': jax.random.normal(keys[3], (hidden_dim, hidden_dim)) * scale,
            # 解码器RNN(简单Elman RNN示例)
            'dec_wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * scale,
            'dec_wx': jax.random.normal(keys[5], (vocab_size, hidden_dim)) * scale,
            'dec_wc': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * scale,
            'dec_b': jnp.zeros(hidden_dim),
            # 输出投影
            'out_w': jax.random.normal(keys[7], (hidden_dim, vocab_size)) * scale,
            'out_b': jnp.zeros(vocab_size),
        }
        return params
    
    def encode(params, x):
        """编码器: 线性投影(替代卷积/LSTM堆叠)."""
        return jnp.tanh(x @ params['enc_w'] + params['enc_b'])
    
    def attend(params, query, enc_out):
        """对编码器输出的点积注意力."""
        q = query @ params['attn_q']                   # (hidden,)
        k = enc_out @ params['attn_k']                 # (T_enc, hidden)
        v = enc_out @ params['attn_v']                 # (T_enc, hidden)
        d_k = q.shape[-1]
        scores = (k @ q) / jnp.sqrt(d_k)              # (T_enc,)
        weights = jax.nn.softmax(scores)               # (T_enc,)
        context = weights @ v                          # (hidden,)
        return context, weights
    
    def decode_step(params, h_prev, y_prev_onehot, enc_out):
        """单步解码: RNN + 注意力."""
        # 嵌入前一标记
        y_emb = y_prev_onehot @ params['dec_wx']       # (hidden,)
        # 注意力编码
        context, attn_w = attend(params, h_prev, enc_out)
        # RNN更新
        h = jnp.tanh(h_prev @ params['dec_wh'] + y_emb + context @ params['dec_wc']
                      + params['dec_b'])
        # 输出logits
        logits = h @ params['out_w'] + params['out_b']
        return h, logits, attn_w
    
    # --- 设置 ---
    key = jax.random.PRNGKey(0)
    input_dim = 40       # 如40个梅尔频带
    hidden_dim = 64
    vocab_size = 10      # 演示用小词汇表
    T_enc = 30           # 编码器时间步
    T_dec = 8            # 解码器步数
    
    params = init_params(key, input_dim, hidden_dim, vocab_size)
    
    # 合成输入: 随机类梅尔特征
    key, subkey = jax.random.split(key)
    x = jax.random.normal(subkey, (T_enc, input_dim))
    
    # 编码
    enc_out = encode(params, x)
    
    # 解码(教师强制,使用随机目标)
    key, subkey = jax.random.split(key)
    targets = jax.random.randint(subkey, (T_dec,), 0, vocab_size)
    
    h = jnp.zeros(hidden_dim)
    all_logits = []
    all_attn = []
    
    for t in range(T_dec):
        y_prev = jax.nn.one_hot(targets[t] if t > 0 else 0, vocab_size)
        h, logits, attn_w = decode_step(params, h, y_prev, enc_out)
        all_logits.append(logits)
        all_attn.append(attn_w)
    
    all_attn = jnp.stack(all_attn)  # (T_dec, T_enc)
    all_logits = jnp.stack(all_logits)  # (T_dec, vocab_size)
    
    # --- 可视化注意力权重 ---
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    im = axes[0].imshow(all_attn, aspect='auto', cmap='Blues', origin='lower')
    axes[0].set_xlabel('编码器时间步')
    axes[0].set_ylabel('解码步')
    axes[0].set_title('注意力权重(解码器 -> 编码器)')
    plt.colorbar(im, ax=axes[0])
    
    # 显示每解码步的预测标记分布
    im2 = axes[1].imshow(jax.nn.softmax(all_logits, axis=-1), aspect='auto',
                          cmap='Oranges', origin='lower')
    axes[1].set_xlabel('词汇表索引')
    axes[1].set_ylabel('解码步')
    axes[1].set_title('输出标记概率')
    plt.colorbar(im2, ax=axes[1])
    
    plt.suptitle('最小注意力基ASR模型(未训练)')
    plt.tight_layout(); plt.show()
    

  3. 使用动态规划(编辑距离)从头计算词错误率(WER),并评估多个假设与参考的匹配。可视化编辑距离矩阵。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import numpy as np
    
    def compute_wer(reference, hypothesis):
        """
        使用动态规划计算WER(词级莱文斯坦距离)。
        返回WER、替换数、删除数、插入数和DP矩阵。
        """
        ref_words = reference.split()
        hyp_words = hypothesis.split()
        N = len(ref_words)
        M = len(hyp_words)
    
        # DP矩阵: d[i][j] = ref[:i] 与 hyp[:j] 的编辑距离
        d = np.zeros((N + 1, M + 1), dtype=np.int32)
        # 回溯矩阵用于计数S, D, I
        ops = np.zeros((N + 1, M + 1, 3), dtype=np.int32)  # [sub, del, ins]
    
        for i in range(N + 1):
            d[i][0] = i  # 全删除
        for j in range(M + 1):
            d[0][j] = j  # 全插入
    
        for i in range(1, N + 1):
            for j in range(1, M + 1):
                if ref_words[i - 1] == hyp_words[j - 1]:
                    sub_cost = d[i - 1][j - 1]  # 匹配,无编辑
                else:
                    sub_cost = d[i - 1][j - 1] + 1  # 替换
                del_cost = d[i - 1][j] + 1      # 删除
                ins_cost = d[i][j - 1] + 1      # 插入
    
                d[i][j] = min(sub_cost, del_cost, ins_cost)
    
        # 回溯计数操作
        i, j = N, M
        S, D, I = 0, 0, 0
        while i > 0 or j > 0:
            if i > 0 and j > 0 and d[i][j] == d[i-1][j-1] and ref_words[i-1] == hyp_words[j-1]:
                i -= 1; j -= 1  # 正确
            elif i > 0 and j > 0 and d[i][j] == d[i-1][j-1] + 1:
                S += 1; i -= 1; j -= 1  # 替换
            elif i > 0 and d[i][j] == d[i-1][j] + 1:
                D += 1; i -= 1  # 删除
            elif j > 0 and d[i][j] == d[i][j-1] + 1:
                I += 1; j -= 1  # 插入
            else:
                break
    
        wer = (S + D + I) / N if N > 0 else 0.0
        return wer, S, D, I, d
    
    # --- 测试用例 ---
    reference = "the cat sat on the mat"
    hypotheses = [
        "the cat sat on the mat",          # 完美
        "the cat sit on the mat",          # 1替换
        "the cat on the mat",              # 1删除
        "the big cat sat on the mat",      # 1插入
        "a dog sat in a rug",              # 多错误
    ]
    
    print(f"参考: '{reference}'\n")
    print(f"{'假设':<40s} {'WER':>6s} {'S':>3s} {'D':>3s} {'I':>3s}")
    print("-" * 60)
    results = []
    for hyp in hypotheses:
        wer, S, D, I, dp = compute_wer(reference, hyp)
        results.append((hyp, wer, S, D, I, dp))
        print(f"'{hyp}':<40s} {wer:>6.1%} {S:>3d} {D:>3d} {I:>3d}")
    
    # 可视化最坏情况的DP矩阵
    worst = results[-1]
    hyp_words = worst[0].split()
    ref_words = reference.split()
    dp_matrix = worst[5]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # DP矩阵
    im = axes[0].imshow(dp_matrix, cmap='YlOrRd', origin='upper')
    axes[0].set_xticks(range(len(hyp_words) + 1))
    axes[0].set_xticklabels([''] + hyp_words, rotation=45, ha='right', fontsize=9)
    axes[0].set_yticks(range(len(ref_words) + 1))
    axes[0].set_yticklabels([''] + ref_words, fontsize=9)
    axes[0].set_xlabel('假设词')
    axes[0].set_ylabel('参考词')
    axes[0].set_title(f'编辑距离矩阵\nWER = {worst[1]:.1%}')
    for i in range(dp_matrix.shape[0]):
        for j in range(dp_matrix.shape[1]):
            axes[0].text(j, i, str(dp_matrix[i, j]), ha='center', va='center', fontsize=8)
    plt.colorbar(im, ax=axes[0])
    
    # WER对比条形图
    names = [f'Hyp {i+1}' for i in range(len(results))]
    wers = [r[1] * 100 for r in results]
    colors = ['#27ae60' if w == 0 else '#f39c12' if w < 30 else '#e74c3c' for w in wers]
    axes[1].barh(names, wers, color=colors)
    axes[1].set_xlabel('WER (%)')
    axes[1].set_title('词错误率对比')
    for i, (w, r) in enumerate(zip(wers, results)):
        axes[1].text(w + 1, i, f'{w:.0f}% (S={r[2]}, D={r[3]}, I={r[4]})',
                     va='center', fontsize=9)
    axes[1].set_xlim(0, max(wers) * 1.4)
    
    plt.tight_layout(); plt.show()
    

  4. 在对数梅尔谱图上实现SpecAugment(频域掩码和时间掩码),并可视化原始与增强版本。从合成信号生成谱图。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # --- 生成合成对数梅尔谱图 ---
    key = jax.random.PRNGKey(42)
    fs = 16000
    duration = 2.0
    t = jnp.arange(0, duration, 1.0 / fs)
    
    # 模拟语音: 带谐波的啁啾信号
    f0 = 120.0
    x = sum(jnp.sin(2 * jnp.pi * f0 * k * t * (1 + 0.1 * t)) / k for k in range(1, 10))
    key, subkey = jax.random.split(key)
    x = x + 0.05 * jax.random.normal(subkey, t.shape)
    
    # 计算对数梅尔谱图(简化版)
    frame_len = 400  # 25 ms
    hop_len = 160    # 10 ms
    n_fft = 512
    n_mels = 80
    
    n_frames = (len(x) - frame_len) // hop_len + 1
    hamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1))
    
    frames = jnp.stack([x[i * hop_len : i * hop_len + frame_len] for i in range(n_frames)])
    windowed = frames * hamming
    spectra = jnp.abs(jnp.fft.rfft(windowed, n=n_fft)) ** 2
    
    # 简单梅尔滤波器组
    def hz_to_mel(f): return 2595 * jnp.log10(1 + f / 700)
    def mel_to_hz(m): return 700 * (10 ** (m / 2595) - 1)
    
    mel_points = jnp.linspace(hz_to_mel(0), hz_to_mel(fs / 2), n_mels + 2)
    hz_pts = mel_to_hz(mel_points)
    bins = jnp.floor((n_fft + 1) * hz_pts / fs).astype(jnp.int32)
    
    n_freqs = n_fft // 2 + 1
    fb = jnp.zeros((n_mels, n_freqs))
    for m in range(n_mels):
        lo, mid, hi = int(bins[m]), int(bins[m+1]), int(bins[m+2])
        for k in range(lo, mid):
            if mid != lo:
                fb = fb.at[m, k].set((k - lo) / (mid - lo))
        for k in range(mid, hi):
            if hi != mid:
                fb = fb.at[m, k].set((hi - k) / (hi - mid))
    
    log_mel = jnp.log(spectra @ fb.T + 1e-10)
    
    # --- SpecAugment ---
    def spec_augment(spec, key, n_freq_masks=2, freq_mask_width=15,
                     n_time_masks=2, time_mask_width=25):
        """应用SpecAugment: 频域和时间掩码."""
        augmented = spec.copy()
        T, F = spec.shape
    
        # 频域掩码
        for _ in range(n_freq_masks):
            key, k1, k2 = jax.random.split(key, 3)
            f_width = jax.random.randint(k1, (), 1, freq_mask_width + 1)
            f_start = jax.random.randint(k2, (), 0, max(1, F - freq_mask_width))
            mask = (jnp.arange(F) >= f_start) & (jnp.arange(F) < f_start + f_width)
            augmented = jnp.where(mask[None, :], 0.0, augmented)
    
        # 时间掩码
        for _ in range(n_time_masks):
            key, k1, k2 = jax.random.split(key, 3)
            t_width = jax.random.randint(k1, (), 1, time_mask_width + 1)
            t_start = jax.random.randint(k2, (), 0, max(1, T - time_mask_width))
            mask = (jnp.arange(T) >= t_start) & (jnp.arange(T) < t_start + t_width)
            augmented = jnp.where(mask[:, None], 0.0, augmented)
    
        return augmented
    
    key, subkey = jax.random.split(key)
    log_mel_aug = spec_augment(log_mel, subkey)
    
    # --- 可视化 ---
    fig, axes = plt.subplots(2, 1, figsize=(14, 8))
    
    im0 = axes[0].imshow(log_mel.T, aspect='auto', origin='lower', cmap='inferno',
                           extent=[0, duration, 0, n_mels])
    axes[0].set_title('原始对数梅尔谱图')
    axes[0].set_xlabel('时间 (s)'); axes[0].set_ylabel('梅尔频带')
    plt.colorbar(im0, ax=axes[0], label='对数能量')
    
    im1 = axes[1].imshow(log_mel_aug.T, aspect='auto', origin='lower', cmap='inferno',
                           extent=[0, duration, 0, n_mels])
    axes[1].set_title('SpecAugment后(频域+时间掩码)')
    axes[1].set_xlabel('时间 (s)'); axes[1].set_ylabel('梅尔频带')
    plt.colorbar(im1, ax=axes[1], label='对数能量')
    
    plt.tight_layout(); plt.show()