当前位置: 首页 > news >正文

PyG教程:MessagePassing基类

PyG教程:MessagePassing基类

  • 一、引言
  • 二、如何自定义消息传递网络
    • 1.构造函数
    • 2.propagate函数
    • 3.message函数
    • 4.aggregate函数
    • 5.update函数
  • 三、代码实战
    • 1.图数据定义
    • 2.实现GNN的消息传递过程
    • 3.完整代码
    • 4.完整代码的精简版本
  • 四、总结
    • 1.MessagePassing各个函数的执行顺序
    • 2.参考资料

一、引言

PyG框架中提供了一个消息传递基类torch_geometric.nn.MessagePassing,它实现了消息传递的自动处理,继承该类可以简单方便的构建自己的消息传播GNN。

二、如何自定义消息传递网络

要自定义GNN模型,首先需要继承MessagePassing类,然后重写如下方法:

  • message(...):构建要传递的消息;
  • aggregate(...):将从源节点传递过来的消息聚合到目标结点;
  • update(...):更新节点的消息。

上述方法并不是一定都要自定义,若MessagePassing类默认实现满足你的需求,则可以不重写。

1.构造函数

继承MessagePassing类后,在构造函数中可以通过super().__init__方法来向基类MessagePassing传递参数,来指定消息传递过程中的一些行为。MessagePassing类的初始化函数如下:
在这里插入图片描述
参数说明:

参数名参数说明
aggr消息传递中的消息聚合方式,常用的包括summeanminmaxmul等等。default: sum
flow消息传播的方向,其中source_to_targe表示从源节点到目标节点、target_to_source表示从目标节点到源节点。default:source_to_target
node_dim传播的维度,default:-2
decomposed_layers这个参数没用过,我也还不知道,后面会更新。

2.propagate函数

在具体介绍消息传递的三个相关函数之前,首先先介绍propagate函数,该函数是消息传递的启动函数,调用该函数后依次会执行messageaggregateudpate函数来完成消息的传递聚合更新。该函数的声明如下:
在这里插入图片描述
参数说明:

参数名参数说明
edge_index边索引
size这个参数目前我理解的不是很透彻,后面透彻了补一下
**kwargs构建、聚合和更新消息所需的额外数据,都可以传入propagate函数,这些参数可以在消息传递过程中的三个函数中接收。

该函数一般会传入edge_index和特征x

3.message函数

message函数是用来构建节点的消息的。传递给propagate函数的tensor可以映射到中心(target)节点邻居(source)节点上,只需要在相应变量名后加上_ior_j即可,通常称_i为中心(target)节点,称_j为邻居(source)节点。

source节点和target节点的关系:
在这里插入图片描述
message实现源码:
在这里插入图片描述

从源码的默认实现可以看出,message传递的消息就是邻居节点自身的特征向量。

示例:

def forward(self, data):out = self.propagate(edge_index, x=x)passdef message(self, x_i, x_j, edge_index_i, edge_index_j):pass

该例子中利用propagate函数传递了两个参数edge_indexx,则message函数可以根据propagate函数中的两个参数构造自己的参数,上述message函数中的构造参数为:

  • x_i:中心节点(target)的特征向量组成的矩阵,注意该矩阵与图节点的矩阵x是不同的;
  • x_j:邻居节点(source)的特征向量组成的矩阵;
  • edge_index_i:中心节点的索引;
  • edge_index_j:邻居节点的索引。

注意,若flow='source_to_target',则消息将由邻居节点传向中心节点,若flow='target_to_source'则消息将从中心节点传向邻居节点,默认为第一种情况

4.aggregate函数

消息聚合函数aggregate用来聚合来自邻居的消息,常用的包括summeanmaxmin等,可以通过super().__init__()中的参数aggr来设定。该函数的第一个参数为message函数的返回值。

  • aggr='sum' 表示 和聚合,它会对每个特征维度计算所有邻居节点的消息的总和。
  • aggr='mean' 表示 平均值值聚合,它会对每个特征维度计算所有邻居节点的消息的平均值。
  • aggr='max' 表示 最大值聚合,它会对每个特征维度选择所有邻居节点的消息中的最大值。
  • aggr='min' 表示 最小值聚合,它会对每个特征维度选择所有邻居节点的消息中的最小值。

5.update函数

update函数用来更新节点的消息,aggregate函数的返回值作为该函数的第一个参数。

默认实现:
在这里插入图片描述

从默认实现可以看出update函数没有进行任何的操作,只是将raggregate函数的返回值返回了而已。

实际写代码的过程中,我们也不会去重写这个方法,而是,在forward函数中调用完propagate(…)函数后编写代码,代替update函数的功能。

三、代码实战

假设我们设计一个GNN模型,其中消息传递过程用公式表示如下:
X i ( k ) = X i ( k − 1 ) + ∑ j ∈ N ( i ) X j ( k − 1 ) (1) X_i^{(k)} = X_i^{(k-1)} + \sum _{j\in {\mathcal {N(i)}}} X_j^{(k-1) }\tag {1} Xi(k)=Xi(k1)+jN(i)Xj(k1)(1)

  • message生成的消息就是中心节点的邻居节点的特征向量。
  • aggregaet聚合消息的方式是sum,即把所有邻居节点的特征向量加起来。
  • update更新中心节点的方式是:将聚合得到的消息和中心节点自身的特征向量相加。

1.图数据定义

我们有如下数据:

import torch
from torch_geometric.data import Dataedge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())

在这里插入图片描述

2.实现GNN的消息传递过程

class MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):out = self.propagate(data.edge_index, x=data.x)# out = out + x return outdef message(self, x_i, x_j, edge_index_i, edge_index_j):# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行return x_jdef aggregate(self, message, edge_index_i):# 这里只是写的样例,实际上一般不会重写这个方法,直接使用默认的就好了,只需要自己选择一下聚合的方式即可return super().aggregate(message, edge_index_i, dim_size=len(x))def update(self, aggregate, x):# 一般也不会重写这个方法的,update阶段可以在forward函数中调用完propagate(...)函数后编写代码。return x + aggregate

3.完整代码

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassingclass MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):out = self.propagate(data.edge_index, x=data.x)out = out + data.xreturn outdef message(self, x_i, x_j, edge_index_i, edge_index_j):# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行return x_j# def aggregate(self, message, edge_index_i):# 	return super().aggregate(message, edge_index_i, dim_size=len(x))# def update(self, aggregate, x):# 	return x + aggregateif __name__ == '__main__':edge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index.contiguous())myConv = MyConv()print(myConv(data))

4.完整代码的精简版本

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loopsclass MyConv(MessagePassing):def __init__(self):super().__init__(aggr='sum')def forward(self, data):edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.x))out = self.propagate(edge_index, x=data.x)return outif __name__ == '__main__':edge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index.contiguous())myConv = MyConv()print(myConv(data))

思考:大家可以根据上面讲解的细节,理解一下这个精简版本的代码的实现逻辑和过程。

四、总结

1.MessagePassing各个函数的执行顺序

在这里插入图片描述

2.参考资料

  • PyG: MessagePassing
  • PyG: Creating Message Passing Networks

http://www.mrgr.cn/news/78611.html

相关文章:

  • 解析类的泛型参数 Spring之GenericTypeResolver.resolveTypeArgument
  • Maven CMD命令
  • openssl生成ca证书
  • Instagram Reels与TikTok有什么区别?
  • 第四节:jsp内的request和response对象
  • 【大模型】基于LLaMA-Factory的模型高效微调
  • Java ConcurrentHashMap
  • HTTP 1
  • Java Collection
  • uniapp连接mqtt频繁断开原因和解决方法
  • 【组成原理】计算机硬件设计——ALU
  • Maven 配置
  • yolov8的深度学习环境安装(cuda12.4、ubuntu22.04)
  • Spring Boot使用JDK 21虚拟线程
  • 在shardingsphere执行存储过程
  • 机器学习实战:泰坦尼克号乘客生存率预测(数据处理+特征工程+建模预测)
  • vulnhub靶场之hackableⅡ
  • 【C语言】字符串左旋的三种解题方法详细分析
  • Jmeter进阶篇(29)AI+性能测试领域场景落地
  • Linux系统 进程
  • 三十二:网络爬虫的工作原理与应对方式
  • 记录学习《手动学习深度学习》这本书的笔记(一)
  • (Python)前缀和
  • OPTEE v4.4.0 FVP环境搭建(支持hafnium)
  • 【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器
  • Spring源码学习