Skip to content

图注意力网络

图注意力网络用学习到的、依赖于数据的权重取代了统一的邻居聚合。本文涵盖 GAT、多头图注意力、GATv2、图 Transformer、位置编码与结构编码以及可扩展性。

  • 在 GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)来聚合其邻居的特征。一个有3个邻居的节点会给每个邻居大致相等的权重(约 \(1/3\))。但并非所有邻居都同等重要:来自亲密合作者的信息应该比来自远方熟人的信息更重要。

  • 图注意力网络通过学习应该关注哪些邻居来解决这个问题,使用的是与 Transformer(第7章)相同的注意力机制。每个节点在其邻居上计算动态的、基于内容的注意力分数,而非基于结构的固定权重。

GAT:图注意力网络

  • GAT(Veličković 等人,2018)计算每个节点与其邻居之间的注意力系数。对于节点 \(i\) 和邻居 \(j\)
\[e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)\]
  • 其中 \(W \in \mathbb{R}^{d' \times d}\) 是共享线性变换,\(\|\) 表示拼接,\(\mathbf{a} \in \mathbb{R}^{2d'}\) 是可学习的注意力向量。分数 \(e_{ij}\) 衡量节点 \(j\) 的特征对节点 \(i\) 的重要程度。

  • 原始分数通过 softmax 在所有邻居上进行归一化:

\[\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}\]
  • 这确保了每个节点的邻域内注意力权重之和为1,就像 Transformer 注意力一样(第7章)。节点的更新特征为:
\[\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)\]

GCN 对所有邻居分配固定的相等权重;GAT 学习数据依赖的注意力权重

  • 与 GCN 的关键区别:权重 \(\alpha_{ij}\)从数据中学习的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略有噪声或不相关的邻居。

  • 注意,注意力仅在边上计算(节点 \(i\) 仅关注其邻居 \(\mathcal{N}(i)\)),而非所有节点对。这使得计算量与边的数量成正比,而不是节点数的平方。

多头图注意力

  • 正如 Transformer(第7章)一样,多头注意力并行运行 \(K\) 个独立的注意力机制,每个都有自己的参数 \(W^k\)\(\mathbf{a}^k\)。结果在中间层进行拼接,或在最后一层进行平均:
\[\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)\]
  • 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个头关注语义相似性。这与 Transformer 中多头注意力的动机相同:不同的头捕获不同类型的关系。

  • 若有 \(K\) 个头且每个头输出维度为 \(d'\),则拼接后的输出维度为 \(K \times d'\)。最后一层通常采用平均而非拼接,以产生固定大小的输出。

GATv2:修复静态注意力

  • 原始 GAT 有一个微妙的局限:其注意力函数是静态的(也称为基于排序的)。注意力分数依赖于拼接 \([W\mathbf{h}_i \| W\mathbf{h}_j]\),但由于注意力向量 \(\mathbf{a}\) 是在拼接之后应用的,它可以分解为两个独立的部分:\(\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j\)

  • 这意味着对于给定节点 \(i\),邻居的排序完全由邻居的特征 \(\mathbf{h}_j\) 决定(\(\mathbf{a}_1^T W\mathbf{h}_i\) 项在 \(i\) 的所有邻居中是常数)。注意力排序并不真正依赖于查询节点自身的特征。节点 \(i\) 和节点 \(k\) 会对相同的邻居集合给出完全相同的排序,这限制了表达能力。

  • GATv2(Brody 等人,2022)通过在注意力向量之前应用非线性激活来修复此问题:

\[e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)\]
  • 将 LeakyReLU 移到计算内部意味着注意力分数是联合特征的非线性函数,无法分解为独立项。这使得注意力变为动态:邻居的排序现在依赖于具体的查询节点。GATv2 在不增加计算成本的前提下,表达能力严格强于 GAT。

图 Transformer

  • 标准消息传递 GNN 受限于图拓扑结构:一个节点只能关注其直接邻居。经过 \(k\) 层后,来自 \(k\) 跳邻居的信息通过多次聚合步骤被混合,保真度下降。这种局部瓶颈(加上过平滑,文件3)限制了捕获长程依赖的能力。

  • 图 Transformer 通过对所有节点对应用全局自注意力(无论它们之间是否有边)打破了这个瓶颈。每个节点可以在单层中关注任何其他节点,就像标准 Transformer(第7章)一样。

  • 基本思想:将所有节点视为 token,并应用 Transformer 自注意力:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
  • 其中 \(Q = XW_Q\)\(K = XW_K\)\(V = XW_V\) 是节点特征 \(X\) 的查询、键、值投影(与第7章完全相同)。这相当于在完全连通图(完全图 \(K_n\),文件2)上的 GNN。

  • 问题:完全连通图忽略了实际的图结构。边信息(谁和谁实际相连)丢失了。有两种方法可以恢复这些信息:

  • Graphormer(Ying 等人,2021)通过向注意力分数中添加偏置项将图结构注入 Transformer:

\[A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)\]
  • 空间偏置 \(b_{\text{spatial}}\) 编码节点 \(i\)\(j\) 之间的最短路径距离。边偏置 \(b_{\text{edge}}\) 编码沿最短路径的边特征。此外,Graphormer 使用中心性编码,将节点的度数添加到其输入嵌入中,使模型获得关于每个节点结构角色的信息。

  • GPS(通用、强大、可扩展的图 Transformer,Rampášek 等人,2022)在每一层中将局部消息传递与全局注意力结合起来:

\[\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)\]
  • 每一层同时应用标准 GNN(用于局部结构)和 Transformer(用于全局上下文),然后组合结果。这兼得两者之优:来自消息传递的局部结构和来自注意力的长程依赖。

位置编码与结构编码

  • 序列上的 Transformer 使用位置编码(第7章)注入顺序信息。图没有规范的顺序,因此需要图特定的编码。

  • 拉普拉斯特征向量编码使用图拉普拉斯矩阵(文件2)的特征向量作为位置特征。最小的 \(k\) 个非平凡特征向量提供了图的谱嵌入:在图中“邻近”的节点具有相似的特征向量值。这些被拼接到节点特征上。

  • 一个小细节:拉普拉斯特征向量存在符号歧义性(若 \(\mathbf{u}\) 是特征向量,则 \(-\mathbf{u}\) 也是)。模型必须对这些符号翻转保持不变。解决方案包括:在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。

  • 随机游走编码计算从节点 \(i\) 出发的随机游走在 \(k\) 步后返回节点 \(i\) 的概率,其中 \(k = 1, 2, \ldots, K\)。这些概率编码了局部结构信息:密集聚类中的节点具有较高的返回概率,而稀疏区域的节点返回概率较低。着陆概率 \(p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}\),其中 \(A_{\text{rw}} = D^{-1}A\) 是随机游走转移矩阵。

  • 度数编码简单地将节点度数作为一个特征加入。这出人意料地有效,因为度数是一个强大的结构信号:叶子节点(度数为1)、桥节点和枢纽节点的行为不同。

  • 这些编码提供了普通 Transformer 所缺乏的结构信息,使得图 Transformer 在需要长程推理的任务上优于标准消息传递 GNN。

可扩展性

  • GNN 面临的基本可扩展性挑战在于图可能包含数百万个节点和数十亿条边。在整个图上训练 GNN 需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。

  • GNN 的小批量训练比图像或序列更复杂,因为节点之间是相互连接的。简单采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层)等等。这种邻域爆炸意味着,1000个目标节点的小批量可能需要在计算图中包含数百万个节点。

  • 邻域采样(GraphSAGE 风格,文件3)通过为每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 \(15^2 = 225\) 个节点,与完整图的大小无关。

  • Cluster-GCN(Chiang 等人,2019)使用图聚类算法(如 METIS)将图划分为若干个簇,然后每次在一个簇上进行训练。簇内边密集(大多数邻居位于同一簇中),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。

  • 图 Transformer 的可扩展性更为困难,因为全局注意力是 \(O(n^2)\) 的。对于拥有数百万个节点的图,完全注意力是不可行的。解决方案包括:

    • 稀疏注意力模式(仅关注图中 \(k\) 个最近的节点)
    • 线性注意力近似
    • 将局部消息传递(成本低,\(O(|E|)\))与粗化图(节点更少)上的全局注意力相结合

时序图与动态图

  • 到目前为止我们研究的图都是静态的:节点、边和特征是固定的。但许多现实世界的图会随时间演化:新用户加入社交网络,金融交易创建边,交通模式在一天中变化,分子相互作用波动。

  • 时序图为每条边增加一个时间戳:\((i, j, t)\) 表示节点 \(i\) 在时间 \(t\) 与节点 \(j\) 发生了交互。挑战在于学习能够同时捕获图结构和时序动态的表示。

  • 存在两种范式:

  • 离散时间动态图(DTDG):图表示为一系列快照 \(G_1, G_2, \ldots, G_T\),每个时间步一个快照。GNN 处理每个快照,RNN 或时序注意力机制捕获快照之间的演化。这种方法简单但会丢失细粒度的时间信息(快照之间的事件丢失了),并且需要选择快照频率。

  • 连续时间动态图(CTDG):事件被建模为带时间戳的交互流。每个事件 \((i, j, t)\) 在它发生的精确时间更新节点 \(i\)\(j\) 的表示。这保留了所有时间信息。

  • 时序图网络(TGN)(Rossi 等人,2020)是领先的 CTDG 架构。每个节点维护一个内存状态 \(\mathbf{s}_i(t)\),当节点参与交互时该状态会被更新:

\[\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)\]
  • 其中 \(\mathbf{m}_i(t)\) 是根据交互计算的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使内存能够捕获长期模式,同时适应近期事件。

  • 时间编码将自上次交互以来的经过时间表示为一个特征向量,类似于 Transformer 中的位置编码(第7章)。一种常见的方法是使用可学习的傅里叶特征:

\[\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]\]
  • 这为模型提供了时间间隔的丰富表示:“该用户5分钟前活跃”与“3个月前活跃”会被嵌入为不同的向量。

  • 时序图注意力(TGAT) 在节点的时序邻域上应用自注意力:一组最近的交互,每个交互同时根据特征相关性(如 GAT)和时间新鲜度进行加权。来自久远过去的交互自然会被降低权重。

  • 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(根据历史流量模式预测拥堵)、社交网络动态(预测病毒式内容传播)以及随时间变化的药物相互作用预测。

编码任务(使用 CoLab 或 notebook)

  1. 从头实现一个单头 GAT 注意力层。计算节点与其邻居之间的注意力权重,并验证它们之和为1。

    import jax
    import jax.numpy as jnp
    
    rng = jax.random.PRNGKey(0)
    k1, k2, k3 = jax.random.split(rng, 3)
    
    n_nodes, d_in, d_out = 5, 4, 3
    
    # 随机节点特征
    H = jax.random.normal(k1, (n_nodes, d_in))
    
    # 可学习参数
    W = jax.random.normal(k2, (d_in, d_out)) * 0.5
    a = jax.random.normal(k3, (2 * d_out,)) * 0.5
    
    # 邻接关系(节点0连接到1、2、3)
    neighbours_of_0 = [1, 2, 3]
    
    # 变换特征
    Wh = H @ W  # (n_nodes, d_out)
    
    # 计算节点0的注意力分数
    h_i = Wh[0]
    scores = []
    for j in neighbours_of_0:
        h_j = Wh[j]
        e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
        e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
        scores.append(float(e_ij))
    
    scores = jnp.array(scores)
    alpha = jax.nn.softmax(scores)
    
    print(f"原始分数: {scores}")
    print(f"注意力权重: {alpha}")
    print(f"权重之和: {alpha.sum():.4f}")
    
    # 加权聚合
    h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
    print(f"更新后的节点0特征: {h_new}")
    

  2. 比较 GCN(固定权重)与 GAT(学习权重)的聚合。展示 GAT 可以对邻居分配不同的权重,而 GCN 则视它们为均等的。

    import jax
    import jax.numpy as jnp
    
    # 4个节点:节点0连接到1、2、3
    A = jnp.array([[0,1,1,1],
                   [1,0,0,0],
                   [1,0,0,0],
                   [1,0,0,0]], dtype=float)
    
    # 特征:节点1非常相关,节点2是噪声,节点3中等
    H = jnp.array([[0.0, 0.0],   # 节点0
                   [1.0, 0.0],   # 节点1(信号)
                   [0.0, 0.0],   # 节点2(噪声)
                   [0.5, 0.0]])  # 节点3(中等)
    
    # GCN:归一化邻接权重
    A_hat = A + jnp.eye(4)
    D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
    gcn_weights = (D_inv @ A_hat)[0]  # 节点0的权重
    print(f"节点0的 GCN 权重: {gcn_weights}")
    print("  → 所有邻居获得大致相等的权重")
    
    # GAT:学习到的注意力(模拟)
    # 假设注意力机制学会关注节点1
    gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15])  # 学习到的
    print(f"\n节点0的 GAT 权重: {gat_weights}")
    print("  → 节点1(信息量大)获得最多的注意力")
    
    gcn_output = gcn_weights @ H
    gat_output = gat_weights @ H
    print(f"\nGCN 输出: {gcn_output}  (被噪声稀释)")
    print(f"GAT 输出: {gat_output}  (聚焦于信号)")
    

  3. 展示位置编码的益处。为一张图计算拉普拉斯特征向量编码,并展示结构相似的节点得到相似的编码。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 杠铃图:两个团通过一座桥连接
    n = 10
    A = jnp.zeros((n, n))
    # 团1:节点0-4
    for i in range(5):
        for j in range(i+1, 5):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # 团2:节点5-9
    for i in range(5, 10):
        for j in range(i+1, 10):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # 桥
    A = A.at[4,5].set(1).at[5,4].set(1)
    
    D = jnp.diag(A.sum(axis=1))
    L = D - A
    eigenvalues, eigenvectors = jnp.linalg.eigh(L)
    
    # 使用前3个非平凡特征向量作为位置编码
    pe = eigenvectors[:, 1:4]
    
    print("拉普拉斯位置编码:")
    for i in range(n):
        group = "团1" if i < 5 else "团2"
        bridge = " (桥节点)" if i in [4, 5] else ""
        print(f"  节点 {i} ({group}{bridge}): {pe[i]}")
    
    plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
    plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
    plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
                label="桥节点", zorder=5)
    plt.legend(); plt.grid(True)
    plt.title("拉普拉斯特征向量位置编码")
    plt.xlabel("特征向量1"); plt.ylabel("特征向量2")
    plt.show()