当前位置: 首页 > news >正文

DGL库之HGTConv的使用

DGL库之HGTConv的使用

  • 论文地址和异构图构建教程
  • HGTConv语法格式
  • HGTConv的使用

论文地址和异构图构建教程

论文地址:https://arxiv.org/pdf/2003.01332
异构图构建教程:异构图构建
异构图转同构图:异构图转同构图

HGTConv语法格式

dgl.nn.pytorch.conv.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes, dropout=0.2, use_norm=False)

参数说明:

  • in_size (int): 输入节点特征的大小。
  • head_size (int): 输出头的大小。输出节点特征的大小为 head_size * num_heads。
  • num_heads (int): 头的数量。输出节点特征的大小为 head_size * num_heads。
  • num_ntypes (int): 节点类型的数量。
  • num_etypes (int): 边类型的数量。
  • dropout (可选, float): dropout 比率,用于防止过拟合。
  • use_norm (可选, bool): 如果为 True,则在输出节点特征上应用层归一化。
forward(g, x, ntype, etype, *, presorted=False)

参数说明:

  • g (DGLGraph): 输入的图对象。

  • x (torch.Tensor): 一个 2D 张量,表示节点特征。其形状应为 (num_nodes, in_size),num_nodes 是节点数量,in_size 是输入特征的维度。

  • ntype (torch.Tensor): 一个 1D 整数张量,表示节点类型。其形状应为 (num_nodes,),对应每个节点的类型索引。

  • etype (torch.Tensor): 一个 1D 整数张量,表示边类型。其形状应为 (num_edges,),对应每条边的类型索引。

  • presorted (bool, 可选): 指示输入图的节点和边是否已经按照类型排序。如果输入图是预排序的,则前向传播可能会更快。通过调用 to_homogeneous()创建的图会自动满足此条件。也可以使用 reorder_graph() 方法手动重新排序节点和边。

返回值:

  • 返回的新节点特征: 返回的特征是一个 2D 张量,其形状为 (num_nodes, head_size * num_heads),表示经过HGTConv 处理后的新节点特征,返回的张量类型为 torch.Tensor。

HGTConv的使用

使用的异构图如下:
在这里插入图片描述
在使用HGTConv时,一定要使用dgl.to_homogeneous将异构图转为同构图,否则不能使用,代码如下:

import dgl
import torch
import torch.nn as nn
import dgl.nn.pytorch# 定义一个简单的异构图
def create_hetero_graph():# 定义两个类型的节点:drug(药物)和 disease(疾病)data_dict = {('drug', 'd_interacts', 'drug'): (torch.tensor([0, 1]), torch.tensor([1, 2])),  # 药物间的相互作用('drug', 'g_interacts', 'gene'): (torch.tensor([0, 1]), torch.tensor([2, 3])),  # 药物与基因间的相互作用('drug', 'treats', 'disease'): (torch.tensor([1]), torch.tensor([2]))           # 药物与疾病的关系}# 创建一个异构图hetero_graph = dgl.heterograph(data_dict)# 设置节点和边的特征hetero_graph.nodes['drug'].data['h'] = torch.ones(3, 320)  # 假设药物特征是320维的hetero_graph.nodes['disease'].data['h'] = torch.zeros(3, 320)  # 假设疾病特征是320维的hetero_graph.nodes['gene'].data['h'] = torch.ones(4, 320)  # 假设基因特征是320维的return hetero_graph# 定义一个HGT模型类
class HGTModel(nn.Module):def __init__(self, in_dim, out_dim, num_heads, num_layers, num_node_types, num_edge_types, dropout=0.2):super(HGTModel, self).__init__()# 使用 dgl.nn.pytorch.conv.HGTConv 初始化 HGT 卷积层self.layers = nn.ModuleList()  # 创建一个空的层列表for _ in range(num_layers):layer = dgl.nn.pytorch.conv.HGTConv(in_dim,  # 输入维度out_dim,  # 输出维度num_heads,  # 注意力头的数量num_node_types,  # 节点类型数量num_edge_types,  # 边类型数量dropout=dropout  # dropout比率)self.layers.append(layer)  # 将层添加到列表中def forward(self, g):with g.local_scope():  # 创建一个局部作用域,‌确保对图的操作不会影响原始图。‌for layer in self.layers:# 使用HGTConv层进行卷积操作h = layer(g, g.ndata['h'], g.ndata['_TYPE'], g.edata['_TYPE'], presorted=True)g.ndata['h'] = h  # 更新节点特征return g.ndata['h']  # 返回最后一层的节点特征# 创建一个异构图
hetero_graph = create_hetero_graph()print('异构图为:\n',hetero_graph)  # 输出异构图的信息
# 将异构图转换为同构图
homogeneous_graph = dgl.to_homogeneous(hetero_graph, ndata=['h'])
print(f"节点特征矩阵为:\n{homogeneous_graph.ndata['h']}")  # 打印节点特征的类型# 创建模型并移动到 CPU 设备
hgt_model = HGTModel(in_dim=320, out_dim=80, num_heads=4, num_layers=2,num_node_types=3, num_edge_types=3, dropout=0.3).to(torch.device('cpu'))# 前向传播
output_features = hgt_model(homogeneous_graph)print("更新后的特征:\n", output_features)  # 输出特征的形状

结果如下:
在这里插入图片描述


http://www.mrgr.cn/news/47520.html

相关文章:

  • 每日算法Day14【删除二叉搜索树中的节点、修剪二叉搜索树、将有序数组转换为二叉搜索树、把二叉搜索树转换为累加树】
  • 相机和激光雷达的外参标定 - 无标定板版本
  • 为AI聊天工具添加一个知识系统 开发环境准备
  • Linux 高级路由 —— 筑梦之路
  • vscode支持ssh远程开发
  • 一文读懂「LoRA」:大型语言模型的低秩适应
  • JavaGuide(3)
  • IDM6.42下载器最新版本,提速你的网络生活!
  • Python的输入输出函数
  • 心觉:开发潜意识的详细流程和步骤是什么
  • 跟《经济学人》学英文:2024年10月05日这期 Workouts for the face are a growing business
  • 从0开始深度学习(7)——线性回归的简洁实现
  • 等保测评:如何建立有效的网络安全监测系统
  • 代码随想录算法训练营第四十六天 | 647. 回文子串,516.最长回文子序列
  • ssm基于Javaee的影视创作论坛的设计与实现
  • 论文《OneLLM:One Framework to Align All Modalities with Language》
  • [SAP ABAP] LIKE TABLE OF
  • netty详细说明ByteBuf的使用
  • 五、创建型(建造者模式)
  • 怎么快速申请CNAS认证
  • 【C++篇】虚境探微:多态的流动诗篇,解锁动态的艺术密码
  • PHP静态化和伪静态如何实现的
  • 计算机网络803-(4)网络层
  • 电动牙刷拆解学习
  • threejs-基础材质设置
  • EcoVadis认证内容有哪些?EcoVadis认证申请流程?