Skip to content

图神经网络

图神经网络通过在连接节点之间传递消息,从图结构数据中学习。本文涵盖消息传递框架、GCN、GraphSAGE、GIN、过平滑、图池化以及节点/边/图级别的任务;这些核心架构支撑着分子性质预测、社交网络分析和推荐系统。

  • 在前面的文档中,我们已经奠定了数学基础:几何深度学习(文档1)告诉我们要利用对称性,图论(文档2)则提供了节点、边和邻接的语言。现在我们直接在图结构上构建神经网络。

  • 核心挑战:图数据是不规则的。与图像(固定网格)或序列(固定顺序)不同,图具有可变数量的节点、可变的连接性,并且没有规范的节点顺序。用于图的神经网络必须处理所有这些,同时保持置换等变性(重新标记节点不应改变输出)。

消息传递框架

  • 几乎所有 GNN 都遵循相同的配方,称为消息传递(也称为邻域聚合)。这个想法简单而优雅:每个节点通过收集来自其邻居的信息来更新自身的表示。

  • 在每一层 \(l\),每个节点 \(i\) 做三件事:

    1. 消息:每个邻居 \(j\) 根据其当前特征计算一条消息 \(\mathbf{m}_{j \to i}\)
    2. 聚合:节点 \(i\) 收集所有传入的消息,并使用一个置换不变函数(求和、均值或最大值)将它们组合起来。
    3. 更新:节点 \(i\) 将聚合后的消息与自身的特征结合,产生新的表示。
  • 形式化地:

\[\mathbf{m}_i^{(l)} = \bigoplus_{j \in \mathcal{N}(i)} \phi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{h}_j^{(l)}, \mathbf{e}_{ij}\right)\]
\[\mathbf{h}_i^{(l+1)} = \psi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{m}_i^{(l)}\right)\]
  • 其中 \(\mathcal{N}(i)\) 是节点 \(i\) 的邻居集合,\(\bigoplus\) 是置换不变的聚合函数(求和、均值、最大值),\(\phi\) 是消息函数,\(\psi\) 是更新函数,\(\mathbf{e}_{ij}\) 是可选的边特征。

消息传递:邻居发送消息,一个置换不变的聚合函数对它们进行聚合,然后节点更新其特征

  • 聚合函数 \(\bigoplus\) 必须是置换不变的(处理邻居的顺序无关紧要),以确保整体函数是置换等变的。这直接实现了文档1中的对称性原则。

  • 经过 \(k\) 层消息传递后,每个节点的表示编码了来自其 \(k\) 跳邻域的信息:在 \(k\) 条边内可到达的所有节点。第1层看到直接邻居,第2层看到邻居的邻居,依此类推。这就是局部信息如何传播以构建全局理解的方式。

  • GNN 的感受野随深度增长,就像 CNN 的感受野随层数增长一样(第8章)。但与规则网格上的 CNN 不同,感受野的形状取决于图拓扑结构,每个节点各不相同。

图卷积网络(GCN)

  • GCN(Kipf & Welling, 2017)是基础性的 GNN 架构。它将谱图卷积(来自文档2)简化为一个优雅、高效的公式。

  • 从谱卷积 \(g_\theta \star \mathbf{x} = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}\) 出发,Kipf 和 Welling 使用一阶切比雪夫多项式近似谱滤波器,从而完全避免了计算特征分解。简化后,层更新变为:

\[H^{(l+1)} = \sigma\left(\hat{A} H^{(l)} W^{(l)}\right)\]
  • 其中:

    • \(H^{(l)} \in \mathbb{R}^{n \times d}\) 是第 \(l\) 层的节点特征矩阵
    • \(W^{(l)} \in \mathbb{R}^{d \times d'}\) 是可学习的权重矩阵
    • \(\hat{A} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\) 是带自环的对称归一化邻接矩阵
    • \(\tilde{A} = A + I\) 添加了自环(这样每个节点也接收自身的消息)
    • \(\tilde{D}\)\(\tilde{A}\) 的度数矩阵
    • \(\sigma\) 是非线性激活函数(ReLU,如第6章)
  • 矩阵乘法 \(\hat{A} H^{(l)}\) 是聚合步骤:对每个节点,它计算其邻居特征(加上自身特征,通过自环)的加权平均。权重矩阵 \(W^{(l)}\) 是可学习的变换,在所有节点上共享。激活函数增加非线性。

  • 这非常简单:仅仅是矩阵乘法,然后是一个学习到的线性映射和激活函数。整个 GCN 层可以写成一行代码。通过 \(\tilde{D}^{-1/2}\) 归一化可以防止度大的节点主导:高度数的节点其消息会被缩小。

  • 在消息传递框架中,GCN 使用:

    • 消息:\(\phi(\mathbf{h}_j) = \mathbf{h}_j\)(只发送自身的特征)
    • 聚合:归一化求和(按度数加权)
    • 更新:线性变换 + 激活

GraphSAGE

  • GCN 是直推式的:它在训练时需要整个图,并且不能处理新的、未见过的节点。如果一个新用户加入社交网络,GCN 必须在整个图上重新训练。GraphSAGE(Hamilton 等,2017)通过归纳式方法解决了这个问题。

  • 关键思想是邻域采样:不使用所有邻居,而是采样一个固定大小的子集。这使得计算独立于整个图结构,并允许泛化到未见过的节点和图。

  • GraphSAGE 对节点 \(i\) 的更新:

\[\mathbf{h}_i^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(\mathbf{h}_i^{(l)}, \text{AGG}\left(\{\mathbf{h}_j^{(l)} : j \in \mathcal{S}(i)\}\right)\right)\right)\]
  • 其中 \(\mathcal{S}(i)\)采样的邻居子集(例如,从500个邻居中随机采样10个)。CONCAT 操作将节点自身的特征与聚合后的邻居特征显式地分开,让网络学习“自身”和“邻居”的不同变换。

  • GraphSAGE 支持多种聚合函数:

    • 均值\(\text{AGG} = \frac{1}{|\mathcal{S}|} \sum_{j \in \mathcal{S}} \mathbf{h}_j\)(简单、有效)
    • LSTM:将采样的邻居送入 LSTM(但这引入了顺序依赖,一定程度上违反了置换不变性)
    • 池化\(\text{AGG} = \max(\{\sigma(W_{\text{pool}} \mathbf{h}_j + \mathbf{b})\})\)(非线性变换后再取最大值)
  • 采样策略使 GraphSAGE 可扩展到非常大的图。训练使用节点的 mini-batch:对于每个目标节点,在第1层采样 \(k_1\) 个邻居,然后对这些邻居中的每一个在第2层采样 \(k_2\) 个邻居。当 \(k_1 = k_2 = 10\),2层网络时,每个节点的计算树最多包含 \(10 \times 10 = 100\) 个节点,与图的大小无关。

图同构网络(GIN)

  • 不同的 GNN 架构具有不同的表达能力:它们区分结构不同的图的能力。GCN 和 GraphSAGE 虽然在实践中有效,但在区分某些图结构方面存在理论上的限制。

  • 衡量 GNN 表达能力的理论工具是Weisfeiler-Lehman (WL) 测试,这是一个经典的图同构测试算法(判断两个图是否结构相同)。WL 测试通过将每个节点的标签与其邻居标签的多重集一起进行哈希,迭代地细化节点标签。

  • GIN(Xu 等,2019)被设计为具有与 WL 测试相同的表达能力,使其成为最强大的消息传递 GNN(在消息传递的理论限制内)。关键见解:聚合函数必须是在多重集上单射的(不同的邻居特征多重集必须产生不同的聚合值)。

  • 求和聚合在多重集上是单射的(\(\{1, 1, 2\}\) 求和为4,而 \(\{1, 3\}\) 求和也为4,但在足够维度的特征向量上,不同多重集的求和通常是可区分的)。均值和最大值不是单射的:均值无法区分 \(\{1, 1\}\)\(\{2, 2\}\),最大值无法区分 \(\{1, 2, 3\}\)\(\{1, 1, 3\}\)

  • GIN 的更新公式为:

\[\mathbf{h}_i^{(l+1)} = \text{MLP}^{(l)}\left((1 + \epsilon^{(l)}) \cdot \mathbf{h}_i^{(l)} + \sum_{j \in \mathcal{N}(i)} \mathbf{h}_j^{(l)}\right)\]
  • 其中 \(\epsilon\) 是可学习的标量(或固定为0),MLP 提供非线性的单射映射。求和聚合保留了多重集结构,MLP 可以学会区分任意两个不同的聚合值。

过平滑

  • GNN 中的一个主要挑战是过平滑:随着层数增加,所有节点的表示收敛到相同的值,失去了区分不同节点的能力。

过平滑:第1层时不同的节点特征,在经过更深层后逐渐混合成均匀的特征

  • 机理很直观。每一层消息传递都将节点的特征与其邻居的特征进行平均。经过多轮平均后,每个节点都看到了(并与)其连通分量中的所有其他节点进行了混合。特征变成均匀的平均值,这相当于将图像过度模糊直到变成单一颜色。

  • 形式化地,重复应用归一化邻接矩阵 \(\hat{A}\) 会收敛到一个秩为1的矩阵(每一行都正比于图上随机游走的平稳分布)。这与幂迭代收敛到主特征向量的过程相同(第2章)。

  • 过平滑将 GNN 限制在较浅的深度(通常为 2-4 层),而 CNN 和 Transformer 可以从数十或数百层中受益。这意味着每个节点只能看到有限的邻域,这对于需要长程信息的任务来说是个问题。

  • 缓解方法包括:

    • 残差连接(来自 ResNet,第8章):\(\mathbf{h}_i^{(l+1)} = \mathbf{h}_i^{(l+1)} + \mathbf{h}_i^{(l)}\),保留来自早期层的信息。
    • 跳跃知识:连接或注意力池化来自所有层的表示,而不仅仅是最后一层。
    • DropEdge:训练期间随机移除边,减缓信息传播。
    • 图 Transformer(文档4):通过全局注意力绕过局部消息传递的瓶颈。

图池化

  • 对于图级别任务(预测整个图的性质,如分子的毒性),我们需要将所有节点表示折叠成一个图级别的向量。这就是图池化,相当于 CNN 中的全局平均池化(第8章)。

  • 最简单的方法是读出:对节点特征集合应用一个置换不变函数:

\[\mathbf{h}_G = \text{READOUT}(\{\mathbf{h}_i^{(L)} : i \in V\}) = \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \frac{1}{|V|} \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \max_i \mathbf{h}_i^{(L)}\]
  • 这是在最后一层 GNN 之后应用的 DeepSets 聚合(来自文档1)。求和保留大小信息(有100个节点的图比有10个节点的图会有更大的求和值),而均值则对大小进行归一化。

  • 层次化池化逐步粗化图,模仿 CNN 逐步下采样图像的方式。在每一级,若干节点被合并成“超级节点”:

  • DiffPool(可微分池化)学习一个软分配矩阵 \(S^{(l)} \in \mathbb{R}^{n_l \times n_{l+1}}\),将每个节点分配到一个簇:

\[X^{(l+1)} = S^{(l)T} H^{(l)}, \quad A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)}\]
  • 分配矩阵由一个单独的 GNN 预测,使聚类成为端到端可微的。这创建了一个层次结构:原始图 → 节点更少的粗化图 → 更粗的图 → 单个节点(图的表示)。

  • TopKPool 采用更简单的方法:为每个节点学习一个标量分数,保留得分最高的 \(k\) 个节点,丢弃其余节点。这是硬选择(而非软分配),计算上比 DiffPool 更便宜。

异构图

  • 迄今为止所有的 GNN 都假设同构图:一种节点类型,一种边类型。但大多数现实世界的图是异构的:多种节点类型和多种边类型。知识图包含人物节点、组织节点和位置节点,由“工作于”、“出生于”和“位于”等边连接。推荐系统包含用户节点和物品节点,由“购买”、“浏览”和“评分”等边连接。

  • 异构图有一个模式(也称为元图),定义了允许的节点类型和边类型。每种边类型连接特定的源类型和目标类型。例如,“工作于”连接 Person → Organisation。

  • 关系图卷积网络(R-GCN)(Schlichtkrull 等,2018)通过为每种边类型使用单独的权重矩阵来处理异构边:

\[\mathbf{h}_i^{(l+1)} = \sigma\left(\sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} W_r^{(l)} \mathbf{h}_j^{(l)} + W_0^{(l)} \mathbf{h}_i^{(l)}\right)\]
  • 其中 \(\mathcal{R}\) 是边类型的集合,\(\mathcal{N}_r(i)\) 是通过关系 \(r\) 连接到节点 \(i\) 的邻居集合,\(W_r\) 是关系 \(r\) 特定的权重矩阵。自连接 \(W_0\) 单独处理节点自身的特征。

  • 问题:关系类型很多时,参数数量爆炸(每个关系一个 \(d \times d\) 矩阵)。R-GCN 通过基分解缓解:\(W_r = \sum_{b=1}^{B} a_{rb} V_b\),其中 \(V_b\) 是共享的基矩阵,\(a_{rb}\) 是每个关系的标量系数。这类似于低秩分解(第2章):关系特定的矩阵位于一个低维子空间中。

  • 异构图 Transformer(HGT)(Hu 等,2020)将注意力机制应用于异构图。关键见解是:注意力应同时依赖于节点类型和连接它们的边类型。HGT 使用类型特定的投影矩阵用于查询、键和值:

\[\text{Attention}(i, j) = \left(W_{\tau(i)}^Q \mathbf{h}_i\right)^T \cdot \frac{W_{\phi(i,j)}^{\text{ATT}}}{\sqrt{d}} \cdot \left(W_{\tau(j)}^K \mathbf{h}_j\right)\]
  • 其中 \(\tau(i)\) 是节点 \(i\) 的类型,\(\phi(i,j)\) 是它们之间的边类型。这确保模型对不同关系类型的关注不同:一篇论文关注其作者时的注意力权重应该与关注其参考文献时的注意力权重不同。

  • 基于元路径的方法定义通过模式的有意义路径(例如,Author → Paper → Author 表示合著关系),并沿这些路径聚合信息。HAN(异构图注意力网络)在两个层面上应用注意力:在每个元路径内部(沿着该路径的哪些邻居重要?)以及在元路径之间(哪些关系模式重要?)。

链接预测与知识图谱补全

  • 链接预测要解决的是:给定现有边,哪些缺失的边可能存在?这是知识图谱补全(预测缺失事实)、推荐(预测用户喜欢哪些物品)和社交网络分析(预测未来友谊)的核心任务。

  • 基于嵌入的方法为每个实体学习一个向量,为每个关系学习一个变换,然后通过实体和关系的契合程度对潜在边进行评分:

  • TransE 将关系建模为嵌入空间中的平移:如果 \((h, r, t)\) 是一个有效三元组(头实体,关系,尾实体),则 \(\mathbf{h} + \mathbf{r} \approx \mathbf{t}\)。评分函数为 \(f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|\)。直观上,关系向量将头实体在嵌入空间中“移动”到尾实体。

  • RotatE 将关系建模为复空间中的旋转:\(\mathbf{t} = \mathbf{h} \circ \mathbf{r}\),其中 \(\circ\) 是逐元素复数乘法,\(|\mathbf{r}_i| = 1\)(单位复数就是旋转)。它可以建模 TransE 无法处理的对称、反对称、反转和组合模式。

  • ComplEx 使用复值嵌入和埃尔米特点积,能够建模非对称关系(如果A是B的上司,则B不是A的上司)。

  • 基于 GNN 的链接预测通过消息传递计算节点嵌入,然后使用端点嵌入对边进行评分。这结合了 GNN 的结构推理能力和嵌入方法的关系建模能力。GNN 编码器捕获了单嵌入方法无法得到的多跳邻域结构。

任务类型

  • GNN 解决三类任务:

  • 节点级别任务:预测每个节点的属性。示例:社交网络中的用户分类(机器人还是人类)、交互网络中的每个蛋白质功能预测、半监督节点分类(标记少数节点,预测其余节点)。输出是节点嵌入 \(\mathbf{h}_i^{(L)}\) 通过一个分类器。

  • 边级别任务:预测每条边的属性或预测边是否存在。示例:链接预测(这两个用户会成为朋友吗?)、知识图谱补全(实体间是否存在这种关系?)、药物-药物相互作用预测。输出通常使用两个端点的嵌入:\(\hat{y}_{ij} = f(\mathbf{h}_i, \mathbf{h}_j)\),其中 \(f\) 可以是点积、拼接+MLP 或其他组合。

  • 图级别任务:预测整个图的属性。示例:分子性质预测(这个分子有毒吗?)、图分类(这个社交网络是机器人网络吗?)、图生成(设计具有所需属性的分子)。输出通过图池化得到 \(\mathbf{h}_G\),然后进行分类或回归。

编码任务(使用 Colab 或 notebook)

  1. 使用归一化邻接矩阵从头实现一个 GCN 层。将其应用到一个小的图上,观察节点特征如何被平滑。

    import jax
    import jax.numpy as jnp
    
    # 图:5个节点,简单的链加一个分支
    A = jnp.array([[0, 1, 0, 0, 0],
                   [1, 0, 1, 0, 0],
                   [0, 1, 0, 1, 1],
                   [0, 0, 1, 0, 0],
                   [0, 0, 1, 0, 0]], dtype=float)
    
    # 添加自环
    A_hat = A + jnp.eye(5)
    D_hat = jnp.diag(A_hat.sum(axis=1))
    D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))
    A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
    
    # 节点特征:单位矩阵(每个节点一个独热编码)
    H = jnp.eye(5)
    
    # 权重矩阵(随机初始化)
    rng = jax.random.PRNGKey(0)
    W = jax.random.normal(rng, (5, 3)) * 0.5
    
    # GCN层: H' = ReLU(A_norm @ H @ W)
    H_new = jax.nn.relu(A_norm @ H @ W)
    
    print("原始特征(独热编码):")
    print(H)
    print("\n经过GCN层后:")
    print(jnp.round(H_new, 3))
    print("\n注意:相连的节点现在具有相似的表示")
    

  2. 使用求和聚合(GIN风格)实现消息传递,并与均值聚合(GCN风格)进行比较。展示求和能够区分均值无法区分的多重集。

    import jax.numpy as jnp
    
    # 两个不同的邻域多重集,但均值相同
    # 节点A:邻居特征为 [1, 1, 1, 1]  (四个邻居,都是1)
    # 节点B:邻居特征为 [2, 2]        (两个邻居,都是2)
    
    neighbours_A = jnp.array([[1.0], [1.0], [1.0], [1.0]])
    neighbours_B = jnp.array([[2.0], [2.0]])
    
    # 均值聚合
    mean_A = neighbours_A.mean(axis=0)
    mean_B = neighbours_B.mean(axis=0)
    print(f"均值 A: {mean_A}, 均值 B: {mean_B}, 相同: {jnp.allclose(mean_A, mean_B)}")
    
    # 求和聚合
    sum_A = neighbours_A.sum(axis=0)
    sum_B = neighbours_B.sum(axis=0)
    print(f"求和 A:  {sum_A},  求和 B:  {sum_B},  相同: {jnp.allclose(sum_A, sum_B)}")
    print("\n求和能够区分这些多重集;均值不能!")
    

  3. 演示过平滑现象。重复应用归一化邻接矩阵,观察节点特征如何收敛。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 随机图
    A = jnp.array([[0,1,1,0,0,0],
                   [1,0,1,0,0,0],
                   [1,1,0,1,0,0],
                   [0,0,1,0,1,1],
                   [0,0,0,1,0,1],
                   [0,0,0,1,1,0]], dtype=float)
    
    A_hat = A + jnp.eye(6)
    D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))
    A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
    
    # 初始特征:每个节点各不相同
    H = jnp.array([[1,0], [0,1], [1,1], [-1,0], [0,-1], [-1,-1]], dtype=float)
    
    distances = []
    for k in range(20):
        H = A_norm @ H
        # 衡量特征的区别程度(节点间的标准差)
        spread = jnp.std(H, axis=0).mean()
        distances.append(float(spread))
    
    plt.plot(distances, "o-")
    plt.xlabel("消息传递轮数")
    plt.ylabel("特征离散程度(节点间的标准差)")
    plt.title("过平滑:深度增加导致特征收敛")
    plt.show()