统一多模态架构
统一多模态架构¶
统一多模态架构用一个单一系统替代了多个专用模型,该系统能够跨文本、图像、音频和视频进行读取、推理和生成。本文涵盖任意到任意模型(CoDi、NExT-GPT)、原生多模态大语言模型(Gemini、GPT-4o)、多模态标记化策略以及统一化的架构权衡。
统一化的理由¶
-
想象一个会说五种语言的翻译,他能在一句话中间毫无停顿地在不同语言间切换。早期的多模态系统更像是五个不同的翻译坐在不同的房间里,每个人处理一种语言,通过墙上的一个槽传递纸条。统一多模态架构就是那个单一的多语者:一个具有共享权重的模型,可以在一次前向传播中跨文本、图像、音频、视频甚至动作进行读取、写入和推理。
-
其动机既有实践性也有理论性。从实践角度看,为每种模态对(文本到图像、图像到文本、音频到文本等)维护单独的专用模型会导致组合爆炸:\(k\) 种模态需要多达 \(k(k-1)\) 个有向管道。一个统一模型将所有这些整合为一个系统。从理论角度看,人类认知并不是在隔离的模块中处理视觉和语言;跨模态绑定发生得很早且很深,统一化试图模仿这一点。
-
共享权重鼓励跨模态迁移。一个已经学习了文本中时间模式(主语在动词前,原因在结果前)的Transformer,可以将同样的注意力回路重新用于视频中的时间模式(物体在移动前出现)或音频中的时间模式(起音在持续音之前)。这类似于你在第07章的语言模型微调和第08章的ImageNet预训练中看到的迁移学习的多模态版本。
-
形式化地说,设 \(\mathcal{M} = \{m_1, m_2, \ldots, m_k\}\) 为一组模态。一个统一模型定义了一个单一参数化函数 \(f_\theta\),它将任意输入模态子集映射到任意输出模态子集:
- 其中 \(\mathcal{P}(\mathcal{M})\) 是模态的幂集(所有子集)。关键的约束是 \(\theta\) 大部分是共享的;只有少量的、针对特定模态的适配器层有所不同。
- 统一化的前景伴随着一个基本矛盾:模态在结构上是不同的。文本是离散词元的一维序列。图像是连续像素值的二维网格。音频是一维连续波形,其时间尺度与文本非常不同。视频在图像基础上增加了时间轴。将这些不同的结构整合成一个Transformer能够消化的单一序列是这一领域的核心工程挑战。
任意到任意模型¶
-
想象一个通用遥控器,可以通过同一个界面操作你的电视、空调和音响系统。任意到任意模型就是AI中的等价物:它们接受任何模态组合作为输入,并产生任何组合作为输出。
-
CoDi(Composable Diffusion)通过训练特定模态的扩散模型,然后通过共享的条件化机制对齐它们的潜在空间来实现任意到任意生成。每种模态都有自己的扩散过程(回顾本章第04章中的扩散模型),但噪声预测网络以一个联合交叉注意力层为条件,该层同时看到所有输入模态的嵌入。这使得CoDi能够,例如,从文本提示中一次性生成一张图像和匹配的音频。
-
NExT-GPT采取了不同的架构方法。它通过轻量级投影层将一个LLM骨干网络(“大脑”)连接到输入侧的特定模态编码器和输出侧的特定模态解码器。输入编码器(例如,来自CLIP的图像编码器,来自CLAP的音频编码器)将每种模态转换到LLM的嵌入空间。LLM对组合的词元序列进行推理,并发出特殊的“模态信号词元”,将信息路由到适当的解码器(例如,用于图像的Stable Diffusion,用于音频的AudioLDM)。只有投影层被训练;LLM和专用编码器/解码器保持冻结。
-
Gemini(Google DeepMind)从预训练开始就是原生的多模态模型。与NExT-GPT的即插即用方法不同,Gemini的Transformer是在文本、图像、音频和视频词元的交错序列上从头开始训练的。这意味着跨模态注意力模式是在预训练期间有机发展的,而不是事后附加的。该模型对文本使用SentencePiece分词器,并学习类似于本章第03章讨论的VQ方法的视觉分词器。
-
GPT-4o(“o”代表“omni”,即全功能)代表了另一种模式:一个端到端的模型,其中所有模态共享同一个Transformer和同一个下一个词元预测目标。音频输入被处理为频谱词元,图像为图像块词元,文本为子词词元,全部输入到一个序列中。该模型生成输出词元,并由特定模态的头进行解码。关键的创新在于,通过移除早期系统(如GPT-4V)所依赖的独立的ASR、LLM和TTS模型的级联,实现了低延迟。
-
这些模型处于一个整合深度的谱系上:
- 浅层整合(NExT-GPT):冻结的专家通过训练好的适配器连接。构建快速,跨模态推理能力有限。
- 中层整合(CoDi):跨特定模态生成器的共享条件化。对齐更好,但仍然模块化。
- 深层整合(Gemini、GPT-4o):在所有模态上端到端训练的单一模型。跨模态推理能力最丰富,但训练成本最高。
具有共享骨干网络的模态专用编码器和解码器¶
-
想象一个拥有单条装配线(共享骨干网络)但有不同的原材料装卸区(编码器)和成品发货区(解码器)的工厂。每个装卸区都针对其货物进行了专门化设计,但一旦进入工厂内部,所有东西都沿着同一条传送带移动。
-
统一模型的主导架构模式使用这种三部分结构:
- 模态编码器 \(E_m\),将来自模态 \(m\) 的原始输入转换为嵌入向量序列 \(\mathbf{h}_1^m, \mathbf{h}_2^m, \ldots, \mathbf{h}_{n_m}^m\),每个向量的维度为 \(d\)。
- 一个共享的Transformer骨干网络 \(T_\theta\),使用自注意力处理来自所有输入模态的拼接或交错嵌入。
- 模态解码器 \(D_m\),将骨干网络的输出嵌入转换回模态 \(m\) 的原始格式(文本词元、图像像素、音频波形)。
-
对于文本,编码器通常是一个嵌入查找表 \(E_\text{text}(w) = \mathbf{W}_e[w]\),其中 \(w\) 是词元索引,与你之前在第07章看到的Transformer相同。对于图像,编码器通常是一个视觉Transformer(ViT),它将图像分割成图像块并线性投影每个图像块,如第08章所述。对于音频,编码器计算梅尔频谱图,并使用卷积前端或音频频谱图Transformer(AST)处理它,如第09章所述。
-
共享骨干网络是一个标准的Transformer,它在所有模态词元上进行自注意力。给定一个拼接的输入序列 \(\mathbf{H} = [\mathbf{h}_1^{m_1}, \ldots, \mathbf{h}_{n_1}^{m_1}, \mathbf{h}_1^{m_2}, \ldots, \mathbf{h}_{n_2}^{m_2}]\),自注意力允许每个词元关注其他任何词元,而不管它们的模态:
-
这与第07章的注意力公式相同,但现在 \(\mathbf{Q}\)、\(\mathbf{K}\) 和 \(\mathbf{V}\) 包含来自多种模态的词元。一个图像块词元可以关注一个文本词元,从而实现跨模态推理,而无需任何额外的交叉注意力模块。
-
模态嵌入被添加到每个词元,以便骨干网络知道词元来自哪种模态。这类似于位置嵌入,但编码的是模态身份而非序列位置。一个可学习的向量 \(\mathbf{e}_m \in \mathbb{R}^d\) 被添加到来自模态 \(m\) 的每个词元:
- 其中 \(\mathbf{p}_i\) 是位置 \(i\) 的位置嵌入。
多模态标记化¶
-
想象你在写一封信,信中既有英文文本,也有手绘草图。你可能写一个句子,画一个图表,再写一个引用该图表的句子,然后贴上一段乐谱。这封信是一个单一的线性流,交织着不同的“模态”。多模态标记化正是做这件事:它将文本、图像、音频和视频转换为一个单一的扁平词元序列,供Transformer从左到右处理。
-
对于文本,标记化已经非常成熟:字节对编码(BPE)或SentencePiece产生一个子词词元词汇表,如第07章所述。挑战在于将这个想法扩展到连续的模态。
-
对于图像,有两种主要方法。离散方法使用VQ-VAE或VQ-GAN(本章第03章详述)将每个图像映射到码本索引序列。如果码本有 \(|\mathcal{C}|\) 个条目,并且一个图像被编码为 \(n\) 个码,则该图像变成了从大小为 \(|\mathcal{C}|\) 的词汇表中抽取的 \(n\) 个离散词元,直接与文本词汇表兼容。连续方法使用ViT或CNN编码器产生 \(n\) 个连续嵌入向量,这些向量被线性投影到Transformer的嵌入维度。Gemini和GPT-4o使用连续方法的变体;像Parti和LlamaGen这样的自回归图像生成器更倾向于离散路径。
-
对于音频,信号通常被转换为梅尔频谱图,然后要么用神经音频编解码器(例如,EnCodec,SoundStream,它们产生层次化的离散词元)进行离散化,要么通过一个学习到的编码器连续投影。例如,AudioLM将音频表示为来自多个码本级别的离散词元序列,然后以自回归方式建模它们。
-
对于视频,标记化建立在图像标记化的基础上,但还必须压缩时间维度。一种常见的策略使用3D VQ-VAE(如VideoGPT或来自第03章的Cosmos Tokeniser),它对时空图像块进行量化,生成离散词元。时间压缩因子至关重要:如果没有激进的时序下采样,24 fps的原始视频每秒会产生太多的词元。
-
一旦所有模态都被标记化,它们就会交错成一个单一的序列,并使用特殊的分隔词元标记模态边界。一个典型的格式看起来像:
[TEXT] 猫坐在垫子上 [/TEXT] [IMAGE] <img_tok_1> <img_tok_2> ... <img_tok_n> [/IMAGE] [AUDIO] <aud_tok_1> ... <aud_tok_m> [/AUDIO]
- 然后,Transformer使用其标准的因果(或双向)注意力机制处理这个完整的混合序列。模态分隔词元有两重作用:它们告知模型模态边界,并充当“汇聚点”,其表示总结了每个模态片段。
- 一个关键的设计选择是词元预算。一个被标记化为256个词元的图像和一个50个词元的文本描述意味着图像消耗的上下文窗口是文本的5倍多。模型必须在分辨率(更多词元 = 更多细节)和上下文长度(更多词元 = 更高的内存和计算成本)之间取得平衡。像词元合并(逐步合并相似的词元)和自适应标记化(对简单区域使用更少的词元,对复杂区域使用更多的词元)等技术有助于管理这种权衡。
训练方案:分阶段预训练与联合微调¶
-
你不会在孩子学算术之前教他微积分。同样,你不可能在随机初始化的情况下同时在所有模态上训练一个统一的多模态模型,并期望它能很好地收敛。主导方法是分阶段训练,模型在精心排序的阶段中逐步学习越来越复杂的跨模态能力。
-
阶段1:单模态预训练。 每个模态编码器在大型单模态数据集上独立训练。文本骨干网络在数万亿个文本词元上使用标准的语言建模目标(下一个词元预测)进行预训练,正如第07章所述。视觉编码器在图像分类或自监督目标(MAE,DINO)上预训练,如第08章所述。音频编码器在语音识别或音频分类数据上预训练,如第09章所述。这个阶段产生强大的单模态特征提取器。
-
阶段2:跨模态对齐。 预训练的编码器连接到共享骨干网络,模型在配对的多模态数据(图像-描述对,音频-转录对)上使用对比或生成目标进行训练。在此阶段,编码器权重可能被冻结(以保留单模态知识),而只有投影层和骨干网络被更新。这是本章第01章中的CLIP式对齐被整合到统一模型中的阶段。
-
阶段3:联合多模态预训练。 所有参数(或大部分参数)被解冻,模型在单模态和多模态数据的混合上训练,使用跨所有模态词元的单一“下一个词元预测”目标。损失函数为:
-
其中 \(x_t\) 可以是文本词元、图像词元或音频词元。模型必须学会预测下一个词元,无论其模态如何,这迫使它发展出真正的跨模态理解。
-
阶段4:指令微调与对齐。 预训练模型在经过整理的多模态指令跟随数据集(例如,“详细描述这张图像”,“这段视频发出什么声音?”,“生成一张X的图像”)上进行微调。此阶段通常使用基于人类反馈的强化学习(RLHF)或直接偏好优化(DPO)来使模型的输出与人类偏好对齐。
-
模态特定热身是在阶段内使用的一种技术,用于防止模态坍塌。如果一种模态(通常是文本,因为它拥有最多的训练数据)主导了梯度信号,模型可能会“遗忘”较弱的模态。热身策略包括:
- 梯度平衡:缩放来自每种模态的梯度,使它们对参数更新有相等的贡献。
- 数据比例调度:逐步增加多模态数据相对于单模态数据的比例。
- 损失加权:分配特定模态的权重 \(\lambda_m\),使得总损失为 \(\mathcal{L} = \sum_m \lambda_m \mathcal{L}_m\),其中 \(\lambda_m\) 被调整以平衡各模态的学习率。
- 为什么不跳过阶段? 从头开始联合训练所有内容是诱人的,但由于几个原因在实践中会失败。首先,模型必须同时学习低级特征(边缘检测、音素识别)和高层次跨模态推理,这两者具有非常不同的学习动态。其次,模态间的数据分布极不平衡(数万亿文本词元 vs 数十亿图像词元 vs 数亿音频片段)。第三,优化景观是高度非凸的,分阶段训练提供了一个课程,指导模型走向一个更好的盆地,类似于第06章的课程学习思想。
多模态思维链推理¶
-
当你解决一个几何问题时,你可能会画一个草图,标记角度,写出方程,然后一步步求解。你不会直接从问题陈述跳到答案。多模态思维链(CoT)推理使模型能够做同样的事情:在得出最终答案之前,生成可能涉及文本、视觉注释甚至生成图表的中间推理步骤。
-
在纯文本CoT中(如第07章关于提示策略的讨论所述),模型以自然语言生成一系列推理步骤。多模态CoT通过允许中间步骤引用或生成视觉内容来扩展这一点。例如,给定一个图表图像和问题“哪一年的销售额最高?”,一个多模态CoT模型可能首先描述图表(“图表显示了2018年至2023年的销售额...”),然后识别相关的视觉特征(“最高的柱子出现在2021年...”),最后输出答案(“2021年”)。
-
形式化地说,设 \(\mathbf{x}\) 为多模态输入,\(y\) 为目标答案。标准预测模型直接建模 \(p(y \mid \mathbf{x})\)。思维链引入中间推理 \(\mathbf{r} = (r_1, r_2, \ldots, r_L)\),并将预测分解为:
-
在实践中,求和通过对推理链的贪心搜索或波束搜索解码来近似。推理步骤 \(r_i\) 可以是文本词元、对图像区域的引用,甚至是生成的视觉词元(例如,叠加在输入图像上的边界框注释)。
-
训练多模态CoT通常涉及整理数据集,其中人类标注者提供逐步的多模态推理轨迹,然后在这些轨迹上微调模型。一些方法从更大的教师模型中蒸馏CoT能力:教师模型为大型数据集生成推理轨迹,然后较小的学生模型在输入和教师模型的轨迹上进行训练。
-
多模态CoT对于需要空间推理(例如,“红色球在蓝色立方体的左边吗?”)、对图表的数学推理(例如,几何问题)以及多步视觉问答(答案依赖于组合图像多个区域的信息)的任务尤其强大。
多模态智能体¶
-
想象一个厨房里的机器人厨师。它看着台面上的食材(视觉),在平板上阅读食谱(文本),听着计时器蜂鸣(音频),然后物理上拿起一把刀切洋葱(动作)。多模态智能体就是它的数字版本:一个通过多种模态感知世界、推理该做什么、并采取基于其感知的行动的模型。
-
智能体的循环遵循经典的观察-推理-行动周期:
- 观察:智能体从其环境接收多模态输入(屏幕截图、用户的口头指令、视频流)。
- 推理:统一模型处理多模态输入,可能使用思维链来规划一系列步骤。
- 行动:模型输出一个动作(文本响应、工具调用、在坐标 \((x, y)\) 处的鼠标点击、机器人电机命令)。
-
工具使用是多模态智能体的一个关键能力。模型被训练成能够识别何时无法直接回答问题,而必须调用外部工具:计算器、代码解释器、网络浏览器或搜索引擎。模型在其输出词元序列中生成结构化的工具调用(例如,
search("伦敦当前天气")),系统执行该调用,并将结果作为额外的输入词元反馈给模型处理。 -
视觉定位将语言连接到图像或视频中的特定区域。当一个智能体说“点击右上角的蓝色按钮”时,它必须将短语“右上角的蓝色按钮”连接到像素坐标。在架构上,这是通过训练模型输出作为特殊词元的边界框坐标,或者让模型在图像上生成一个指示所指区域的热图来实现的。这将本章第02章(视觉语言模型)讨论的定位和指代工作扩展到行动领域。
-
网络智能体,如WebVoyager和SeeAct,展示了多模态智能体在网站上导航的能力。智能体接收网页的屏幕截图,识别交互元素(按钮、文本字段、链接),并输出动作(点击、输入、滚动)以完成用户指定的目标。关键的挑战在于巨大的行动空间:一个典型的网页有数百个可能的点击目标。
-
具身智能体将此扩展到物理环境。一个带有摄像头和麦克风的机器人接收视觉和音频输入,通过统一模型处理它们,并输出电机命令。像PaLM-E(Google)这样的项目将机器人传感器数据直接嵌入语言模型的词元序列中,使机器人能够通过将指令定位在其视觉观察中并生成一系列电机动作来遵循诸如“拿起碗附近的绿色积木”的指令。
-
智能体的训练方案在标准的分阶段预训练之上增加了一个强化学习(RL)阶段。智能体与环境(模拟桌面、网络浏览器、机器人模拟器)交互,获得任务完成的奖励,并使用像PPO或REINFORCE这样的算法更新其策略。奖励信号通常是稀疏的(任务成功为1,否则为0),这使得优化具有挑战性,并且严重依赖于来自多模态预训练的强先验。
基准测试与评估¶
-
评估一个能够看到、听到、阅读和行动的模型需要一套多样化的基准测试。没有单一的指标能够捕捉多模态能力,因此该领域依赖于一系列专门的评估方法。
-
MMLU(Massive Multitask Language Understanding)测试了57个学科领域的知识。虽然最初只用于文本,但它作为一个基线:一个统一的多模态模型在获得视觉能力时,不应损失纯文本的性能。多模态训练后MMLU的下降标志着灾难性遗忘。
-
MMBench 跨20个细粒度的能力维度评估视觉语言理解,包括属性识别、空间关系理解和OCR。每个问题呈现一张图像和一个多项选择题。该基准系统地测试模型是真正理解了图像,还是依赖于纯文本的捷径。
-
SEED-Bench 提供了19,000个多项选择题,涵盖图像和视频理解的12个评估维度。它专门测试时间理解(在给定帧之前/之后发生了什么)和组合推理(组合多个视觉属性)。
-
MM-Vet 通过要求模型在单个问题中同时使用多种技能(识别、OCR、空间意识、语言生成和知识检索)来评估综合多模态能力。
-
MathVista 测试对视觉输入的数学推理:几何图、统计图表、函数图和科学图形。这个基准专门针对多模态思维链能力。
-
视听基准,如AVQA(音频-视觉问答),测试模型是否能推理所见与所听之间的关系。例如:“说话的人是左边的还是右边的?”
-
智能体基准,如WebArena、OSWorld和SWE-bench,在交互式环境中评估任务完成情况。指标通常是成功率:智能体正确完成任务的百分比?这些基准特别具有挑战性,因为它们需要长程规划和错误恢复。
-
整体评估框架,如LMSYS Chatbot Arena,以头对头的方式使用人类偏好判断。两个模型被展示相同的多模态输入,人类评判者选择哪个响应更好。Elo评分从数千次这样的比较中计算出来,提供一个与整体模型质量高度相关的单一标量。
-
多模态评估中的一个持续挑战是数据污染:因为这些模型是在互联网规模的数据上训练的,基准测试中的图像和问题可能出现在训练集中。仔细的去重和创建保留的测试集是必要但不完美的保障措施。
世界模型¶
-
想象一下,闭上眼睛,想象如果你把杯子推下桌子边缘会发生什么。你“看到”它掉下来,“听到”碎裂声,并“感觉”到这将是一个坏主意。你的大脑正在运行一个世界模型:一个对环境的物理和因果结构的内在模拟,可以预测跨多种模态的未来状态。
-
在AI领域,世界模型是一个学习到的函数,它在给定当前状态和动作的情况下预测世界的下一个状态:
-
其中 \(s_t\) 是当前状态表示(可能包含视觉、听觉和本体感觉信息),\(a_t\) 是一个动作,\(\hat{s}_{t+1}\) 是预测的下一个状态。状态 \(s_t\) 存在于一个学习到的潜在空间中,而不是原始像素空间,这使得预测问题变得可行。
-
视频预测模型,如Sora(OpenAI)和Genie(Google DeepMind),代表了向世界模型迈出的重要一步。它们学会在以文本提示和/或动作序列为条件下生成时间上连贯的视频帧。虽然它们常被讨论为视频生成器,但其底层能力更接近于世界模拟:模型已经内化了足够的物理知识(重力、碰撞、遮挡、流体动力学)来渲染合理的未来。
-
这与多模态架构的联系很深。一个只预测像素的世界模型是有限的;一个真正有用的世界模型会跨模态进行预测。如果你推杯子,世界模型应该预测视觉轨迹(杯子下落)、听觉事件(杯子破碎)和语义后果(现在地板上有碎玻璃)。统一的多模态架构是世界模型的自然候选者,因为它们已经在一个共享空间中表示所有模态。
-
形式化地说,一个多模态世界模型优化:
- 其中 \(s_{t+1}^m\) 是模态 \(m\) 的真实下一状态表示,\(g_\phi^m\) 是世界模型针对模态 \(m\) 的预测头。共享的潜在动态 \(g_\phi\) 在联合多模态空间中运行,而特定模态的头将预测解码到每种模态的原始格式。
- JEPA(联合嵌入预测架构)由Yann LeCun提出,提供了一个世界模型的框架,避免了像素级预测的陷阱。JEPA不是在像素空间预测(这会在无关的细节如精确纹理上浪费容量),而是在嵌入空间中进行预测。该模型学习一个将观测映射到嵌入的编码器,以及一个预测未来嵌入的预测器:
-
损失比较的是嵌入而不是原始观测,这对于感知混叠(许多不同的像素配置可能代表相同的语义状态)更加鲁棒。这种方法对于多模态世界模型尤其有希望,因为它自然地运行在统一架构已经提供的共享嵌入空间中。
-
世界模型除了学术兴趣外还有实际应用。在基于模型的强化学习中,智能体使用其世界模型在采取行动前“想象”行动的后果,从而显著减少所需的真实世界交互次数(回顾第11章关于基于模型的RL的讨论)。在自动驾驶中,世界模型预测给定不同转向决策后场景将在未来几秒内如何演变。在机器人学中,世界模型允许机器人在执行操作序列之前进行心理演练。
-
世界模型研究的前沿正朝着交互式世界模型发展,这些模型实时运行并对任意用户动作做出反应,本质上成为完全从数据中学习的通用模拟器。Genie 2(Google DeepMind)在3D环境中展示了这一点:给定一张图像,它生成一个可由用户探索的、可交互、可控的3D世界。世界模型与统一多模态架构的融合暗示了一个未来,其中单一模型可以跨所有模态进行感知、预测、模拟和行动。
编程任务(使用CoLab或notebook)¶
任务1:构建一个最小的多模态词元交织器
- 编写一个函数,接收一个文本字符串和一个虚拟的“图像”(一个小型2D数组),并将它们的标记化表示交织成一个带有模态嵌入的单一扁平序列。
import jax
import jax.numpy as jnp
# 模拟多模态标记化:文本词元 + "图像块" 词元
def interleave_modalities(text_tokens, image_patches, embed_dim=32, key=jax.random.PRNGKey(0)):
"""使用学习到的模态嵌入来交织文本和图像词元。"""
k1, k2, k3 = jax.random.split(key, 3)
n_text = text_tokens.shape[0]
n_img = image_patches.shape[0]
# 随机投影矩阵(替代真实编码器)
W_text = jax.random.normal(k1, (text_tokens.shape[-1], embed_dim)) * 0.02
W_img = jax.random.normal(k2, (image_patches.shape[-1], embed_dim)) * 0.02
# 模态嵌入:一个用于文本,一个用于图像
mod_emb = jax.random.normal(k3, (2, embed_dim)) * 0.02
text_embs = text_tokens @ W_text + mod_emb[0] # (n_text, embed_dim)
img_embs = image_patches @ W_img + mod_emb[1] # (n_img, embed_dim)
# 交织:先 [IMG] 词元,然后 [TEXT] 词元(像 LLaVA 一样)
combined = jnp.concatenate([img_embs, text_embs], axis=0)
print(f"组合序列: {n_img} 图像 + {n_text} 文本 = {combined.shape[0]} 词元")
return combined
# 试试:5个文本词元(维度16)和4个图像块(维度64)
text = jax.random.normal(jax.random.PRNGKey(1), (5, 16))
image = jax.random.normal(jax.random.PRNGKey(2), (4, 64))
seq = interleave_modalities(text, image)
# 实验:改变 embed_dim,交换交织顺序,添加第三种模态
任务2:可视化跨模态注意力模式
- 创建一个合成的多模态序列,并计算自注意力分数,以观察图像词元如何关注文本词元,反之亦然。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def cross_modal_attention(n_text=6, n_img=4, d=32, key=jax.random.PRNGKey(42)):
"""计算并可视化文本和图像词元之间的注意力。"""
k1, k2, k3 = jax.random.split(key, 3)
# 为两种模态模拟词元嵌入
text_embs = jax.random.normal(k1, (n_text, d))
img_embs = jax.random.normal(k2, (n_img, d))
seq = jnp.concatenate([img_embs, text_embs], axis=0) # (n_img+n_text, d)
# 学习到的 Q, K 投影
Wq = jax.random.normal(k3, (d, d)) * 0.1
Wk = jax.random.normal(jax.random.PRNGKey(99), (d, d)) * 0.1
Q, K = seq @ Wq, seq @ Wk
scores = Q @ K.T / jnp.sqrt(d)
attn = jax.nn.softmax(scores, axis=-1)
# 绘图
labels = [f"img_{i}" for i in range(n_img)] + [f"txt_{i}" for i in range(n_text)]
fig, ax = plt.subplots(figsize=(7, 6))
ax.imshow(attn, cmap="viridis")
ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=45, fontsize=8)
ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=8)
ax.set_xlabel("键(被关注的对象)")
ax.set_ylabel("查询(发出关注的主体)")
ax.set_title("跨模态自注意力图")
plt.colorbar(ax.images[0], ax=ax, shrink=0.8)
plt.tight_layout(); plt.show()
cross_modal_attention()
# 实验:增加 d,添加因果掩码,观察注意力模式如何变化
任务3:使用模态特定损失权重模拟分阶段训练
- 演示模态特定的损失权重如何影响一个玩具多模态训练循环。观察平衡损失如何防止一种模态主导。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def staged_training_sim(steps=200, key=jax.random.PRNGKey(7)):
"""使用可调节的模态损失权重模拟多模态训练。"""
# 两种“模态”,具有不同的损失尺度(文本损失约比图像损失大10倍)
losses_text, losses_img = [], []
param = jnp.array([0.0, 0.0]) # 由两种模态损失更新的共享参数
lr = 0.05
# 尝试改变这些权重以观察对收敛平衡的影响
lambda_text, lambda_img = 1.0, 5.0 # 增加较弱模态的权重
for step in range(steps):
k1, k2, key = jax.random.split(key, 3)
noise_t = jax.random.normal(k1, ()) * 0.3
noise_i = jax.random.normal(k2, ()) * 0.1
loss_t = (param[0] - 3.0) ** 2 + noise_t # 文本目标 = 3.0
loss_i = 0.1 * (param[1] - 1.0) ** 2 + noise_i # 图像目标 = 1.0(尺度更小)
# 加权组合梯度
grad_t = lambda_text * 2 * (param[0] - 3.0)
grad_i = lambda_img * 0.2 * (param[1] - 1.0)
param = param - lr * jnp.array([grad_t, grad_i])
losses_text.append(float(loss_t)); losses_img.append(float(loss_i))
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(losses_text, label=f"文本损失 (权重={lambda_text})", alpha=0.7)
ax.plot(losses_img, label=f"图像损失 (权重={lambda_img})", alpha=0.7)
ax.set_xlabel("训练步数"); ax.set_ylabel("损失"); ax.legend()
ax.set_title("分阶段训练期间的模态损失平衡")
plt.tight_layout(); plt.show()
staged_training_sim()
# 实验:将 lambda_img 设为 1.0,观察图像损失收敛得慢得多