当前位置: 首页 > news >正文

图神经网络实战(20)——时空图神经网络

图神经网络实战(20)——时空图神经网络

    • 0. 前言
    • 1. 动态图
    • 2. 预测网络流量
    • 3. EvolveGCN 架构
      • 3.1 EvolveGCN-H
      • 3.2 EvolveGCN-O
      • 3.3 模型选用技巧
    • 4. 构建 EvolveGCN
      • 4.1 数据集分析
      • 4.2 模型构建与训练
    • 小结
    • 系列链接

0. 前言

在经典图神经网络 (Graph Neural Networks, GNN) 中,我们只考虑了边和特征不会发生变化的图数据。然而,在现实世界中,多数应用充满动态性。例如,在社交网络中,人们会关注或取关其他用户,个人资料也会随着时间的推移而发生变化。这种动态性无法用经典 GNN 架构来表示。必须嵌入一个新的时间维度,将静态图转化为动态图 (dynamic graph)。然后,这些动态图将作为一类新 GNN 的输入——时空图神经网络 (Temporal Graph Neural Networks, TGNN,或 Spatio-Temporal GNN)。
在本节中,我们将介绍两种包含时空信息的动态图,并重点关注时间序列预测,这也是时空 GNN 的主要应用领域。然后介绍网络流量预测应用,利用时间信息来改进结果,获得可靠预测。

1. 动态图

动态图 (dynamic graph) 和时空图神经网络 (Temporal Graph Neural Networks, TGNN) 开启了各种新应用,如交通和网络流量预测、动作分类、流行病预测、链接预测、电力系统预测等。时空预测十分契合动态图数据,因为我们可以利用历史数据来预测系统的未来行为。
在本节中,我们将专注于具有时间维度的图,可以分为两类:

  • 带有时间信号的静态图: 基础图并未改变,但特征和标签会随时间变化
  • 具有时间信号的动态图: 图的拓扑结构、特征和标签会随时间变化

在第一种情况下,图的拓扑结构是静态的。例如,它可以表示城市之间的交通网络,用于交通预测:特征随时间变化,但连接保持不变。
在第二种情况下,节点和连接是动态的。例如,它可以表示社交网络,其中用户之间的链接可能随时间出现或消失。这种变体更通用,但实现起来也更加困难。

接下来,我们将了解如何使用 PyTorch Geometric Temporal 处理具有时间信号的图数据。

2. 预测网络流量

在本节中,我们将使用时空图神经网络 (Temporal Graph Neural Networks, TGNN) 预测维基百科文章的流量(解决具有时间信号的静态图问题)。这一回归任务在图卷积网络一节中已有介绍,但使用的是没有时间信号的静态数据集进行流量预测,模型并没有关于之前历史时刻的实例信息,因此模型无法了解当前流量是增加还是减少。在本节中,我们可以改进这一模型,使其包含有关之前历史时刻的实例信息。
我们将首先介绍 TGNN 架构及其两个变体,然后使用 PyTorch Geometric Temporal 进行实现。

3. EvolveGCN 架构

在本节中,我们将使用 EvolveGCN 架构预测维基百科文章的流量。该架构由 Pareja 等人于 2019 年提出,是图神经网络 (Graph Neural Networks, GNN) 与循环神经网络 (Recurrent Neural Network, RNN) 的结合。之前的方法,如图卷积递归网络等,是将 RNN 与图卷积算子结合起来计算节点嵌入。而 EvolveGCN 则是将 RNN 应用于图卷积网络 (Graph Convolutional Network, GCN) 参数本身。顾名思义,GCN 会随着时间的推移而不断变化,从而产生相关的时空节点嵌入,此过程如下所示:

EvolveGCN

EvolveGCN 架构有两种变体:

  • EvolveGCN-H:递归神经网络同时考虑之前时刻的 GCN 参数和当前的节点嵌入值
  • EvolveGCN-O:递归神经网络只考虑之前时刻的 GCN 参数

3.1 EvolveGCN-H

EvolveGCN-H 通常使用门控递归单元 (Gated Recurrent Unit, GRU),而非普通 RNNGRU 是长短期记忆 (Long Short-Term Memory, LSTM) 单元的简化版本,能以较少的参数实现类似的性能。它由重置门、更新门和单元状态组成。在 EvolveGCN-H 架构中,GRU 在时间 t t t 更新 GCN l l l 层的权重矩阵:
W t ( l ) = G R U ( H t ( l ) , W t − 1 ( l ) ) W_t^{(l)}=GRU (H_t^{(l)},W_{t-1}^{(l)}) Wt(l)=GRU(Ht(l),Wt1(l))
其中, H t ( l ) H_t^{(l)} Ht(l) 表示第 l l l 层在时间 t t t 产生的节点嵌入, W t − 1 ( l ) W_{t-1}^{(l)} Wt1(l) 是第 l l l 层在上一个时间步的权重矩阵。
由此产生的 GCN 权重矩阵将用于计算下一层的节点嵌入:
H t ( l + 1 ) = G C N ( A t , H t ( l ) , W t ( l ) ) = D ~ − 1 2 A ~ T D ~ − 1 2 H t ( l ) W t ( l ) T \begin{aligned} H_t^{(l+1)}=&GCN(A_t,H_t^{(l)},W_t^{(l)}) \\ =&\widetilde D^{-\frac 1 2}\widetilde A^T\widetilde D^{-\frac 1 2}H_t^{(l)}W_t^{{(l)}^T} \end{aligned} Ht(l+1)==GCN(At,Ht(l),Wt(l))D 21A TD 21Ht(l)Wt(l)T
其中, A ~ \widetilde A A 是包括自循环的邻接矩阵, D ~ \widetilde D D 是带有自循环的度矩阵。可以用下图总结上述步骤。

EvolveGCN-H

EvolveGCN-H 可以通过接收两个扩展的 GRU 来实现:

  • 输入和隐藏状态是矩阵而非向量,以正确存储 GCN 权重矩阵。
  • 输入的列维度必须与隐藏状态的列维度一致,因此需要对节点嵌入矩阵 H t ( l ) H_t^{(l)} Ht(l) 进行汇总,只保留合适数量的列数

3.2 EvolveGCN-O

EvolveGCN-O 变体不需要 EvolveGCN-H 中所用扩展,EvolveGCN-O 是基于 LSTM 网络来模拟输入输出关系的。不需要为 LSTM 提供隐藏状态,因为 LSTM 已经包含了一个可以记忆先前值的单元。这种机制简化了更新步骤,具体步骤如下:
W t ( l ) = L S T M ( W t − 1 ( l ) ) W_t^{(l)}=LSTM(W_{t-1}^{(l)}) Wt(l)=LSTM(Wt1(l))
生成的 GCN 权重矩阵将以同样的方式用于生成下一层的节点嵌入:
H t ( l + 1 ) = G C N ( A t , H t ( l ) , W t ( l ) ) = D ~ − 1 2 A ~ T D ~ − 1 2 H t ( l ) W t ( l ) T \begin{aligned} H_t^{(l+1)}=&GCN(A_t,H_t^{(l)},W_t^{(l)}) \\ =&\widetilde D^{-\frac 1 2}\widetilde A^T\widetilde D^{-\frac 1 2}H_t^{(l)}W_t^{{(l)}^T} \end{aligned} Ht(l+1)==GCN(At,Ht(l),Wt(l))D 21A TD 21Ht(l)Wt(l)T
由于时间维度完全依赖于 LSTM 网络,因此这种实现方式更为简单。下图显示了 EvolveGCN-O 如何更新权重矩阵 W t − 1 ( l ) W_{t-1}^{(l)} Wt1(l) 并计算节点嵌入 H t ( l + 1 ) H_t^{(l+1)} Ht(l+1)

EvolveGCN-O

3.3 模型选用技巧

在实际应用中,具体采用哪种变体模型往往取决于数据:

  • 当节点特征至关重要时,EvolveGCN-H 效果更好,因为它的 RNN 明确包含了节点嵌入
  • 当图结构起重要作用时,EvolveGCN-O 效果更好,因为它更关注拓扑变化

需要注意的是,这些选择标准往往是理论性的,因此在应用时同时测试这两种变体能够得到更好模型。接下来,我们将在网络流量预测中实现这两个模型,以对比不同模型在实际应用的性能差异。

4. 构建 EvolveGCN

4.1 数据集分析

在本节中,为了对具有时间信号的静态图上执行网络流量预测,我们将使用 WikiMaths 数据集,该数据集由 1,068 篇文章组成,以节点表示。节点特征对应于过去每天的访问量(默认为八个特征)。边带有权重,权重表示从源页面到目标页面的链接数量。我们的目标是预测 2019316 日至 2021315 日期间这些维基百科页面的每日用户访问量,共计 731 个快照。每个快照都是一个图,描述了系统在某一特定时间的状态。
Gephi 制作的 WikiMaths 表示如下所示,其中节点的大小和颜色与它们的连接数成正比。

WikiMaths 数据集

PyTorch Geometric 本身并不支持带有时间信号的静态或动态图,但可以通过使用名为 PyTorch Geometric Temporal 的扩展解决此问题,同时也实现了各种时空图神经网络 (Temporal Graph Neural Networks, TGNN) 层。在本节中,我们将使用 PyTorch Geometric Temporal 库来简化代码实现并专注于模型应用。

(1)shell 中使用 pip 安装 PyTorch Geometric Temporal 库:

pip install torch-geometric-temporal

(2) 导入 WikiMaths 数据集加载器 WikiMathDatasetLoader、使用 temporal_signal_split 进行时间感知的训练-测试数据集分割,以及 GNNEvolveGCNH

from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.nn.recurrent import EvolveGCNHimport pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

(3) 加载 WikiMaths 数据集,它是一个 StaticGraphTemporalSignal 对象。在该对象中,dataset[0] 描述的是 t = 0 时的图(在此上下文种也称快照),dataset[500] 描述的是 t = 500 时的图。创建一个训练-测试分割集,比例为 0.5,训练集由较早时间段的快照组成,而测试集则包含较晚时间段的快照:

dataset = WikiMathsDatasetLoader().get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)
print(dataset[0])
print(dataset[500])# Data(x=[1068, 8], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
# Data(x=[1068, 8], edge_index=[2, 27079], edge_attr=[27079], y=[1068])

(4) 图是静态的,因此节点和边的维度不会发生变化。但是,这些张量中包含的值却不同。很难可视化 1068 个节点中每个节点的值。为了更好地理解这个数据集,我们可以计算每个快照的平均值和标准差,移动平均值也有助于平滑短期波动:

mean_cases = [snapshot.y.mean().item() for snapshot in dataset]
std_cases = [snapshot.y.std().item() for snapshot in dataset]
df = pd.DataFrame(mean_cases, columns=['mean'])
df['std'] = pd.DataFrame(std_cases, columns=['std'])
df['rolling'] = df['mean'].rolling(7).mean()

matplotlib 绘制这些时间序列,以直观地进行展示:

plt.figure(figsize=(10,5))
plt.plot(df['mean'], 'k-', label='Mean')
plt.plot(df['rolling'], 'g-', label='Moving average')
plt.grid(linestyle=':')
plt.fill_between(df.index, df['mean']-df['std'], df['mean']+df['std'], color='r', alpha=0.1)
plt.axvline(x=360, color='b', linestyle='--')
plt.text(360, 1.5, 'Train/test split', rotation=-90, color='b')
plt.xlabel('Time (days)')
plt.ylabel('Normalized number of visits')
plt.legend(loc='upper right')
plt.show()

时间序列

可以看到,数据呈现出周期性模式,我们希望 TGNN 能够学习这些模式。接下来,我们实现该架构,并观察其性能表现。

4.2 模型构建与训练

(1) 时空图神经网络 (Temporal Graph Neural Networks, TGNN) 需要两个参数作为输入:节点数 (node_count) 和输入维度 (dim_in)。TGNN 只有两层:EvolveGCN-H 层和线性层,线性层为每个节点输出预测值:

import torchclass TemporalGNN(torch.nn.Module):def __init__(self, node_count, dim_in):super().__init__()self.recurrent = EvolveGCNH(node_count, dim_in)self.linear = torch.nn.Linear(dim_in, 1)

(2) forward() 方法中对输入应用 EvolveGCN-H 层和线性层,并使用 ReLU 激活函数:

    def forward(self, x, edge_index, edge_weight):h = self.recurrent(x, edge_index, edge_weight).relu()h = self.linear(h)return h

(3) 创建 TemporalGNN 实例,并使用 WikiMaths 数据集的节点数和输入维度作为参数。使用 Adam 优化器对其进行训练:

model = TemporalGNN(dataset[0].x.shape[0], dataset[0].x.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

(4) 打印模型信息,观察 EvolveGCNH 所包含的网络层:

print(model)

网络架构

可以看到,EvolveGCNH 包含三个层,TopKPooling 将输入矩阵汇总为八列;门控递归单元 (Gated Recurrent Unit, GRU) 更新图卷积网络 (Graph Convolutional Network, GCN) 权重矩阵;GCNConv 生成新的节点嵌入。最后,使用线性层输出图中每个节点的预测值。

(5) 创建训练循环,在训练集的每个快照上训练模型。通过反向传播计算每个快照的损失:

for epoch in tqdm(range(50)):for i, snapshot in enumerate(train_dataset):y_pred = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)loss = torch.mean((y_pred-snapshot.y)**2)loss.backward()optimizer.step()optimizer.zero_grad()

(6) 在测试集上对训练后的模型进行评估。对整个测试集的均方误差 (Mean squared error, MSE) 取平均值,得出最终得分:

# Evaluation
model.eval()
loss = 0
for i, snapshot in enumerate(test_dataset):y_pred = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)mse = torch.mean((y_pred-snapshot.y)**2)loss += mse
loss = loss / (i+1)
print(f'MSE = {loss.item():.4f}')# MSE = 0.7819

(7) 可以看到,MSE 损失值为 0.7819。接下来,绘制模型在先前图上预测的平均值,以对其进行解释。对预测值进行平均,并将它们存储在一个列表中,然后,将它们添加到先前的图中:

y_preds = [model(snapshot.x, snapshot.edge_index, snapshot.edge_attr).squeeze().detach().numpy().mean() for snapshot in test_dataset]plt.figure(figsize=(10,5), dpi=300)
plt.plot(df['mean'], 'k-', label='Mean')
plt.plot(df['rolling'], 'g-', label='Moving average')
plt.plot(range(360,722), y_preds, 'r-', label='Prediction')
plt.grid(linestyle=':')
plt.fill_between(df.index, df['mean']-df['std'], df['mean']+df['std'], color='r', alpha=0.1)
plt.axvline(x=360, color='b', linestyle='--')
plt.text(360, 1.5, 'Train/test split', rotation=-90, color='b')
plt.xlabel('Time (days)')
plt.ylabel('Normalized number of visits')
plt.legend(loc='upper right')
plt.show()

预测结果

可以看到,预测值与数据的总体趋势一致。考虑到数据集的规模有限,可以称得上是一个非常好的结果。

(8) 最后,绘制散点图,以显示单个快照的预测值和真实值之间的差异:

import seaborn as snsy_pred = model(test_dataset[0].x, test_dataset[0].edge_index, test_dataset[0].edge_attr).detach().squeeze().numpy()plt.figure(figsize=(10,5), dpi=300)
sns.regplot(x=test_dataset[0].y.numpy(), y=y_pred)
plt.show()

真实值与预测值差异

可以看到,预测值和真实值之间存在适度的正相关。本节训练的模型并不是非常准确,但从上图可以看出,它能很好地理解数据的周期性。
EvolveGCN-O 变体的实现过程与 EvolveGCN-H 相似,只需要使用 PyTorch Geometric Temporal 中的 EvolveGCNO 层替代 EvolveGCNH 层。EvolveGCNO 层不需要节点数,所以只需要给它输入维度。其实现方法如下:

from torch_geometric_temporal.nn.recurrent import EvolveGCNOclass TemporalGNN(torch.nn.Module):def __init__(self, dim_in):super().__init__()self.recurrent = EvolveGCNO(dim_in, 1)self.linear = torch.nn.Linear(dim_in, 1)def forward(self, x, edge_index, edge_weight):h = self.recurrent(x, edge_index, edge_weight).relu()h = self.linear(h)return hmodel = TemporalGNN(dataset[0].x.shape[1])

总体而言,EvolveGCN-O 模型能够取得类似的结果,平均 MSE0.7524。在这种情况下,使用 GRULSTM 网络并不会影响预测结果,因为过去的访问数量(包含在节点特征中的 EvolveGCN-H )和页面之间的连接 (EvolveGCN-O) 都是重要的因素。因此,这种 GNN 架构特别适合这种流量预测任务。

小结

本节介绍了具有时空信息的图数据。这种时空成分在许多应用中都很有帮助,主要与时间序列预测有关。我们介绍了两种符合这种定义的图:静态图(特征随时间变化)和动态图(特征和拓扑结构会发生变化)。PyTorch Geometric TemporalPyTorch Geometric 的扩展,专门用于处理时空图神经网络。此外,我们实现了 EvolveGCN 架构,该架构使用 GRULSTM 网络更新 GCN 参数。应用此架构执行 Web 流量预测,并且在有限的数据集上取得了出色的结果。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法
图神经网络实战(17)——深度图生成模型
图神经网络实战(18)——消息传播神经网络
图神经网络实战(19)——异构图神经网络


http://www.mrgr.cn/news/70201.html

相关文章:

  • 数据库SQLite的使用
  • 「QT」顺序容器类 之 QVector 动态数组类
  • 产品经理如何提升项目管理能力
  • Apache服务安装
  • Python用CEEMDAN-LSTM-VMD金融股价数据预测及SVR、AR、HAR对比可视化
  • Android GPU纹理数据拷贝
  • ORB-SLAM2源码学习:Frame.cc: Frame::isInFrustum 判断地图点是否在当前帧的视野范围内
  • 【昱合昇天窗】电动采光排烟天窗功率
  • Python 列表:数据处理的强大工具
  • NX/UG 二次开发 获取注释信息
  • Redis 入门
  • go聊天系统项目-2 redis 验证用户id和密码
  • 0-1000 的数字里,恰好只有一个5的数的个数
  • 【AI技术】DH_Live部署方案
  • 量化交易系统开发-实时行情自动化交易-2.技术栈
  • 适合初学者和专家程序员的 AI 编码工具
  • 贯穿式学习MySQL
  • 歌曲去人声的轻松技巧,只需两步就能获取纯伴奏
  • 优化时钟网络之时钟偏移
  • [CKS] Audit Log Policy
  • 快速了解SpringBoot 统一功能处理
  • 集运行业破内卷:以差异化服务打造准时达品牌,重塑良性竞争生态
  • 双 11 数据可视化:Pyecharts 与 Matplotlib 绘制商品价格对比及动态饼图
  • 华大单片机跑历程IO口被写保护怎么解决
  • golang分布式缓存项目 Day3 HTTP服务端
  • 如何让 AI 更懂你:提示词的秘密