无消息传递的图变换器中的图归纳偏差
人工智能咨询培训老师叶梓 转载标明出处
在处理小规模数据集时,图变换器的性能通常不尽如人意,特别是在需要明显的归纳偏好时。为了引入这些偏好,早期的图变换器一般会利用消息传递组件或位置编码。然而,依赖消息传递的图变换器在将研究成果应用到其他领域时遇到了难题,因为它们与其它领域的变换器有较大差异。针对这一挑战,来自麦吉尔大学、牛津大学工程科学系、麻省理工学院的CSAIL、MetaAI以及魁北克人工智能研究所(Mila)的研究人员们提出了一种新型的图变换器——图归纳偏差变换器(Graph Inductive bias Transformer,简称GRIT)。GRIT不依赖于显式的消息传递机制,而是通过三种关键设计决策来内嵌图归纳偏好,以此优化模型对图数据的处理能力。
方法
GRIT架构引入了一种新颖的灵活注意力机制和通用的相对位置编码方案,它不依赖任何显式的局部消息传递模块。这一设计包含三个核心决策:
-
学习随机游走相对位置编码:通过随机游走概率初始化的相对位置编码,相较于传统的最短路径距离编码,它具有更强的表达能力,能够捕捉更丰富的图传播矩阵信息。
-
灵活的注意力机制:该机制不仅能更新节点表示,还能更新节点对表示,从而更全面地捕获节点间的相对位置信息,并通过注意力层更新位置编码,进一步提升模型的表达能力。
-
整合节点度信息:为了保留节点的度信息,引入了自适应的度量缩放器,确保模型在每一层都能考虑到节点的连接丰富度。
通过结合随机游走相对位置编码(RRWP)和多层感知器(MLP),该方法能够近似最短路径距离或一般类别的图传播矩阵,从而证明了其强大的表达能力。这种编码方式能够捕获节点间多跳的相对位置信息,对深入理解图结构至关重要。
Figure 1和Figure 2的可视化显示,随着随机游走步数的增加,RRWP能够揭示更高阶的结构信息,如图中的团簇和星形模式。这表明RRWP能有效捕获图中的关键信息,并随着游走步数的增加,更好地突出社区结构,减少瓶颈。
在当前的自注意力机制设计中,通常基于节点级的位置编码和表示,但这并不足以完全捕获节点对之间的相对位置信息。因此,提出了一种新的方法来计算注意力分数,通过考虑节点对的相对表示,结合了一般条件层和GATv2的优点。
在每个变换器层中,模型更新节点表示 和节点对表示 ,这些表示最初使用初始节点特征和RRWP位置编码进行初始化。注意力计算过程如下:
其中,σ 是非线性激活函数,默认为ReLU;和是可学习的权重矩阵;⊙ 表示逐元素乘法;是带符号的平方根,通过减少大输入的幅度来稳定训练。
使用最近提出的Weisfeiler-Lehman类图同构测试(GD-WL)来证明,在变换器架构中,RRWP比传统的最短路径距离(SPD)具有更强的表达能力。
为了解决注意力机制在处理图结构数据时对节点度的不变性问题,引入了自适应的度量缩放器来维持度信息。在计算节点表示后,模型将度信息注入到节点表示中:
其中是节点i 的度, 是可学习的权重。这样做是为了在使用标准前馈网络(FFN)更新节点表示之前,先注入度量信息。
为了正确包含度信息,选择使用批量归一化而不是标准的层归一化,因为层归一化可能会抵消度量缩放器或总和聚合器的效果。
通过引入灵活的注意力机制和度信息注入,GRIT模型能够在不依赖显式消息传递的情况下,有效地处理图数据。这些机制的引入显著提高了模型对图结构的理解和表达能力,使得GRIT在多个图学习任务中取得了优异的性能。
想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。
评论留言“参加”或扫描微信备注“参加”,即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory。关注享粉丝福利,限时免费录播讲解。
LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。
实验
GRIT在五个来自Benchmarking GNNs的基准测试和两个Long-Range Graph Benchmark的测试中接受了评估。这些测试覆盖了节点分类、图分类和图回归等多种任务,特别注重图结构编码、节点聚类和长距离依赖性学习。测试的数据集包括ZINC-full图数据集(约25万个图)和PCQM4Mv2数据集(约370万个图)。
在基准测试中,GRIT与目前最先进的混合图变换器GraphGPS以及其他多种流行的图学习模型进行了比较,包括MPNNs、图变换器以及其它一些性能领先的图神经网络。
实验结果如Table 1和Table 2所示。表格展示了在不同数据集上使用不同模型的性能指标,包括平均绝对误差(MAE)和准确率(Accuracy)。在多数情况下,GRIT都获得了最优的平均性能,并在统计上显著优于其他模型。
在ZINC-full数据集的测试中,如Table 3所示,GRIT同样展现了其卓越的性能,与其他各种方法相比,GRIT取得了最佳的平均性能。
在PCQM4Mv2的大规模图回归基准测试中,如Table 4所示,GRIT与多种MPNNs和图变换器进行了比较。尽管GRIT使用的可学习参数更少,但其性能与GraphGPS和Graphormer相当。
消融实验如Table 5所示。实验结果表明,移除度量缩放器、更新机制的RRWP、将全局注意力替换为稀疏注意力、替换度量编码和注意力机制,以及将RRWP替换为RWSE或SPDPE等操作,都会导致性能下降。
对RRWP的参数K进行的敏感性分析结果显示,如Table 6所示,对于许多K值的选择,GRIT方法都是最先进或接近最先进,除了像K=2这样不合理的选择。
合成实验研究了GRIT模型的注意力模块是否能够模仿一般类别的图传播矩阵,如Table 7所示。实验结果表明,与其他图变换器相比,GRIT显著优于其他基线,能够更好地近似一般类别的图传播。
实验证明了GRIT模型能够在不依赖消息传递的情况下,通过整合图归纳偏差,在多个图数据集上实现最前沿的性能。
https://arxiv.org/pdf/2305.17589v1
GitHub - LiamMa/GRIT: This is an official implementation for "GRIT: Graph Inductive Biases in Transformers without Message Passing".