自动语音识别¶
自动语音识别将口语音频转换为书面文本,架起人类语音与机器可读语言之间的桥梁。本文涵盖GMM-HMM、CTC损失、RNN-Transducer、基于注意力机制的编码器-解码器模型(LAS)、Whisper以及端到端ASR,从经典流程到现代神经架构。
-
自动语音识别(Automatic Speech Recognition, ASR)是将口语音频转换为书面文本的任务。它是人工智能领域最古老的问题之一(1950年代的首批系统仅能识别单个数字),也是商业部署最广泛的应用之一(语音助手、转录服务、字幕生成)。
-
其难度源于语音的巨大变异性:不同说话人、口音、语速、背景噪声、麦克风特性,以及将连续声学信号映射到离散词汇的根本性歧义。
-
可以将ASR想象成法庭速记员。速记员听到连续的声音流,在脑中将其分割为单词,利用上下文消解歧义(如"they're"、"their"、"there"),然后输入结果。ASR系统做同样的事情,但通过可显式表达并可独立或联合优化的阶段完成。
-
经典ASR流程通过一系列独立阶段处理音频:原始音频被转换为特征(文件01中的MFCCs或对数梅尔谱图),声学模型评估每个特征帧与各音素单元的匹配程度,发音模型(词典)将音素单元映射到单词,语言模型评估词序列的可能性,解码器搜索使综合得分最高的词序列。每个组件分别训练和调优。
-
音素(Phonemes)是语言中区分单词的最小声音单位。英语约有39-44个音素(确切数量取决于方言和所用音素清单)。例如,"bat"和"pat"仅在一个音素上不同(/b/ vs /p/)。大多数ASR系统建模上下文相关音素,称为三音子(triphones):由左右邻居定义的音素(例如"b_t"上下文中的"a"与"c_t"上下文中的"a"是不同单元),因为音素的声学实现受其邻居强烈影响(这称为协同发音,coarticulation)。
-
可能的三音子数量巨大(40个音素的立方=64,000),因此决策树聚类将声学相似的三音子分组为音素状态(senones,通常2000-10,000类)。每个senone拥有独立的声学模型。这种聚类是文件06中决策树算法的一种形式。
-
GMM-HMM(高斯混合模型-隐马尔可夫模型)是1980年代至2010年代初主导的声学建模方法。HMM(文件05)建模语音的时间结构:每个音素是一个从左到右的HMM,含3-5个状态,每个状态代表一个子音素片段(起始、中间、结束)。状态间转移隐式建模时长。
-
在每个HMM状态,发射概率(给定该状态下特定特征向量的可能性)由高斯混合模型(GMM,文件05中的多元高斯分布加权和)建模:
-
其中 \(\mathbf{x}\) 是特征向量(如39维MFCCs),\(s\) 是HMM状态,\(M\) 是混合分量数量(通常8-64),\(w_m\) 是混合权重,\(\boldsymbol{\mu}_m\)、\(\boldsymbol{\Sigma}_m\) 是每个高斯分量的均值和协方差。为计算效率,协方差矩阵通常设为对角阵(假设特征维度独立,这对经DCT去相关的MFCCs近似成立)。
-
训练使用Baum-Welch算法(文件05中EM算法的特例)从带标注语音数据中迭代估计GMM参数和HMM转移概率。解码(寻找最可能的状态序列)使用Viterbi算法(文件05中的动态规划):
-
其中 \(\delta_t(j)\) 是在时间 \(t\) 以状态 \(j\) 结束的最优路径概率,\(a_{ij}\) 是从状态 \(i\) 到 \(j\) 的转移概率,\(b_j(\mathbf{x}_t)\) 是状态 \(j\) 下特征 \(\mathbf{x}_t\) 的发射概率。
-
DNN-HMM(Hinton et al., 2012)用深度神经网络(文件06中的DNN)替换GMM发射模型,从特征帧窗口预测senone后验概率 \(p(s | \mathbf{x})\)。HMM仍处理时间结构和序列,但神经网络提供更具判别力的发射分数。这种混合方法相比GMM将词错误率相对降低20-30%,是2012-2016年的主导范式。
-
WFST解码(加权有限状态变换器)是传统ASR的标准解码框架。每个组件(HMM拓扑H、上下文依赖C、词典L、语法/语言模型G)表示为加权有限状态变换器,并组合成单一搜索图 \(H \circ C \circ L \circ G\)。Viterbi搜索随后在该组合图中寻找最低代价路径。WFST允许知识源的模块化组合和高效动态规划搜索。其数学框架来自有限自动机理论(与文件05中的状态机相关)。
-
端到端ASR消除了独立组件(发音模型、音素清单、WFST解码器),训练单个神经网络直接从音频特征映射到字符或词片段。核心挑战是对齐问题:输入(每秒数百个特征帧)和输出(每秒几个字符)长度差异巨大,且训练时对齐关系未知。
-
连接时序分类(CTC)(Graves et al., 2006)通过引入特殊空白(blank)标记解决对齐问题,允许网络输出任意字符和空白序列,只要折叠连续重复项并移除空白后能得到正确转写。例如,转写"cat"可由输出序列"--cc-aa-t--"产生(其中"-"表示空白)。
-
形式上,CTC定义了一个多对一映射 \(\mathcal{B}\),从所有长度-\(T\) 的输出序列(字母表加空白)到标签序列。标签序列 \(\mathbf{y}\) 的概率是所有折叠后得到它的对齐路径概率之和:
-
朴素计算该求和需枚举指数级数量的对齐,但CTC前向-后向算法使用动态规划在 \(O(T \cdot |\mathbf{y}|)\) 时间内高效计算,类比于文件05中的HMM前向-后向算法。
-
CTC做出条件独立假设:给定输入,每个时间步的输出独立于其他所有输出。这意味着CTC无法建模输出依赖(例如无法学习"q"几乎总是后跟"u")。必须使用外部语言模型处理此类依赖。
-
CTC解码选项:
- 贪婪解码:每步取最可能标记,然后折叠。速度快但次优。
- 束搜索:每步维护前 \(k\) 个部分假设,合并折叠后前缀相同的假设。可融入语言模型分数。
- 前缀束搜索:修改的束搜索,正确处理CTC空白合并,确保假设在折叠后比较。
-
RNN-Transducer(RNN-T)(Graves, 2012)通过添加显式预测网络(类似语言模型的RNN)扩展CTC,使每个输出条件依赖于先前输出,移除条件独立假设。RNN-T包含三个组件:
- 编码器:处理音频特征生成隐藏表示 \(\mathbf{h}_t^\text{enc}\)(通常为LSTM或Conformer层堆叠)。
- 预测网络:自回归RNN,从先前输出的标签生成隐藏表示 \(\mathbf{h}_u^\text{pred}\)。
- 联合网络:在每个(时间,标签)位置组合编码器和预测网络输出,生成下一个标记(含空白)的分布:
-
RNN-T每时间步可输出零个或多个标签(通过在进入下一时间步前输出非空白标记,或输出空白以推进时间但不输出)。训练使用2D(时间,标签)格上的前向-后向算法,复杂度 \(O(T \cdot U)\)(\(U\) 为输出长度)。由于天然支持流式处理(编码器从左到右处理音频,预测网络增量生成输出),RNN-T是设备端流式ASR的主导架构(用于Google Pixel手机等类似产品)。
-
Listen, Attend and Spell(LAS)(Chan et al., 2016)是基于注意力机制的编码器-解码器模型(文件06中的序列到序列架构)。包含三个组件:
- Listener(编码器):金字塔双向LSTM,处理完整输入序列并以8倍因子下采样(每层拼接连续隐藏状态对),生成较短的编码器隐藏状态序列。
- Attention:每解码步计算所有编码器状态的注意力权重以形成上下文向量(文件07中的相同注意力机制)。
- Speller(解码器):自回归LSTM,每次生成一个字符,条件于上下文向量和先前生成的字符。
-
LAS取得强劲结果,但需要完整话语可用后才能解码(因为注意力需关注所有编码器状态),不适合流式应用。对于超长话语,注意力在长序列上变得分散,效果下降。
-
Conformer(Gulati et al., 2020)结合卷积的局部模式捕获能力与自注意力的全局依赖建模能力。每个Conformer模块采用三明治结构,含四个子模块:
- 前馈模块(半步):带残差连接的前馈网络,使用一半残差权重。
- 多头自注意力模块:文件07中的标准Transformer自注意力,含相对位置编码。
- 卷积模块:逐点卷积、门控线性单元(GLU)、1D深度卷积、批归一化、Swish激活、另一个逐点卷积。深度卷积捕获局部上下文(类似特征序列上的n-gram)。
- 前馈模块(半步):与模块1相同。
-
输出为:\(\mathbf{y} = \text{LayerNorm}(\mathbf{x} + \frac{1}{2}\text{FFN}_1 + \text{MHSA} + \text{Conv} + \frac{1}{2}\text{FFN}_2)\)。这种类似马卡龙的结构(FFN-Attention-Conv-FFN)配合半步残差,经实证发现优于其他排序方式。Conformer已成为CTC和RNN-T系统的默认编码器,优于纯Transformer和纯LSTM编码器。
-
Whisper(Radford et al., 2023)是OpenAI的大规模基于注意力的ASR模型。使用标准编码器-解码器Transformer架构(文件07),在从网络抓取的68万小时弱监督数据(音频配对近似转写)上训练。关键设计选择:
- 输入:80通道对数梅尔谱图(文件01),25ms窗长、10ms跳长,归一化为零均值单位方差。
- 编码器:标准Transformer编码器,含正弦位置嵌入和预激活层归一化。
- 解码器:Transformer解码器,使用字节级BPE分词器(文件07)自回归生成标记。
- 多任务:单个模型处理转写、翻译、语言识别和时间戳预测,通过解码器提示中的特殊任务标记条件控制。
- 训练数据规模(而非架构创新)是Whisper在领域、口音和语言间强大泛化能力的主要驱动因素。
-
wav2vec 2.0(Baevski et al., 2020)是语音表征的自监督预训练框架。核心思想是从大量无标签音频学习语音表征,然后用少量标注数据微调。这遵循与BERT(文件07)相同的自监督范式,但适配连续音频信号。
-
wav2vec 2.0架构包含三部分:
- 特征编码器:多层1D CNN,处理原始波形样本,以20ms帧率(16kHz下每320个样本一个向量)生成潜在表示 \(\mathbf{z}_t\)。
- 量化模块:使用乘积量化将潜在表示离散化为有限码本(将向量分组并独立量化每组,从 \(G\) 个码本中各选 \(V\) 个条目)。这为对比学习目标生成目标 \(\mathbf{q}_t\)。
- 上下文网络:Transformer编码器,接收(部分掩码的)潜在表示并生成上下文化表示 \(\mathbf{c}_t\)。
- 预训练期间,随机跨度的潜在表示被掩码(替换为学习到的掩码嵌入),模型必须从一组干扰项(从同一话语其他位置采样的负样本)中识别掩码位置的真实量化表示。对比损失为:
-
其中 \(\text{sim}\) 为余弦相似度,\(\kappa\) 为温度参数,\(Q_t\) 包含真实量化目标和干扰项。额外的多样性损失鼓励均衡使用所有码本条目。该损失本质上是InfoNCE对比损失,与视觉自监督学习中使用的对比目标同属一族。
-
预训练后,在顶部添加线性投影和CTC头,并在标注数据上微调。仅用10分钟标注数据(使用53,000小时无标签音频预训练),wav2vec 2.0即达到接近最先进的结果,展示了自监督学习在低资源语音识别中的强大能力。
-
HuBERT(Hsu et al., 2021)是另一种自监督方法,用掩码预测目标(预测掩码帧的离散聚类分配)替换对比目标。目标由离线聚类步骤生成(第一轮在MFCCs上执行k-means,后续轮次在HuBERT特征上执行k-means)。相比wav2vec 2.0,HuBERT简化了训练流程(无需量化模块或对比采样),并达到相当或更好的结果。
-
Fast Conformer(Rekesh et al., 2023, NVIDIA NeMo)用下采样注意力机制替换标准Conformer中的二次自注意力:计算注意力前压缩输入序列(通常通过步长卷积压缩8倍),然后扩展回原尺寸。这将注意力代价从 \(O(T^2)\) 降低到 \(O(T^2/64)\),同时保留全局上下文,使训练超长话语(长达数分钟)时无内存问题。Fast Conformer是NVIDIA NeMo工具包的默认编码器,构成其生产级模型的主干。
-
Parakeet(NVIDIA, 2024)是基于Fast Conformer编码器、搭配CTC和RNN-T解码器的高精度英语ASR模型家族,在64,000小时英语语音上训练。发布时,Parakeet模型(0.6B和1.1B参数)在标准基准上达到最低词错误率,在大多数英语测试集上超越Whisper large-v3。关键要素包括高效的Fast Conformer架构、激进的数据增强(SpecAugment、速度扰动、噪声混合)和大规模监督训练数据——表明对已知组件的精心工程仍能推动技术发展。
-
Canary(NVIDIA, 2024)将NeMo框架扩展到多语言和多任务ASR。使用Fast Conformer编码器搭配基于注意力的解码器(而非CTC或RNN-T),在单个模型中处理多语言的转写和翻译(类似Whisper的多任务设计,但采用更高效的Fast Conformer主干)。Canary模型支持英语、德语、西班牙语和法语,具有竞争力精度。
-
Moonshine(Useful Sensors, 2024)是专为设备端和边缘部署优化的ASR模型家族。编码器采用混合架构,用小型CNN后接少量Transformer层替换初始Transformer/Conformer层,大幅降低模型尺寸(基础模型不足3000万参数)。Moonshine面向CPU和低功耗设备上的实时流式处理,在Whisper过大过慢的场景下,以部分精度换取5-10倍的更低延迟和内存占用。
-
Distil-Whisper(Gandhi et al., 2023)应用知识蒸馏(文件06)将Whisper压缩为更小、更快的模型。学生模型仅用2层解码器(相比Whisper的32层),保留完整编码器,训练以匹配Whisper的输出分布。Distil-Whisper在词错误率上与教师模型差距在1%以内,同时速度快6倍,使完整Whisper模型过慢的实时应用变得可行。
-
通用语音模型(USM)(Zhang et al., 2023, Google)将自监督预训练扩展到300+语言的1200万小时无标签音频,随后进行监督微调。USM证明wav2vec 2.0/自监督范式可扩展到真正大规模数据场景,在标注数据极少的低资源语言上实现强劲性能。
-
大规模多语言语音(MMS)(Pratap et al., 2023, Meta)将wav2vec 2.0预训练扩展到1100+语言,使用宗教录音和其他多语言音频源。MMS覆盖的语言远超以往任何ASR系统,首次使许多低资源语言能够进行语音识别。
-
现代ASR格局正收敛于几种主导模式:(1) 用于流式的Conformer家族编码器搭配CTC或RNN-T,(2) 用于离线/多任务的编码器-解码器Transformer,(3) 用于低资源场景的自监督预训练,(4) 规模——更多数据和更大模型持续提升精度。选择取决于部署约束:延迟预算、可用算力、语言数量、以及应用是流式还是批处理。
-
语言模型集成通过融入声学模型未捕获的语言学知识提升ASR。基本思想是在解码时将声学模型分数 \(p(\mathbf{x} | \mathbf{y})\)(音频与转写的匹配程度)与语言模型分数 \(p(\mathbf{y})\)(转写作为句子的可能性)结合。
-
浅层融合在束搜索时结合分数:
-
其中 \(\lambda\) 为可调权重,\(p_\text{LM}\) 为外部语言模型(通常为文件07中的n-gram或神经语言模型)。这种方法简单有效,但要求语言模型与ASR模型使用相同的标记词汇表。
-
深层融合(Gulcehre et al., 2015)将语言模型集成到解码器网络内部:LM隐藏状态与解码器隐藏状态拼接,经门控机制处理后进行输出投影。整个系统(包括预训练LM)联合微调。这允许更深层次集成,但训练更复杂。
-
冷融合(Sriram et al., 2018)类似深层融合,但从头训练集成语言模型的ASR解码器,而非微调预训练解码器。这迫使声学模型学习互补信息,而非重复LM已掌握的知识。
-
重排序(N-best重排序)是两阶段方法:先用束搜索生成 \(N\) 个候选转写,再用更强大的语言模型(如大型Transformer LM)重新排序。实现简单,且允许使用对首遍解码过慢的超大型语言模型。
-
内部语言模型估计(ILME)解决一个微妙问题:端到端模型从训练转写中隐式学习内部语言模型,可能在浅层融合时与外部语言模型冲突(本质上是双重计数语言先验)。ILME估计内部语言模型并在融合时减去其分数:
-
流式 vs. 离线ASR是根本性的架构选择。离线(或批处理)ASR在处理完整话语后才输出任何结果。流式ASR在音频到达时增量输出,具有有界延迟。
-
流式对实时应用至关重要:实时字幕、语音助手(用户期望在说完前得到响应)、电话通话转录。挑战在于:部分未来上下文有助于识别(知道下一个词是"York"可消解"New"的歧义),但流式系统不能等待任意长的未来上下文。
-
单向编码器(从左到右的LSTM、因果卷积、因果Transformer)天然支持流式,因为每个输出仅依赖过去和当前输入。双向编码器(关注未来上下文)无法直接支持流式。
-
分块注意力(也称块状或分段注意力)将输入划分为固定长度块,仅在每块内(及可选的前几个块)应用自注意力。这将延迟限制为块大小加处理时间,同时允许每块内一定的局部双向上下文。权衡是:块尺寸越小,精度下降越明显。
-
前瞻允许流式编码器在输出当前帧前窥视少量未来帧(如300-900ms)。通过为单向计算添加小的右上下文实现。前瞻窗口增加延迟,但显著提升精度。
-
流式ASR中的延迟包含多个组件:
- 算法延迟:音频到达至模型可处理的时间差(由块大小、前瞻和特征提取决定)。
- 计算延迟:运行模型前向传播的时间。
- 端点检测延迟:检测用户说完话的延迟。
- 首词延迟:第一个词出现的速度。最终确认延迟:最终输出确认的速度(流式系统通常生成临时输出,随更多音频到达而修正)。
-
ASR的评估指标:
-
词错误率(Word Error Rate, WER)是主要指标。通过编辑距离(将一个序列转换为另一个所需的最少替换、删除、插入操作数)对齐假设(系统输出)与参考(真实转写),然后计算:
-
其中 \(S\) 为替换数,\(D\) 为删除数,\(I\) 为插入数,\(N\) 为参考中的总词数。若插入很多,WER可超过100%。对于清晰朗读语音,5%的WER大致相当于人类水平;对话或噪声语音则困难得多(10-20%+)。
-
字符错误率(Character Error Rate, CER)是相同公式在字符级别而非词级别的应用。对于无清晰词边界的语言(中文、日语)以及评估近似错误("cat" vs "bat"的WER为100%但CER为33%),CER更具信息量。
-
词信息丢失(WIL)和词信息保留(WIP)是信息论替代指标,比WER更精确地考虑参考与假设间的相关性,但报告较少。
-
实时因子(Real-Time Factor, RTF)衡量计算效率:处理时间与音频时长的比值。RTF < 1表示系统运行快于实时;RTF > 1表示无法跟上实时音频。流式系统必须保持RTF < 1。
-
数据增强对鲁棒ASR至关重要。常用技术:
- 速度扰动:以0.9倍和1.1倍速度重采样音频(改变音高和时长)。
- SpecAugment(Park et al., 2019):在谱图中掩码随机频带和时间步。这是音频领域的dropout类比,是ASR最有效的正则化技术之一。无需额外数据。
- 噪声增强:以不同信噪比将干净语音与录制噪声混合。
- 房间脉冲响应模拟:将干净语音与模拟房间声学卷积,模拟混响环境。
-
ASR的分词决定模型的输出词汇表。选项包括:
- 字符:简单、词汇表小(英语约30个),但输出序列长且无隐式语言建模。
- 词片段 / BPE(文件07):子词单元,平衡词汇表大小和序列长度。现代系统的标准(Whisper使用字节级BPE,约50,000标记)。
- 词:词汇表大(50,000+)、输出序列短,但无法处理未登录词。
- 音素:语言学动机、紧凑,但需要发音词典。
-
ASR的演进可总结为:从高度工程化的模块化系统(1990年代-2010年代的GMM-HMM + WFST解码),到混合系统(2012-2016年的DNN-HMM),再到将越来越多流程吸收到单一神经网络中的端到端系统(2016-2020年的CTC、RNN-T、LAS),再到利用海量无标签或弱标签数据的大规模预训练模型(2020年至今的wav2vec 2.0、Whisper)。每次转变都简化了工程同时提升了精度,遵循机器学习更广泛的趋势:从数据中学习表征而非手工设计(如同文件06中图像特征被CNN取代、文件07中NLP特征被Transformer取代的故事)。
编程任务(使用CoLab或Notebook)¶
-
用JAX从头实现CTC损失。创建一个短序列logits和目标标签的玩具示例,计算CTC前向算法得到总概率,并计算负对数似然损失。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt def ctc_forward(log_probs, targets): """ CTC前向算法(对数域以保证数值稳定性)。 log_probs: (T, V) 词汇表上的对数概率(索引0=空白) targets: (U,) 目标标签索引(不含空白) 返回:目标序列在CTC下的对数概率。 """ T, V = log_probs.shape U = len(targets) # 构建带空白的扩展标签序列: [blank, y1, blank, y2, ..., yU, blank] S = 2 * U + 1 labels = jnp.zeros(S, dtype=jnp.int32) # 全空白 for i in range(U): labels = labels.at[2 * i + 1].set(targets[i]) # 初始化alpha(对数域) NEG_INF = -1e30 alpha = jnp.full((T, S), NEG_INF) alpha = alpha.at[0, 0].set(log_probs[0, labels[0]]) # 从空白开始 alpha = alpha.at[0, 1].set(log_probs[0, labels[1]]) # 或第一个标签 # 前向填充 for t in range(1, T): for s in range(S): # 同一状态 a = alpha[t - 1, s] # 来自前一状态 if s > 0: a = jnp.logaddexp(a, alpha[t - 1, s - 1]) # 跳过空白(如果当前和隔一个标签不同) if s > 1 and labels[s] != 0 and labels[s] != labels[s - 2]: a = jnp.logaddexp(a, alpha[t - 1, s - 2]) alpha = alpha.at[t, s].set(a + log_probs[t, labels[s]]) # 总对数概率:最后时间步最后两个状态之和 log_prob = jnp.logaddexp(alpha[T - 1, S - 1], alpha[T - 1, S - 2]) return log_prob, alpha # --- 玩具示例 --- T = 12 # 输入长度(时间步) V = 5 # 词汇表大小(0=空白, 1='c', 2='a', 3='t', 4='x') targets = jnp.array([1, 2, 3]) # "c", "a", "t" # 创建随机logits并转换为对数概率 key = jax.random.PRNGKey(42) logits = jax.random.normal(key, (T, V)) log_probs = jax.nn.log_softmax(logits, axis=-1) log_prob, alpha = ctc_forward(log_probs, targets) ctc_loss = -log_prob print(f"目标序列: {targets.tolist()} ('c', 'a', 't')") print(f"输入长度 T={T}, 词汇表大小 V={V}") print(f"CTC对数概率: {log_prob:.4f}") print(f"CTC损失(负对数概率): {ctc_loss:.4f}") # 可视化前向变量(alpha格) fig, ax = plt.subplots(figsize=(12, 5)) # 从对数转换为线性以便可视化 alpha_linear = jnp.exp(alpha - jnp.max(alpha)) # 归一化以便显示 im = ax.imshow(alpha_linear.T, aspect='auto', origin='lower', cmap='viridis') ax.set_xlabel('时间步 (t)') ax.set_ylabel('扩展标签索引 (s)') label_names = ['_', 'c', '_', 'a', '_', 't', '_'] # _ = 空白 ax.set_yticks(range(len(label_names))) ax.set_yticklabels(label_names) ax.set_title(f'CTC前向变量(alpha格)| 损失 = {ctc_loss:.2f}') plt.colorbar(im, ax=ax, label='归一化概率') plt.tight_layout(); plt.show() -
用JAX构建简单的基于编码器-解码器注意力的ASR模型(最小LAS类架构)。使用1D卷积编码器和单层解码器搭配点积注意力。在合成数据上运行并可视化注意力权重。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt # --- 最小注意力编码器-解码器ASR模型 --- def init_params(key, input_dim, hidden_dim, vocab_size): """初始化小型LAS类模型参数.""" keys = jax.random.split(key, 8) scale = 0.1 params = { # 编码器: 简单线性投影(模拟卷积输出) 'enc_w': jax.random.normal(keys[0], (input_dim, hidden_dim)) * scale, 'enc_b': jnp.zeros(hidden_dim), # 注意力: query, key, value投影 'attn_q': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * scale, 'attn_k': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * scale, 'attn_v': jax.random.normal(keys[3], (hidden_dim, hidden_dim)) * scale, # 解码器RNN(简单Elman RNN示例) 'dec_wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * scale, 'dec_wx': jax.random.normal(keys[5], (vocab_size, hidden_dim)) * scale, 'dec_wc': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * scale, 'dec_b': jnp.zeros(hidden_dim), # 输出投影 'out_w': jax.random.normal(keys[7], (hidden_dim, vocab_size)) * scale, 'out_b': jnp.zeros(vocab_size), } return params def encode(params, x): """编码器: 线性投影(替代卷积/LSTM堆叠).""" return jnp.tanh(x @ params['enc_w'] + params['enc_b']) def attend(params, query, enc_out): """对编码器输出的点积注意力.""" q = query @ params['attn_q'] # (hidden,) k = enc_out @ params['attn_k'] # (T_enc, hidden) v = enc_out @ params['attn_v'] # (T_enc, hidden) d_k = q.shape[-1] scores = (k @ q) / jnp.sqrt(d_k) # (T_enc,) weights = jax.nn.softmax(scores) # (T_enc,) context = weights @ v # (hidden,) return context, weights def decode_step(params, h_prev, y_prev_onehot, enc_out): """单步解码: RNN + 注意力.""" # 嵌入前一标记 y_emb = y_prev_onehot @ params['dec_wx'] # (hidden,) # 注意力编码 context, attn_w = attend(params, h_prev, enc_out) # RNN更新 h = jnp.tanh(h_prev @ params['dec_wh'] + y_emb + context @ params['dec_wc'] + params['dec_b']) # 输出logits logits = h @ params['out_w'] + params['out_b'] return h, logits, attn_w # --- 设置 --- key = jax.random.PRNGKey(0) input_dim = 40 # 如40个梅尔频带 hidden_dim = 64 vocab_size = 10 # 演示用小词汇表 T_enc = 30 # 编码器时间步 T_dec = 8 # 解码器步数 params = init_params(key, input_dim, hidden_dim, vocab_size) # 合成输入: 随机类梅尔特征 key, subkey = jax.random.split(key) x = jax.random.normal(subkey, (T_enc, input_dim)) # 编码 enc_out = encode(params, x) # 解码(教师强制,使用随机目标) key, subkey = jax.random.split(key) targets = jax.random.randint(subkey, (T_dec,), 0, vocab_size) h = jnp.zeros(hidden_dim) all_logits = [] all_attn = [] for t in range(T_dec): y_prev = jax.nn.one_hot(targets[t] if t > 0 else 0, vocab_size) h, logits, attn_w = decode_step(params, h, y_prev, enc_out) all_logits.append(logits) all_attn.append(attn_w) all_attn = jnp.stack(all_attn) # (T_dec, T_enc) all_logits = jnp.stack(all_logits) # (T_dec, vocab_size) # --- 可视化注意力权重 --- fig, axes = plt.subplots(1, 2, figsize=(14, 5)) im = axes[0].imshow(all_attn, aspect='auto', cmap='Blues', origin='lower') axes[0].set_xlabel('编码器时间步') axes[0].set_ylabel('解码步') axes[0].set_title('注意力权重(解码器 -> 编码器)') plt.colorbar(im, ax=axes[0]) # 显示每解码步的预测标记分布 im2 = axes[1].imshow(jax.nn.softmax(all_logits, axis=-1), aspect='auto', cmap='Oranges', origin='lower') axes[1].set_xlabel('词汇表索引') axes[1].set_ylabel('解码步') axes[1].set_title('输出标记概率') plt.colorbar(im2, ax=axes[1]) plt.suptitle('最小注意力基ASR模型(未训练)') plt.tight_layout(); plt.show() -
使用动态规划(编辑距离)从头计算词错误率(WER),并评估多个假设与参考的匹配。可视化编辑距离矩阵。
import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np def compute_wer(reference, hypothesis): """ 使用动态规划计算WER(词级莱文斯坦距离)。 返回WER、替换数、删除数、插入数和DP矩阵。 """ ref_words = reference.split() hyp_words = hypothesis.split() N = len(ref_words) M = len(hyp_words) # DP矩阵: d[i][j] = ref[:i] 与 hyp[:j] 的编辑距离 d = np.zeros((N + 1, M + 1), dtype=np.int32) # 回溯矩阵用于计数S, D, I ops = np.zeros((N + 1, M + 1, 3), dtype=np.int32) # [sub, del, ins] for i in range(N + 1): d[i][0] = i # 全删除 for j in range(M + 1): d[0][j] = j # 全插入 for i in range(1, N + 1): for j in range(1, M + 1): if ref_words[i - 1] == hyp_words[j - 1]: sub_cost = d[i - 1][j - 1] # 匹配,无编辑 else: sub_cost = d[i - 1][j - 1] + 1 # 替换 del_cost = d[i - 1][j] + 1 # 删除 ins_cost = d[i][j - 1] + 1 # 插入 d[i][j] = min(sub_cost, del_cost, ins_cost) # 回溯计数操作 i, j = N, M S, D, I = 0, 0, 0 while i > 0 or j > 0: if i > 0 and j > 0 and d[i][j] == d[i-1][j-1] and ref_words[i-1] == hyp_words[j-1]: i -= 1; j -= 1 # 正确 elif i > 0 and j > 0 and d[i][j] == d[i-1][j-1] + 1: S += 1; i -= 1; j -= 1 # 替换 elif i > 0 and d[i][j] == d[i-1][j] + 1: D += 1; i -= 1 # 删除 elif j > 0 and d[i][j] == d[i][j-1] + 1: I += 1; j -= 1 # 插入 else: break wer = (S + D + I) / N if N > 0 else 0.0 return wer, S, D, I, d # --- 测试用例 --- reference = "the cat sat on the mat" hypotheses = [ "the cat sat on the mat", # 完美 "the cat sit on the mat", # 1替换 "the cat on the mat", # 1删除 "the big cat sat on the mat", # 1插入 "a dog sat in a rug", # 多错误 ] print(f"参考: '{reference}'\n") print(f"{'假设':<40s} {'WER':>6s} {'S':>3s} {'D':>3s} {'I':>3s}") print("-" * 60) results = [] for hyp in hypotheses: wer, S, D, I, dp = compute_wer(reference, hyp) results.append((hyp, wer, S, D, I, dp)) print(f"'{hyp}':<40s} {wer:>6.1%} {S:>3d} {D:>3d} {I:>3d}") # 可视化最坏情况的DP矩阵 worst = results[-1] hyp_words = worst[0].split() ref_words = reference.split() dp_matrix = worst[5] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # DP矩阵 im = axes[0].imshow(dp_matrix, cmap='YlOrRd', origin='upper') axes[0].set_xticks(range(len(hyp_words) + 1)) axes[0].set_xticklabels([''] + hyp_words, rotation=45, ha='right', fontsize=9) axes[0].set_yticks(range(len(ref_words) + 1)) axes[0].set_yticklabels([''] + ref_words, fontsize=9) axes[0].set_xlabel('假设词') axes[0].set_ylabel('参考词') axes[0].set_title(f'编辑距离矩阵\nWER = {worst[1]:.1%}') for i in range(dp_matrix.shape[0]): for j in range(dp_matrix.shape[1]): axes[0].text(j, i, str(dp_matrix[i, j]), ha='center', va='center', fontsize=8) plt.colorbar(im, ax=axes[0]) # WER对比条形图 names = [f'Hyp {i+1}' for i in range(len(results))] wers = [r[1] * 100 for r in results] colors = ['#27ae60' if w == 0 else '#f39c12' if w < 30 else '#e74c3c' for w in wers] axes[1].barh(names, wers, color=colors) axes[1].set_xlabel('WER (%)') axes[1].set_title('词错误率对比') for i, (w, r) in enumerate(zip(wers, results)): axes[1].text(w + 1, i, f'{w:.0f}% (S={r[2]}, D={r[3]}, I={r[4]})', va='center', fontsize=9) axes[1].set_xlim(0, max(wers) * 1.4) plt.tight_layout(); plt.show() -
在对数梅尔谱图上实现SpecAugment(频域掩码和时间掩码),并可视化原始与增强版本。从合成信号生成谱图。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt # --- 生成合成对数梅尔谱图 --- key = jax.random.PRNGKey(42) fs = 16000 duration = 2.0 t = jnp.arange(0, duration, 1.0 / fs) # 模拟语音: 带谐波的啁啾信号 f0 = 120.0 x = sum(jnp.sin(2 * jnp.pi * f0 * k * t * (1 + 0.1 * t)) / k for k in range(1, 10)) key, subkey = jax.random.split(key) x = x + 0.05 * jax.random.normal(subkey, t.shape) # 计算对数梅尔谱图(简化版) frame_len = 400 # 25 ms hop_len = 160 # 10 ms n_fft = 512 n_mels = 80 n_frames = (len(x) - frame_len) // hop_len + 1 hamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1)) frames = jnp.stack([x[i * hop_len : i * hop_len + frame_len] for i in range(n_frames)]) windowed = frames * hamming spectra = jnp.abs(jnp.fft.rfft(windowed, n=n_fft)) ** 2 # 简单梅尔滤波器组 def hz_to_mel(f): return 2595 * jnp.log10(1 + f / 700) def mel_to_hz(m): return 700 * (10 ** (m / 2595) - 1) mel_points = jnp.linspace(hz_to_mel(0), hz_to_mel(fs / 2), n_mels + 2) hz_pts = mel_to_hz(mel_points) bins = jnp.floor((n_fft + 1) * hz_pts / fs).astype(jnp.int32) n_freqs = n_fft // 2 + 1 fb = jnp.zeros((n_mels, n_freqs)) for m in range(n_mels): lo, mid, hi = int(bins[m]), int(bins[m+1]), int(bins[m+2]) for k in range(lo, mid): if mid != lo: fb = fb.at[m, k].set((k - lo) / (mid - lo)) for k in range(mid, hi): if hi != mid: fb = fb.at[m, k].set((hi - k) / (hi - mid)) log_mel = jnp.log(spectra @ fb.T + 1e-10) # --- SpecAugment --- def spec_augment(spec, key, n_freq_masks=2, freq_mask_width=15, n_time_masks=2, time_mask_width=25): """应用SpecAugment: 频域和时间掩码.""" augmented = spec.copy() T, F = spec.shape # 频域掩码 for _ in range(n_freq_masks): key, k1, k2 = jax.random.split(key, 3) f_width = jax.random.randint(k1, (), 1, freq_mask_width + 1) f_start = jax.random.randint(k2, (), 0, max(1, F - freq_mask_width)) mask = (jnp.arange(F) >= f_start) & (jnp.arange(F) < f_start + f_width) augmented = jnp.where(mask[None, :], 0.0, augmented) # 时间掩码 for _ in range(n_time_masks): key, k1, k2 = jax.random.split(key, 3) t_width = jax.random.randint(k1, (), 1, time_mask_width + 1) t_start = jax.random.randint(k2, (), 0, max(1, T - time_mask_width)) mask = (jnp.arange(T) >= t_start) & (jnp.arange(T) < t_start + t_width) augmented = jnp.where(mask[:, None], 0.0, augmented) return augmented key, subkey = jax.random.split(key) log_mel_aug = spec_augment(log_mel, subkey) # --- 可视化 --- fig, axes = plt.subplots(2, 1, figsize=(14, 8)) im0 = axes[0].imshow(log_mel.T, aspect='auto', origin='lower', cmap='inferno', extent=[0, duration, 0, n_mels]) axes[0].set_title('原始对数梅尔谱图') axes[0].set_xlabel('时间 (s)'); axes[0].set_ylabel('梅尔频带') plt.colorbar(im0, ax=axes[0], label='对数能量') im1 = axes[1].imshow(log_mel_aug.T, aspect='auto', origin='lower', cmap='inferno', extent=[0, duration, 0, n_mels]) axes[1].set_title('SpecAugment后(频域+时间掩码)') axes[1].set_xlabel('时间 (s)'); axes[1].set_ylabel('梅尔频带') plt.colorbar(im1, ax=axes[1], label='对数能量') plt.tight_layout(); plt.show()