Skip to content

多模态表示

多模态表示将视觉、语言和音频桥接到共享的嵌入空间。本文涵盖融合策略、CLIP、ALIGN、SigLIP、对比损失函数(InfoNCE、NT-Xent)、零样本分类以及检索评估。

  • 想象你坐在咖啡馆里。你看到桌上冒着热气的杯子,听到陶瓷的轻碰声,闻到烘焙咖啡豆的香气,感受到从马克杯辐射出的温暖。没有任何单一感官能告诉你全部信息:你的大脑将这些信号融合成"热咖啡"的统一感知。多模态学习为机器做同样的事情:结合来自多种模态(视觉、语言、音频等)的信息,构建比任何单一模态单独提供更丰富、更鲁棒的表示。

  • 模态(modality)是独立的信息通道。在机器学习中,最常见的模态是图像(像素网格)、文本(标记序列)、音频(波形或谱图,如第9章)、视频(帧序列)和结构化数据(表格、图)。每种模态有其自身的统计结构:图像具有空间连贯性,文本是序列且离散的,音频是时间且连续的。多模态学习的挑战在于桥接这些根本不同的数据类型。

  • 为什么要结合多种模态?因为它们提供互补信息。一张狗的照片告诉你它的品种和颜色,但不告诉你它的名字。标题"我的金毛寻回犬Max"告诉你名字和品种,但不告诉你确切姿势。图像和文本结合比单独任一提供更完整的画面。这种互补性是多模态模型能够回答问题、生成内容并做出单模态模型无法完成的决策的核心动机。

多模态学习概览:独立编码器处理图像、文本和音频输入,其表示在共享嵌入空间中交汇

融合策略

  • 想象一个小组项目。你可以用两种方式结合想法:所有人从一开始就在同一房间协作(共享原始笔记和草稿),或者每个人独立撰写自己的部分,然后合并最终文档。这对应于多模态学习中的早期融合晚期融合

  • 早期融合(也称特征级融合)在任何实质性处理发生之前,拼接或混合来自不同模态的原始或低级特征。例如,你可能将图像的像素特征与文本的标记嵌入拼接,然后将组合序列输入单个Transformer。模型可以从一开始学习细粒度的跨模态交互,但输入空间很大,模型必须学会同时处理非常不同的数据类型。

  • 形式上,给定来自两个模态的特征向量 \(x_{\text{img}} \in \mathbb{R}^{d_1}\)\(x_{\text{txt}} \in \mathbb{R}^{d_2}\),早期融合简单地拼接它们:

\[x_{\text{fused}} = [x_{\text{img}}; x_{\text{txt}}] \in \mathbb{R}^{d_1 + d_2}\]
  • 该拼接向量随后由共享网络处理。优点是模型可以在每一层发现跨模态相关性。缺点是计算成本以及对齐非常不同特征类型(密集像素值与稀疏标记索引)的难度。

  • 晚期融合(也称决策级融合)通过各自独立的编码器处理每种模态,为每种模态生成高级表示甚至最终预测。这些输出随后被组合,通常通过平均分数、投票或学习到的组合层。晚期融合更简单,允许你直接重用预训练的单模态模型,但无法捕获低级跨模态交互,因为模态从未"看到"彼此的原始特征。

  • 给定模态特定预测 \(\hat{y}_1\)\(\hat{y}_2\),简单的晚期融合规则是:

\[\hat{y} = \alpha \hat{y}_1 + (1 - \alpha) \hat{y}_2\]
  • 其中 \(\alpha \in [0, 1]\) 是学习到的或手工调整的混合权重。

  • 中期融合(也称中间融合)是大多数现代系统采用的务实中间方案。每种模态首先由其自己的编码器处理(提取模态特定特征),然后编码表示在网络中途组合,通常通过交叉注意力层。这允许每个编码器专攻其模态,同时仍能实现丰富的跨模态交互。Flamingo、LLaVA和大多数视觉-语言模型(文件02)使用中期融合。

早期、中期和晚期融合策略:早期融合拼接原始输入,中期融合通过交叉注意力合并中间表示,晚期融合组合最终预测

  • 融合策略的选择取决于数据可用性、计算预算和任务。早期融合强大但数据需求高。晚期融合廉价但能力有限。带交叉注意力的中期融合已成为大规模多模态模型的主导方法,因为它在表达能力与模块化之间取得平衡。

联合嵌入空间

  • 想象一个通用翻译器,可以将任何语言的任何句子映射到共享"意义空间"中的同一点。英语、法语或日语中的句子"海滩上的一只狗"都会落到相同的坐标。联合嵌入空间正是如此,但跨模态实现:海滩上狗的图像和文本"海滩上的一只狗"应映射到同一向量空间中的邻近点。

  • 形式上,我们学习两个编码器函数:\(f_\theta : \mathcal{X}_1 \to \mathbb{R}^d\) 用于模态1(如图像),\(g_\phi : \mathcal{X}_2 \to \mathbb{R}^d\) 用于模态2(如文本)。两者都将输入映射到相同的 \(d\) 维空间。训练目标确保语义匹配的对 \((x_1, x_2)\) 具有接近的嵌入 \(f_\theta(x_1)\)\(g_\phi(x_2)\)(高余弦相似度),而不匹配的对则相距较远。

  • 这是第7章词嵌入空间的直接推广。回想Word2Vec和GloVe将语义相似的词放在向量空间中彼此靠近的位置。联合嵌入空间将这一思想扩展到跨模态:我们不再测量词与词的相似度,而是测量图像与文本、音频与文本、甚至图像与音频的相似度。

  • 相似度度量几乎总是余弦相似度(第1章):

\[\text{sim}(u, v) = \frac{u \cdot v}{\|u\| \|v\|}\]
  • 通过对所有嵌入进行 \(L_2\) 归一化到单位超球面,余弦相似度简化为简单的点积 \(u \cdot v\),计算极其高效,且可通过近似最近邻库加速。

联合嵌入空间:图像编码器和文本编码器将各自输入映射到共享向量空间,匹配对在此聚类

  • 联合嵌入空间的强大之处在于它支持零样本迁移。一旦对齐了图像和文本嵌入,你就可以将图像分类到从未训练过的类别:只需将类别名称作为文本嵌入,找到与图像嵌入最接近的文本嵌入。无需任务特定微调。这是CLIP及其后继者的关键洞见。

用于多模态对齐的对比学习

  • 想象一个课堂练习:学生拿到打乱的照片和标题对,要求将每张照片与其正确标题匹配。要做好这一点,你需要理解视觉内容和语言,并知道它们如何关联。对比学习正是以这种方式训练模型:给定一批(图像,文本)对,模型必须找出哪张图像对应哪个文本。

  • 正如我们在第8章(文件04)所见,单模态设置中的对比学习(SimCLR、MoCo)将同一图像的不同增强视图拉近,将不同图像的视图推远。多模态对比学习用"匹配模态"替换"增强视图":图像及其标题是正例对;图像与批次中任何其他标题配对是负例对。

CLIP

  • CLIP(Contrastive Language-Image Pre-training, Radford et al., 2021)是多模态对比学习的基础模型。它在从网络抓取的4亿(图像,文本)对上联合训练图像编码器(ViT或ResNet,第8章)和文本编码器(Transformer,第7章)。

  • 给定一批 \(N\) 个图像-文本对,CLIP计算所有图像嵌入与所有文本嵌入之间的 \(N \times N\) 余弦相似度矩阵。对角线项是匹配对(正例);所有非对角线项是不匹配对(负例)。训练损失推高对角线项、压低非对角线项。

  • 损失是对称交叉熵。对于与文本 \(j = i\) 配对的图像 \(i\),图像到文本的损失是:

\[\mathcal{L}_{i \to t} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(\text{sim}(z_i^{\text{img}}, z_i^{\text{txt}}) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z_i^{\text{img}}, z_k^{\text{txt}}) / \tau)}\]
  • 文本到图像的损失角色互换,形式相同:
\[\mathcal{L}_{t \to i} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(\text{sim}(z_i^{\text{txt}}, z_i^{\text{img}}) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z_i^{\text{txt}}, z_k^{\text{img}}) / \tau)}\]
  • 总CLIP损失是两者的平均:
\[\mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{i \to t} + \mathcal{L}_{t \to i})\]
  • 这里 \(\tau\) 是学习的温度参数(初始化为 \(\tau = 0.07\))。温度控制softmax分布的锐度:低 \(\tau\) 使模型更专注于最接近的匹配,高 \(\tau\) 使概率分布更均匀。CLIP将 \(\tau\) 与模型权重联合学习,而非将其视为固定超参数。

CLIP训练:一批N个图像-文本对产生NxN相似度矩阵,训练最大化对角线项并最小化非对角线项

  • CLIP的图像编码器通常是ViT-L/14(大型Vision Transformer,14×14补丁,第8章文件04)。文本编码器是12层Transformer,带因果掩码(类似GPT,第7章文件04)。两个编码器都通过学习的线性投影将输出映射到共享的512或768维空间,随后进行 \(L_2\) 归一化。

  • CLIP最显著的特性是零样本图像分类。要将图像分类到 \(K\) 个类别之一,你创建 \(K\) 个文本提示如"一张{类别名称}的照片",用文本编码器嵌入每个提示,用图像编码器嵌入图像,然后选择文本嵌入与图像嵌入余弦相似度最高的类别。在ImageNet上,CLIP在从未见过任何ImageNet训练样本的情况下实现了有竞争力的准确率。

ALIGN

  • ALIGN(Jia et al., 2021)将CLIP的方法扩展到更嘈杂、更大的数据集:18亿图像-文本对,几乎无过滤。CLIP精心整理数据,而ALIGN表明规模可以补偿噪声。ALIGN使用EfficientNet图像编码器和BERT文本编码器,并用相同的对比损失训练。关键发现是:有足够数据时,你不需要昂贵的数据清洗:对比目标自然降低噪声对的权重,因为它们产生不一致的梯度。

SigLIP

  • SigLIP(Sigmoid Loss for Language-Image Pre-training, Zhai et al., 2023)用更简单的sigmoid损失替换CLIP基于softmax的对比损失。SigLIP不再将 \(N \times N\) 相似度矩阵视为分类问题(每行是对列的softmax),而是将每个条目独立视为二元分类:这个(图像,文本)对是否匹配?

  • 单个对 \((i, j)\) 的SigLIP损失是:

\[\mathcal{L}_{ij} = -y_{ij} \log \sigma(z_i^{\text{img}} \cdot z_j^{\text{txt}} / \tau) - (1 - y_{ij}) \log(1 - \sigma(z_i^{\text{img}} \cdot z_j^{\text{txt}} / \tau))\]
  • 其中 \(y_{ij} = 1\) 如果 \(i = j\)(匹配),否则 \(y_{ij} = 0\)\(\sigma\) 是sigmoid函数。

  • SigLIP的关键优势是消除了对整个批次进行全局softmax归一化的需求。在CLIP中,softmax分母需要跨所有设备收集所有嵌入,这在分布式训练中是通信瓶颈。SigLIP的逐对sigmoid损失可在本地计算,支持更高效地扩展到超大批次。SigLIP以更低的训练成本达到与CLIP相当的质量。

对比损失函数详解

  • 对比学习中使用的损失函数具有共同结构:它们都试图使正例对的相似度分数高于负例对,并通过某种"边际"或"温度"概念控制模型推动的力度。让我们形式化关键变体。

InfoNCE

  • InfoNCE(Noise-Contrastive Estimation, van den Oord et al., 2018)是CLIP损失的理论基础。给定查询 \(q\)、一个正例键 \(k^+\)\(K\) 个负例键 \(\{k_1^-, \ldots, k_K^-\}\),损失是:
\[\mathcal{L}_{\text{InfoNCE}} = -\log \frac{\exp(q \cdot k^+ / \tau)}{\exp(q \cdot k^+ / \tau) + \sum_{j=1}^{K} \exp(q \cdot k_j^- / \tau)}\]
  • 这是一个 \((K+1)\) 路分类问题:从 \(K+1\) 个候选中识别正例。InfoNCE是查询与正例键之间互信息的下界,这就是为什么最大化它能对齐语义匹配输入的表示。随着负例数量 \(K\) 增加,该界变紧,这解释了对比方法为何受益于大批次。

NT-Xent

  • NT-Xent(Normalised Temperature-scaled Cross-Entropy, Chen et al., 2020)是SimCLR(第8章文件04)中使用的损失,本质上是批次内对称应用的InfoNCE。对于 \(N\) 个对的批次,\(2N\) 个增强视图为每个锚点产生 \(2N - 2\) 个负例(除自身及其正例外的所有视图)。正例对 \((i, j)\) 的损失是:
\[\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)}\]
  • NT-Xent和InfoNCE是相同的数学公式;名称不同是因为它们在不同上下文中引入(自监督视觉与表示学习理论)。

温度的作用

  • 温度 \(\tau\) 是对比学习中最重要的超参数之一。为建立直觉,从物理角度思考温度:高温时分子随机运动(softmax平坦,所有负例看起来同样差);低温时分子凝固成刚性结构(softmax尖锐,只有最难负例重要)。

  • 形式上,当 \(\tau \to 0\),softmax趋近于硬argmax,仅选择单个最难负例。当 \(\tau \to \infty\),所有负例贡献相等。实践中,对于归一化嵌入,\(\tau \in [0.01, 0.1]\) 效果良好。温度过低会导致训练不稳定(难负例的梯度变得非常大);温度过高会使损失对违规不敏感。

  • CLIP初始化 \(\tau = 0.07\) 并将其学习为对数参数化标量 \(\tau = \exp(t)\),其中 \(t\) 与模型权重一起通过梯度下降更新。这允许模型在训练过程中自动调整对比任务的难度。

温度对对比softmax的影响:低温产生聚焦于难负例的尖锐分布,高温产生平坦分布

三元组损失与基于边际的替代方案

  • 在InfoNCE主导之前,三元组损失是度量学习的标准。给定锚点 \(a\)、正例 \(p\) 和负例 \(n\)
\[\mathcal{L}_{\text{triplet}} = \max(0, \|a - p\|^2 - \|a - n\|^2 + m)\]
  • 其中 \(m\) 是边际,确保正例至少比负例近 \(m\)。三元组损失在单个三元组上操作而非批次,使其样本效率低于InfoNCE。它对挖掘策略也很敏感:随机负例往往太容易(损失为零),因此难负例挖掘(选择最接近的错误匹配)或半难挖掘(选择边际内的负例)至关重要。

  • InfoNCE隐式地在整个批次上执行难负例挖掘,这是它在大规模下优于三元组损失的原因之一。InfoNCE中的softmax归一化自动提升与锚点相似度高的难负例的权重,提供自然的课程学习而无需显式挖掘。

图像-文本检索与零样本分类

  • 一旦训练好联合嵌入空间,你就可以执行图像-文本检索:给定图像查询,从数据库中查找最相关的文本(图像到文本检索),或给定文本查询,查找最相关的图像(文本到图像检索)。这仅仅是共享嵌入空间中的最近邻搜索。

  • 想象一位图书管理员可以瞬间比较任何照片与百万项目录中的任何标题。他们无需提前理解每个可能的类别;只需测量每张照片与每个标题的"接近程度"。这就是CLIP风格模型执行检索和零样本分类的方式。

  • 零样本分类是文本到图像检索的特例。给定 \(K\) 个类别名称,你构建文本提示 \(\{t_1, \ldots, t_K\}\)(如"一张猫的照片"、"一张狗的照片")并嵌入它们。对于新图像 \(x\),预测类别是:

\[\hat{y} = \arg\max_{k} \; \text{sim}(f_\theta(x), g_\phi(t_k))\]
  • 关键洞见是文本编码器充当灵活的分类器头。你无需为每个下游任务训练新的线性层,只需用自然语言描述任务。这就是为什么CLIP泛化能力如此之强:文本编码器在预训练期间见过数百万种多样化描述。

  • 提示工程很重要。仅通过将提示模板从"{类别名称}"改为"一张{类别名称}的照片",CLIP在ImageNet上的零样本准确率就从63.2%提升到68.4%。更好的是,提示集成平均多个模板的文本嵌入(如"一张{类别名称}的照片"、"一张好的{类别名称}照片"、"一幅{类别名称}的画")以产生更鲁棒的文本表示。

零样本分类:每个类别的文本提示与图像一起嵌入,选择余弦相似度最高的类别

音频-视觉对应

  • 闭上眼睛听某人拍篮球。你可以从有节奏的砰砰声中判断球何时落地。现在睁开眼睛:视觉上的弹跳与每次砰声完美对齐。音频与视觉事件之间的这种紧密对应是机器可以学习的免费监督信号。音频-视觉对应学习训练模型在无人工标注的情况下将声音与其视觉源关联。

  • 这个想法与CLIP惊人地相似,但用音频替换文本。给定配对的视频帧和音频片段,模型学习一个嵌入空间,其中时间对齐的音频-视觉对接近,未对齐的对远离。

  • 音频-视觉嵌入(AVE)方法(Arandjelovic and Zisserman, 2017)用对比损失在视频数据上训练视觉编码器 \(f\) 和音频编码器 \(g\)。正例对是(视频帧,同一时间的音频片段),负例是来自不同视频或不同时间的音频片段。模型学会吠叫声与狗的图像关联、吉他声与吉他的图像关联,全部无需标注。

  • 音频编码器通常使用CNN或音频Transformer处理对数梅尔谱图(第9章文件01),生成固定大小嵌入。视觉编码器使用标准图像骨干网络(ResNet、ViT)处理视频帧。两者都投影到共享的 \(d\) 维空间,训练使用与CLIP相同的InfoNCE损失:

\[\mathcal{L}_{\text{AV}} = -\log \frac{\exp(\text{sim}(z^{\text{vis}}, z^{\text{aud}}) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z^{\text{vis}}, z_k^{\text{aud}}) / \tau)}\]

音频-视觉对应:视觉编码器处理视频帧,音频编码器处理谱图,对比学习对齐时间匹配的对

  • 应用包括:声源定位(图像中声音来自何处?)、音频-视觉语音识别(结合唇动与音频,如第9章文件02)、音频-视觉源分离(通过观看说话者面部隔离其声音,第9章文件05的"鸡尾酒会"问题),以及以音频为条件的视频生成。

  • ImageBind(Girdhar et al., 2023)将此扩展到六种模态:图像、文本、音频、深度、热成像和IMU数据。关键洞见是你不需要每种组合的配对数据。通过将每种模态与图像对齐(用图像-文本对对齐文本,用图像-音频对对齐音频等),所有模态通过共享图像嵌入空间隐式对齐。这种通过共同锚定模态的"绑定"产生涌现对齐:音频和文本变得相似,尽管它们从未直接一起训练。

评估

  • 评估多模态模型需要能捕获跨模态理解的指标。两种主导评估范式是零样本基准检索指标

零样本基准

  • 零样本评估衡量模型是否能执行从未显式训练过的任务。最常见的基准是ImageNet零样本准确率:将所有1000个ImageNet类别名称作为文本嵌入,嵌入每个测试图像,并基于余弦相似度测量top-1和top-5分类准确率。CLIP ViT-L/14零样本实现75.5% top-1准确率,与在ImageNet上监督训练的ResNet-50相当。

  • 其他零样本基准包括:CIFAR-10/100、STL-10、Food-101、Oxford Pets和Flowers-102。在多个数据集上评估测试模型是否具有真正的通用视觉理解,还是仅仅记忆了预训练数据中的模式。

  • 线性探测评估是互补测试。你冻结预训练图像编码器,为标注数据集提取特征,然后在顶部训练简单线性分类器。这独立于零样本检索机制,衡量学习表示的质量。CLIP的特征是优秀的线性探测特征,通常匹配或超越监督预训练。

检索指标

  • 对于检索任务(图像到文本和文本到图像),标准指标是Recall@K(R@K):正确匹配出现在前 \(K\) 个检索结果中的查询比例。常用值为R@1、R@5和R@10。

  • 形式上,对于 \(Q\) 个查询的集合:

\[\text{R@}K = \frac{1}{Q} \sum_{q=1}^{Q} \mathbf{1}[\text{rank}(q) \leq K]\]
  • 其中 \(\text{rank}(q)\) 是查询 \(q\) 的排序检索列表中正确匹配的位置。

  • 标准检索基准包括Flickr30K(31,000张图像,每张5个标题)和MS-COCO(123,000张图像,每张5个标题)。评估在测试集上进行:给定图像,从完整测试集中检索正确标题,反之亦然。

  • 中位排名(MedR)是互补指标:所有查询中正确匹配位置的中位数。完美模型的MedR = 1。越低越好。

  • 除检索外,多模态模型还在组合理解基准上评估,如Winoground(测试模型能否区分"杯子里的狗"与"狗里的杯子")和ARO(属性、关系、顺序),测试模型是否真正理解语言结构,还是仅仅匹配词袋。CLIP风格模型在这些基准上往往表现不佳,揭示了一个根本局限:对比预训练对齐全局语义,但可能无法捕获细粒度组合结构。

检索评估:给定查询图像,模型按相似度对所有文本候选排序,Recall@K衡量正确标题是否出现在前K结果中

综合总结

  • 本文涵盖的多模态表示构成了本章后续内容的基础。由CLIP及其后继者训练的联合嵌入空间是连接视觉与语言的"胶水"。文件02在此基础上构建超越检索的视觉-语言模型,生成关于图像的文本。文件03探索图像和视频如何被标记化以用于序列模型。文件04涵盖跨模态生成(文本到图像、文本到视频)。文件05研究在单个模型内处理多种模态的统一架构。

  • 核心要点:在配对数据上的对比学习产生嵌入空间,其中不同模态可互换。图像嵌入和文本嵌入变成"同一种东西",支持零样本分类、检索,并无缝集成到更大系统中。这一思想的简洁性——只需将匹配对拉近、不匹配对推远——与其非凡的有效性形成鲜明对比。

编程任务(使用CoLab或Notebook)

  1. 从头实现CLIP对比损失。创建随机图像和文本嵌入,计算相似度矩阵,并计算对称交叉熵损失。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def clip_loss(image_embeds, text_embeds, temperature=0.07):
        """计算对称CLIP对比损失."""
        # L2归一化嵌入
        image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)
        text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=1, keepdims=True)
    
        # 计算余弦相似度矩阵 (N x N)
        logits = image_embeds @ text_embeds.T / temperature  # (N, N)
    
        # 标签: 对角线 (第i张图像匹配第i个文本)
        N = logits.shape[0]
        labels = jnp.arange(N)
    
        # 对称交叉熵: 图像到文本 + 文本到图像
        loss_i2t = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(N), labels])
        loss_t2i = -jnp.mean(jax.nn.log_softmax(logits, axis=0)[labels, jnp.arange(N)])
        return (loss_i2t + loss_t2i) / 2, logits * temperature
    
    # 模拟一批8个图像-文本对,64维空间
    key = jax.random.PRNGKey(42)
    k1, k2 = jax.random.split(key)
    N, D = 8, 64
    image_embeds = jax.random.normal(k1, (N, D))
    text_embeds = jax.random.normal(k2, (N, D))
    
    loss, sim_matrix = clip_loss(image_embeds, text_embeds)
    print(f"CLIP损失(随机嵌入): {loss:.4f}")
    
    # 可视化相似度矩阵
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)
    ax.set_xlabel("文本索引"); ax.set_ylabel("图像索引")
    ax.set_title(f"余弦相似度矩阵(损失={loss:.3f})")
    plt.colorbar(im); plt.tight_layout(); plt.show()
    # 尝试改变温度(0.01, 0.1, 1.0)并观察损失变化
    # 尝试使匹配对相似: 设置 text_embeds = image_embeds + 小噪声
    

  2. 构建玩具联合嵌入模型,使用InfoNCE损失和梯度下降学习对齐2D"图像"(随机向量)与"标题"(不同随机向量)。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    def info_nce_loss(img_enc, txt_enc, img_data, txt_data, tau=0.1):
        """一批配对(图像,文本)数据上的InfoNCE."""
        z_img = img_data @ img_enc  # (N, D)
        z_txt = txt_data @ txt_enc  # (N, D)
        # L2归一化
        z_img = z_img / jnp.linalg.norm(z_img, axis=1, keepdims=True)
        z_txt = z_txt / jnp.linalg.norm(z_txt, axis=1, keepdims=True)
        logits = z_img @ z_txt.T / tau
        labels = jnp.arange(logits.shape[0])
        return -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels])
    
    # 创建32个配对样本: 图像在R^8, 文本在R^6, 嵌入到R^4
    key = jax.random.PRNGKey(0)
    k1, k2, k3, k4 = jax.random.split(key, 4)
    N, d_img, d_txt, d_embed = 32, 8, 6, 4
    
    img_data = jax.random.normal(k1, (N, d_img))
    txt_data = jax.random.normal(k2, (N, d_txt))
    
    # 可学习投影矩阵
    img_enc = jax.random.normal(k3, (d_img, d_embed)) * 0.1
    txt_enc = jax.random.normal(k4, (d_txt, d_embed)) * 0.1
    
    grad_fn = jax.jit(jax.grad(info_nce_loss, argnums=(0, 1)))
    lr = 0.05
    losses = []
    
    for step in range(300):
        loss = info_nce_loss(img_enc, txt_enc, img_data, txt_data)
        losses.append(float(loss))
        g_img, g_txt = grad_fn(img_enc, txt_enc, img_data, txt_data)
        img_enc = img_enc - lr * g_img
        txt_enc = txt_enc - lr * g_txt
    
    print(f"初始损失: {losses[0]:.3f}, 最终损失: {losses[-1]:.3f}")
    print(f"随机基线(log N): {jnp.log(N):.3f}")
    
    plt.figure(figsize=(8, 4))
    plt.plot(losses, color='#2c3e50')
    plt.axhline(y=0, color='green', linestyle='--', alpha=0.5, label='完美对齐')
    plt.axhline(y=float(jnp.log(N)), color='red', linestyle='--', alpha=0.5, label='随机(log N)')
    plt.xlabel("步骤"); plt.ylabel("InfoNCE损失")
    plt.title("学习联合嵌入空间")
    plt.legend(); plt.grid(alpha=0.3); plt.tight_layout(); plt.show()
    # 修改d_embed(尝试2, 4, 16)观察嵌入维度如何影响对齐
    

  3. 使用预计算嵌入实现零样本分类。将类别"原型"模拟为文本嵌入,并通过最近邻查找分类新图像。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 模拟5个类别,每个类别在R^32中有原型文本嵌入
    key = jax.random.PRNGKey(42)
    n_classes, d = 5, 32
    class_names = ["cat", "dog", "car", "plane", "ship"]
    
    # 类别原型(想象这些来自文本编码器)
    k1, k2 = jax.random.split(key)
    class_prototypes = jax.random.normal(k1, (n_classes, d))
    class_prototypes = class_prototypes / jnp.linalg.norm(class_prototypes, axis=1, keepdims=True)
    
    # 生成200个测试"图像"(嵌入靠近其类别原型 + 噪声)
    n_per_class = 40
    true_labels = jnp.repeat(jnp.arange(n_classes), n_per_class)
    keys = jax.random.split(k2, n_classes * n_per_class)
    
    image_embeds = []
    for i in range(n_classes):
        noise = jax.random.normal(keys[i], (n_per_class, d)) * 0.5
        cluster = class_prototypes[i] + noise
        image_embeds.append(cluster)
    image_embeds = jnp.concatenate(image_embeds, axis=0)
    image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)
    
    # 零样本分类: 与每个原型的余弦相似度
    similarities = image_embeds @ class_prototypes.T  # (200, 5)
    predicted_labels = jnp.argmax(similarities, axis=1)
    accuracy = jnp.mean(predicted_labels == true_labels)
    print(f"零样本准确率: {accuracy:.1%}")
    
    # 混淆矩阵
    conf = jnp.zeros((n_classes, n_classes), dtype=jnp.int32)
    for true, pred in zip(true_labels, predicted_labels):
        conf = conf.at[true, pred].add(1)
    
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(conf, cmap='Blues')
    ax.set_xticks(range(n_classes)); ax.set_xticklabels(class_names, rotation=45)
    ax.set_yticks(range(n_classes)); ax.set_yticklabels(class_names)
    ax.set_xlabel("预测"); ax.set_ylabel("真实")
    for i in range(n_classes):
        for j in range(n_classes):
            ax.text(j, i, int(conf[i, j]), ha='center', va='center', fontsize=11)
    ax.set_title(f"零样本混淆矩阵(准确率={accuracy:.1%})")
    plt.colorbar(im); plt.tight_layout(); plt.show()
    # 尝试增加噪声(0.5 -> 1.0 -> 2.0)观察准确率下降
    # 尝试添加提示集成: 平均每个原型的3个噪声副本