Skip to content

几何深度学习

几何深度学习是一个统一的框架,揭示了 CNN、Transformer 和 GNN 都是同一原理的实例:利用对称性。本章涵盖对称群、群作用、不变性、等变性、五个几何域以及尺度分离

  • 在本书中,我们学习了多种架构:处理图像的 CNN(第 8 章)、处理语言的 Transformer(第 7 章)以及用于序列决策的强化学习策略(第 6 章)。它们看起来像是为完全不同的问题设计的完全不同的模型。但在这些表象之下,存在着更深刻的模式。

  • 几何深度学习 揭示了所有这些架构都是同一思想的实例:构建尊重数据对称性的网络。CNN 利用了图像中的平移对称性。Transformer 利用了序列中的置换对称性(注意力不依赖于绝对位置)。GNN 利用了图中的置换对称性。一旦你理解了这一点,架构的“动物园”就变成了一个单一的、连贯的框架。

对称性与群

  • 一个物体的对称性是指使其保持不变的一种变换。一个正方形有 8 种对称性:4 种旋转(0°、90°、180°、270°)和 4 种反射。一个圆有无穷多种对称性:绕其圆心的任何旋转。关键的洞见是,对称性告诉你什么是不重要的,而知道什么不重要对于学习来说是非常强大的。

  • 用机器学习的术语来说:如果一个任务具有某种对称性,那么无论模型看到输入的哪个“版本”,它都应该给出相同的答案。一个猫检测器,无论猫是在图像的左上角还是右下角,都应该能工作。这就是平移对称性。

  • 对称性被形式化为。一个群 \(G\) 是一个变换集合,具有四个性质:

    • 封闭性:两个变换的组合仍然在集合内。旋转 90° 再旋转 90° 得到 180°,仍在集合内。
    • 结合律\((g_1 \circ g_2) \circ g_3 = g_1 \circ (g_2 \circ g_3)\)。分组的顺序不重要(回顾第 2 章中矩阵乘法的结合律)。
    • 单位元:存在一个“什么都不做”的变换 \(e\),使得 \(e \circ g = g \circ e = g\)
    • 逆元:每个变换都有一个撤消操作:\(g \circ g^{-1} = e\)
  • 这些公理与向量空间(第 1 章)的公理相同,但针对的是变换而非向量。它们之间的联系很深:群作用于向量空间,而神经网络必须尊重这种作用。

  • 深度学习中出现的几个关键群:

    • 平移群 \((\mathbb{R}^n, +)\):平移图像或信号。这是 CNN 所利用的对称性。
    • 对称群 \(S_n\)\(n\) 个元素的所有置换。这是 GNN 和 Transformer 所利用的对称性(重新标记节点或标记不应改变结果)。
    • 旋转群 \(SO(n)\)\(n\) 维空间中的所有旋转。\(SO(2)\) 是平面中的旋转,\(SO(3)\) 是 3D 中的旋转(对分子和 3D 视觉任务至关重要)。
    • 欧几里得群 \(E(n)\):所有旋转、反射和平移。物理空间的对称性。
    • 特殊欧几里得群 \(SE(n)\):旋转和平移(无反射)。刚体运动的对称性。
  • 群作用 描述了群如何变换数据。如果 \(G\) 是一个群,\(X\) 是一个数据空间,那么作用 \(\rho: G \times X \to X\) 将每个群元素 \(g\) 和数据点 \(x\) 映射到一个变换后的点 \(\rho(g, x)\)。对于图像,平移群通过平移像素坐标来作用。对于图,对称群通过重新标记节点来作用。

不变性与等变性

  • 给定一个对称群,一个函数可以以两种重要方式与之相关:

  • 如果输入变换后输出不变,则函数 \(f\) 对群 \(G\)不变的

\[f(\rho(g, x)) = f(x) \quad \text{对所有 } g \in G\]
  • 例如:一张图像的总亮度不会因为图像平移而改变。图像分类应该是平移不变的:无论猫坐在哪里,“猫”这个类别都是一样的。

  • 如果输入变换时输出以相应方式变换,则函数 \(f\) 对群 \(G\)等变的

\[f(\rho_{\text{in}}(g, x)) = \rho_{\text{out}}(g, f(x)) \quad \text{对所有 } g \in G\]
  • 例如:如果你将一张图像向右平移 5 个像素,CNN 中的特征图也会向右平移 5 个像素。卷积运算是平移等变的:它保持了空间关系。目标检测应该是等变的:如果猫移动了,边界框也应该随之移动。

不变性:无论变换如何,输出保持不变。等变性:输出相应地变换

  • 这种区别很重要:中间层通常应该是等变的(为下游层保留结构),而最终输出应该是不变的(答案不应依赖于变换)。CNN 通过堆叠等变卷积层,然后在最后应用全局池化(它是不变的)来实现这一点。

  • 将等变性内建于架构中,远比从数据中学习它更高效。具有权重共享的平移等变 CNN 所需的参数,远少于一个必须独立学习“位置 (10,10) 处的猫”和“位置 (200,150) 处的猫”的全连接网络。对称性约束以指数方式减少了假设空间。

五个几何域

  • 几何深度学习识别出数据的五个基本域,每个域都有其自身的对称群。每一个神经网络架构都可以被理解为利用了这些域之一的对称性。

五个几何域:网格、集合、序列、图和流形,每个都有自己的对称性和架构

  • 1. 网格(欧几里得数据):图像、音频频谱图、体积数据。底层结构是一个具有平移对称性的规则网格。其群是平移群(还可能包括旋转和反射)。利用这种对称性的架构是 CNN:卷积正是对平移等变的操作。跨空间位置的权重共享就是平移等变性的具体体现。

  • 2. 集合(无序集合):点云、粒子系统。其对称性是置换不变性:元素的顺序无关紧要。其架构是 DeepSets(以及第 8 章的 PointNet):对每个元素应用一个共享函数,然后使用置换不变的操作(求和、平均或最大值)进行聚合。形式上表示为 \(f(\{x_1, \ldots, x_n\}) = \phi\left(\sum_i \psi(x_i)\right)\)

  • 3. 序列(有序数据):文本、时间序列。序列是一维网格,但有一点不同:对称性更为微妙。绝对位置可能重要,也可能不重要。RNN 以自回归方式处理序列。带有位置编码的 Transformer 可以关注任何位置,其自注意力在(添加位置编码之前)对置换是等变的。这就是 Transformer 泛化能力如此之强的原因:它们从置换等变开始,然后只添加了刚好够用的位置结构。

  • 4. 图(关系数据):社交网络、分子、知识图谱。其对称性是节点的置换:重新标记节点不应改变图的属性。其架构是 GNN:在连接节点之间进行消息传递,使用不依赖于节点顺序的共享函数。这是本章剩余部分将重点介绍的内容。

  • 5. 流形和网格(曲面,3D形状):曲面、3D 形状。其对称性包括微分同胚(光滑变形)。其架构使用由曲面几何本身定义的内在算子(例如 Laplace-Beltrami 算子),独立于曲面在空间中的嵌入方式。这联系到微分几何,并与形状分析、球面上的气候建模以及蛋白质表面分析相关。

  • 这个框架的强大之处在于统一性。CNN 是网格图上的 GNN。Transformer 是全连接图上的 GNN。DeepSets 是没有边的 GNN。将这些视为同一原理的实例,可以指导新架构的设计:识别你数据的对称性,然后构建尊重该对称性的网络。

尺度分离与粗化

  • 现实世界的数据具有多尺度结构。一幅图像有细粒度的纹理(像素级)、局部模式(边缘、角点)、物体部件(轮子、窗户)和全局结构(整个场景)。一个分子有原子级特征、官能团和整体分子形状。

  • 尺度分离 是指这些细节层次可以被分层处理:首先捕捉局部结构,然后逐步聚合成更粗略的表示。这就是粗化池化

  • 在 CNN 中,池化层(最大池化、平均池化)会降低空间分辨率,迫使更高层捕捉更大尺度的模式。从感受野的角度看(第 8 章),更深的层能“看到”图像更多的部分。这就是尺度分离的实际应用。

  • 在图结构中,粗化意味着将节点组聚类成“超节点”,生成一个保留基本结构的更小的图。这就是图池化,我们将在文件 3 中详细讨论。它与图像池化的类比是直接的:降低分辨率的同时保留重要特征。

  • 在序列中,分层处理(例如:句子 → 段落 → 文档)可以捕捉不同时间或语义尺度的结构。Swin Transformer(第 8 章)通过其移动窗口层次结构将这一思想应用于图像。

  • 从数学上讲,粗化定义了一个日益抽象的表示层次结构

\[x \xrightarrow{\text{局部特征}} h^{(1)} \xrightarrow{\text{粗化}} h^{(2)} \xrightarrow{\text{粗化}} \cdots \xrightarrow{\text{全局}} y\]
  • 在每个层次上,表示对该层次的对称群是等变的。最终的全局表示是不变的,捕捉了输入的本质,而对无关的变换不敏感。

  • 这个层次结构就是深度网络在处理结构化数据时比浅层网络效果更好的原因:每一层增加一个抽象层次,许多等变层的组合从简单的局部特征中构建出复杂的全局不变特征。

编程任务(使用 CoLab 或 notebook)

  1. 验证卷积的平移等变性。对一个图像应用卷积,然后平移图像再进行卷积。检查输出是否是彼此的平移版本。

    import jax
    import jax.numpy as jnp
    
    # 一维信号和一个简单滤波器
    signal = jnp.array([0, 0, 0, 1, 2, 3, 2, 1, 0, 0, 0], dtype=float)
    kernel = jnp.array([1, 0, -1], dtype=float)
    
    # 先卷积再平移
    conv_result = jnp.convolve(signal, kernel, mode="same")
    shifted_signal = jnp.roll(signal, 3)
    conv_shifted = jnp.convolve(shifted_signal, kernel, mode="same")
    shifted_conv = jnp.roll(conv_result, 3)
    
    print(f"先卷积再平移:  {shifted_conv}")
    print(f"先平移再卷积:  {conv_shifted}")
    print(f"等变:{jnp.allclose(shifted_conv, conv_shifted, atol=1e-5)}")
    

  2. 验证 DeepSets 风格聚合的置换不变性。对集合的每个元素应用一个共享函数,对结果求和,并检查输出是否与元素顺序无关。

    import jax
    import jax.numpy as jnp
    
    # 一个包含 4 个向量的“集合”(顺序不应该重要)
    x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
    
    # 简单的共享函数:逐元素平方
    psi = lambda v: v ** 2
    
    # 通过求和聚合
    def deepsets(points):
        return jnp.sum(jax.vmap(psi)(points), axis=0)
    
    # 原始顺序
    result1 = deepsets(x)
    
    # 置换后的顺序
    perm = jnp.array([2, 0, 3, 1])
    result2 = deepsets(x[perm])
    
    print(f"原始顺序:  {result1}")
    print(f"置换顺序:  {result2}")
    print(f"不变:{jnp.allclose(result1, result2)}")
    

  3. 探索群结构。通过验证封闭性、结合律、单位元和逆元,来确认二维旋转矩阵构成一个群。

    import jax.numpy as jnp
    
    def rot2d(theta):
        return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                           [jnp.sin(theta),  jnp.cos(theta)]])
    
    R1 = rot2d(jnp.pi / 6)
    R2 = rot2d(jnp.pi / 4)
    R3 = rot2d(jnp.pi / 3)
    
    # 封闭性:两个旋转的乘积仍然是一个旋转
    R12 = R1 @ R2
    print(f"封闭性 (det=1, 正交): det={jnp.linalg.det(R12):.4f}, "
          f"R^T R = I: {jnp.allclose(R12.T @ R12, jnp.eye(2), atol=1e-5)}")
    
    # 结合律
    print(f"结合律: {jnp.allclose((R1 @ R2) @ R3, R1 @ (R2 @ R3), atol=1e-5)}")
    
    # 单位元
    I = rot2d(0.0)
    print(f"单位元: {jnp.allclose(R1 @ I, R1, atol=1e-5)}")
    
    # 逆元
    R1_inv = rot2d(-jnp.pi / 6)
    print(f"逆元: {jnp.allclose(R1 @ R1_inv, jnp.eye(2), atol=1e-5)}")