3D图网络¶
3D图网络将GNN扩展到具有空间几何的数据,其中必须正确处理旋转和平移。本文涵盖几何图、SE(3)/E(n)-等变性、SchNet、DimeNet、EGNN、张量场网络,以及在分子性质预测、蛋白质结构、材料科学和药物发现中的应用——这些架构能够从3D物理世界中学习。
-
文件3和文件4中的GNN作用于抽象图:节点具有特征,边编码连接性,但没有3D空间的概念。社交网络图没有几何结构。但GNN许多最有影响力的应用涉及存在于物理3D空间中的数据:分子、蛋白质、晶体、点云。对于这些数据,节点的空间位置携带了抽象GNN所忽略的关键信息。
-
挑战在于3D数据具有几何对称性(文件1):旋转一个分子不会改变其性质,平移同样如此。一个3D GNN必须尊重这些对称性。当你旋转分子时,能量预测发生变化,这在物理上就是错误的。
几何图¶
-
几何图是嵌入3D空间的图。每个节点 \(i\) 除了其特征向量 \(\mathbf{h}_i\) 外,还有一个位置 \(\mathbf{r}_i \in \mathbb{R}^3\)。边可以由空间接近性(连接距离 \(r_{\text{cut}}\) 内的节点)定义,而不是由明确的化学键定义。
-
对于分子,几何图以原子为节点(具有特征:元素类型、电荷等),以化学键为边。3D位置 \(\mathbf{r}_i\) 是原子的坐标,由量子力学或实验测量(X射线晶体学、冷冻电镜)确定。
-
对于点云(来自LiDAR或3D扫描仪,第8章和第11章),每个点是一个节点,具有位置和可选的额外特征(颜色、强度)。边连接附近的点,形成k-最近邻(kNN)图或半径图。
-
消息传递中关键的几何量包括:
-
原子间距离:\(d_{ij} = \|\mathbf{r}_i - \mathbf{r}_j\|\)。距离对于旋转和平移具有不变性。具有相同原子间距离的两个分子,无论取向如何,形状都是相同的。
-
键角:节点 \(i\) 处向量 \(\mathbf{r}_j - \mathbf{r}_i\) 与 \(\mathbf{r}_k - \mathbf{r}_i\) 之间的夹角 \(\theta_{ijk}\)。角度捕获了成对距离之外的局部几何形状。
-
二面角(扭转角):由 \((i, j, k)\) 和 \((j, k, l)\) 确定的平面之间的夹角 \(\phi_{ijkl}\)。二面角捕捉了结构在3D中如何扭转,这对于蛋白质主链几何形状至关重要。
-
相对位置向量:\(\mathbf{r}_{ij} = \mathbf{r}_j - \mathbf{r}_i\)。这些量是平移不变的,但不是旋转不变的。使用它们需要等变(而不仅仅是不变)的架构。
-
SE(3) 和 E(n) 等变性¶
-
3D物理数据的对称群是欧几里得群 \(E(3)\),由所有旋转、反射和平移组成。子群 \(SE(3)\)(特殊欧几里得群)包括旋转和平移,但不包括反射。
-
一个3D GNN 应该满足:
- 对于标量输出(能量、结合亲和力)是平移不变的:将所有原子平移相同的向量不应改变预测。
- 对于标量输出是旋转不变的:旋转分子不应改变其能量。
- 对于向量/张量输出(力、偶极矩)是旋转等变的:旋转分子应使预测的力向量以相同的方式旋转。
- 形式化地,对于标量预测 \(f\) 和旋转 \(R \in SO(3)\):
- 对于向量预测 \(\mathbf{F}\):
-
这些约束直接体现了文件1中的不变性/等变性框架,现在具体应用于3D旋转和平移群。
-
两种设计方法:
- 不变架构:仅使用不变的几何特征(距离、角度)作为消息传递的输入。内部表示是标量(不变的)。简单高效,但无法在不破坏对称性的情况下产生向量输出。
- 等变架构:在整个网络中保持向量(以及更高阶张量)表示,确保每一层都是等变的。表达能力更强,可以自然地预测向量和张量,但更复杂。
SchNet:基于距离的消息传递¶
-
SchNet(Schütt 等人,2017)是基础性的不变3D GNN。其关键创新是连续滤波器卷积:不是使用固定的边类型集(如分子GNN中的键类型),而是直接从原子间距离生成消息滤波器。
-
距离 \(d_{ij}\) 首先通过径向基函数(RBF)扩展为一个特征向量:
-
每个基函数是一个以 \(\mu_k\) 为中心、宽度为 \(\gamma_k\) 的高斯函数。这类似于距离的可学习位置编码:连续的距离被映射到一个高维特征空间,网络可以在其中学习距离依赖的相互作用。中心 \(\mu_k\) 通常从0到截断半径均匀分布。
-
SchNet 中从 \(j\) 到 \(i\) 的消息为:
-
其中 \(W_{\text{filter}}\) 是一个将RBF扩展映射到滤波器向量的MLP,\(\odot\) 是逐元素乘法(Hadamard积,第2章)。滤波器依赖于距离,因此邻近原子与远处原子的相互作用方式不同。逐元素乘法类似于门控机制(第6章):依赖于距离的滤波器控制每个特征维度有多少信息能够通过。
-
由于SchNet只使用距离(不变的),整个模型自动对旋转和平移保持不变。除了这个设计选择外,不需要对对称性进行特殊处理。
DimeNet 和 SphereNet:角度与二面角¶
-
仅凭距离无法完全确定3D结构。两种不同的分子构象可能具有相同的成对距离,但键角不同(这就是“距离几何歧义性”问题)。DimeNet(Gasteiger 等人,2020)将键角纳入消息传递中。
-
DimeNet 使用定向消息传递:消息沿着有向边流动,边 \((j \to i)\) 上的消息受边 \((k \to j)\) 和 \((j \to i)\) 之间角度的影响:
-
角度 \(\theta_{kji}\) 使用球贝塞尔函数和球谐函数(球面上角度信息的自然基,类似于距离的RBF)进行扩展。这使得模型在保持不变性的同时能够获取方向信息。
-
SphereNet(Liu 等人,2022)进一步包含二面角 \(\phi_{lkji}\),捕获完整的3D扭转结构。层次结构如下:
- 距离 → 捕捉成对接近程度
- 角度 → 捕捉局部几何形状(弯曲 vs 直线)
- 二面角 → 捕捉3D扭转(对蛋白质主链、药物结合至关重要)
-
每个级别的增加都以计算复杂度为代价带来更高的几何分辨率(距离为 \(O(|E|)\),角度为 \(O(|E| \cdot k)\),二面角为 \(O(|E| \cdot k^2)\),其中 \(k\) 是平均度)。
E(n) 等变 GNN (EGNN)¶
-
EGNN(Satorras 等人,2021)采用等变方法:不是仅使用不变特征,而是在每一层中同时更新节点特征和节点位置,始终保持等变性。
-
EGNN 对节点 \(i\) 的更新:
-
关键在于位置更新:节点位置通过相对位置向量 \((\mathbf{r}_i - \mathbf{r}_j)\) 的加权和来调整。权重来自消息函数 \(\phi_r\),该函数仅依赖于不变量(特征和距离)。这种构造可证明是等变的:如果所有输入位置被旋转 \(R\),则所有输出位置也被相同的 \(R\) 旋转。
-
EGNN 的精妙之处在于,它不需要显式使用球谐函数或不可约表示就能实现等变性。相对位置向量携带方向信息,不变的消息函数控制如何使用这些方向信息。
-
简洁性也有代价:EGNN 仅使用向量表示(阶数为1)。若不扩展,它无法表示高阶张量,如四极矩或应力张量。
张量场网络与高阶表示¶
-
张量场网络(Thomas 等人,2018)及其后继模型(SE(3)-Transformer、MACE、Equiformer)使用旋转群不可约表示的完整机制来构建等变层。
-
在表示论中(联系第2章的线性代数),3D中的旋转可以分解为由整数阶 \(\ell\) 刻画的不可约分量:
- \(\ell = 0\):标量(1个分量,不变)。能量、电荷。
- \(\ell = 1\):向量(3个分量,像位置向量一样旋转)。力、偶极矩。
- \(\ell = 2\):二阶对称无迹张量(5个分量)。四极矩、应力张量。
- 更高 \(\ell\):捕捉日益复杂的角结构。
-
这些被称为球张量,它们在旋转 \(R\) 下通过 Wigner-D 矩阵 \(D^\ell(R)\) 变换:标量不变,向量通过 \(R\) 旋转,二阶张量通过更复杂的矩阵旋转。
-
使用球张量的等变消息传递利用 Clebsch-Gordan 张量积来组合不同阶的特征:
-
Clebsch-Gordan 系数 \(C\) 是固定的数学常数,确保张量积是等变的。这是矩阵乘法的 SO(3)-等变类比。
-
MACE(Batatia 等人,2022)使用高阶消息(多个邻居特征的乘积)以更少的消息传递层实现高精度。通过构建多体相互作用(距离对应2体,角度对应3体,张量积对应多体),MACE 高效地捕获了复杂的原子间相互作用。
-
Equiformer(Liao & Smidt,2023)将等变球张量特征与 Transformer 注意力机制(文件4)相结合,创建了 SE(3)-等变图 Transformer。注意力分数从不变量特征计算,而值聚合则在等变张量特征上进行。
应用¶
-
分子性质预测:给定分子的3D结构,预测其能量、力、偶极矩、HOMO-LUMO 能隙、毒性、溶解度等性质。这是3D GNN 最成熟的应用。在量子化学数据集(QM9、OC20)上训练的模型,在许多性质上达到了化学精度,使得对数百万候选分子进行虚拟筛选成为可能。
-
分子动力学加速:用量子力学(密度泛函理论,DFT)计算原子间力极其昂贵(对于 \(n\) 个电子,复杂度为 \(O(n^3)\))。经过训练来预测力的3D GNN 可以在分子动力学模拟中替代 DFT,实现 \(10^3\)–\(10^6\) 倍的加速,同时保持接近 DFT 的精度。这使得模拟更大体系和更长时间尺度成为可能,揭示传统方法无法观察到的现象。
-
蛋白质结构:蛋白质是由氨基酸组成的链,折叠成复杂的3D结构。蛋白质主链是一个几何图,其中节点是残基,边连接空间上邻近的残基。3D GNN 用于蛋白质功能预测、结合位点识别和蛋白质设计(逆折叠:给定目标结构,预测氨基酸序列)。AlphaFold 使用几何和基于图的推理从序列预测蛋白质结构。
-
材料科学与催化:晶体材料具有周期性的3D结构。GNN 对重复单元晶胞进行建模,并预测材料性质:带隙、形成能、机械强度。开放催化剂项目(OC20/OC22)对用于预测催化表面吸附能的 GNN 进行基准测试,加速寻找可再生能源的新催化剂。
-
药物发现:3D GNN 预测药物分子如何与靶点蛋白结合。结合亲和力取决于药物与蛋白结合口袋之间的3D形状互补性和化学相互作用。像 DiffDock 这样的模型使用等变 GNN 结合扩散模型(第8章)来预测结合姿态(药物在蛋白口袋中的3D取向)。
图生成¶
-
上述所有架构都是分析现有图。图生成则是创建新的图:设计具有所需性质的分子,生成用于测试的合成社交网络,或提出新的蛋白质结构。这是图级别预测的生成式对应任务。
-
挑战在于图是离散的、大小可变且组合爆炸的。生成一个图意味着决定创建多少个节点、它们具有什么特征以及哪些节点对之间相连。可能的图空间随节点数的增加呈超指数增长。
-
自回归生成一次生成一个节点(或一条边)。GraphRNN(You 等人,2018)顺序生成图:一个 RNN 维护一个状态,每一步生成一个新节点,并决定将其连接到哪些现有节点。生成顺序给本来无序的图强加了一个人为的序列,但 BFS 顺序通过保持最近生成的节点相关性而有所帮助。
-
基于VAE的生成将图编码到连续的潜在空间中(使用 GNN 编码器),然后从采样的潜在向量解码生成新图。GraphVAE 一次性生成概率性邻接矩阵 \(\hat{A} \in [0, 1]^{n \times n}\),但其规模为 \(O(n^2)\),并产生必须进行阈值处理的稠密输出。潜在空间允许平滑插值:在两个分子嵌入之间移动可以生成化学上有效的中间结构。
-
基于扩散的生成将扩散框架(第8章)应用于图。前向过程逐渐向节点特征和边结构添加噪声。反向过程学习去噪,从噪声中生成有效图。DiGress(Vignac 等人,2023)将离散扩散应用于节点类型和边类型,自然地处理了图数据的类别性质。
-
对于分子生成,关键约束是化学有效性:生成的分子必须遵守价键规则(碳形成4个键,氧形成2个键,等等)。像 结树 VAE (JT-VAE) 这样的方法将分子分解为有效的子结构(环、链、官能团),并通过组装这些构建块来生成,从而通过构造保证有效性。
-
目标导向生成针对特定性质进行优化:生成一个与靶点蛋白具有高结合亲和力、低毒性和良好溶解度的分子。这在一个循环中结合了图生成与性质预测(使用3D GNN作为性质评估器):生成 → 评估 → 改进。强化学习(第6章)或贝叶斯优化引导对化学空间的搜索。
-
DiffDock(Corso 等人,2023)使用 SE(3)-等变扩散来预测药物分子如何对接进入蛋白质结合口袋。该模型通过从随机放置中去噪来生成3D结合姿态(药物相对于蛋白质的位置和取向),结合了本文中的3D等变网络和第8章中的扩散框架。
编码任务(使用 CoLab 或 notebook)¶
-
构建一个简单的基于原子间距离的不变3D消息传递层。将其应用于一个小分子(水:H-O-H),并验证输出对旋转具有不变性。
import jax import jax.numpy as jnp # 水分子:O在原点,两个H原子 positions = jnp.array([[0.0, 0.0, 0.0], # O [0.96, 0.0, 0.0], # H1 [-0.24, 0.93, 0.0]]) # H2 # 节点特征:[原子序数] features = jnp.array([[8.0], [1.0], [1.0]]) # 计算成对距离(不变量) def pairwise_distances(pos): diff = pos[:, None, :] - pos[None, :, :] return jnp.sqrt(jnp.sum(diff**2, axis=-1) + 1e-8) # 简单的基于距离的消息传递 def invariant_message_pass(features, positions): dists = pairwise_distances(positions) # 4个中心的RBF扩展 centres = jnp.array([0.5, 1.0, 1.5, 2.0]) rbf = jnp.exp(-5.0 * (dists[:, :, None] - centres[None, None, :]) ** 2) # 消息:由距离依赖的滤波器加权的特征 messages = jnp.einsum("ij,jd->id", rbf.sum(axis=-1), features) return messages output1 = invariant_message_pass(features, positions) # 绕z轴旋转分子90度 R = jnp.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float) rotated_positions = (R @ positions.T).T output2 = invariant_message_pass(features, rotated_positions) print(f"原始输出:\n{output1}") print(f"\n旋转后输出:\n{output2}") print(f"\n不变性: {jnp.allclose(output1, output2, atol=1e-5)}") -
计算三个原子之间的键角,并验证其对旋转具有不变性。
import jax.numpy as jnp def bond_angle(r_i, r_j, r_k): """计算节点j处边j->i和j->k之间的夹角。""" v1 = r_i - r_j v2 = r_k - r_j cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2)) return jnp.arccos(jnp.clip(cos_angle, -1, 1)) # 三个原子 r1 = jnp.array([1.0, 0.0, 0.0]) r2 = jnp.array([0.0, 0.0, 0.0]) r3 = jnp.array([0.0, 1.0, 0.0]) angle_original = bond_angle(r1, r2, r3) print(f"原始角度: {jnp.degrees(angle_original):.1f}°") # 应用随机旋转 R = jnp.array([[0.36, 0.48, -0.80], [-0.80, 0.60, 0.00], [0.48, 0.64, 0.60]]) r1_rot, r2_rot, r3_rot = R @ r1, R @ r2, R @ r3 angle_rotated = bond_angle(r1_rot, r2_rot, r3_rot) print(f"旋转后角度: {jnp.degrees(angle_rotated):.1f}°") print(f"不变性: {jnp.allclose(angle_original, angle_rotated, atol=1e-4)}") -
演示等变位置更新(EGNN风格)。使用距离加权的相对向量更新节点位置,并验证等变性。
import jax import jax.numpy as jnp def egnn_position_update(positions, features): """简单的EGNN风格等变位置更新。""" n = positions.shape[0] new_positions = jnp.zeros_like(positions) for i in range(n): shift = jnp.zeros(3) for j in range(n): if i != j: r_ij = positions[i] - positions[j] d_ij = jnp.linalg.norm(r_ij) # 基于距离的权重(简单形式:逆距离) weight = 1.0 / (d_ij + 1.0) # 按特征相似度缩放 feat_sim = jnp.dot(features[i], features[j]) shift = shift + weight * feat_sim * r_ij new_positions = new_positions.at[i].set(positions[i] + 0.1 * shift) return new_positions # 3个原子 pos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) feat = jnp.array([[1.0, 0.5], [0.5, 1.0], [0.8, 0.3]]) # 更新位置 pos_new = egnn_position_update(pos, feat) # 现在旋转输入,进行更新,检查输出是否一致地旋转 R = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) pos_rot = (R @ pos.T).T pos_new_from_rot = egnn_position_update(pos_rot, feat) # 应该等于先旋转原始输出 pos_new_then_rot = (R @ pos_new.T).T print(f"先更新后旋转:\n{jnp.round(pos_new_then_rot, 4)}") print(f"\n先旋转后更新:\n{jnp.round(pos_new_from_rot, 4)}") print(f"\n等变性: {jnp.allclose(pos_new_then_rot, pos_new_from_rot, atol=1e-4)}")