StemGNN(Spatio-Temporal Evolutionary Map Graph Neural Network) 是一种针对时空数据预测设计的模型,融合图神经网络(GNN)与时间序列分析,广泛用于交通流量预测、气象预报、股票市场分析等场景。
一 图卷积神经网络的学习(GCN)
StemGNN是在GCN的基础上改进而成的,所以先对GCN进行学习。
1.1图卷积的特别之处
传统卷积具有局限性,只能处理欧几里得数据,像图像,文本,而无法处理非欧几里得数据。并且图数据能自然表达关系与交互,适合应用于节点分类,连接预测和图分类。
1.1.1什么叫欧几里得数据:
存在于欧几里得空间中的数据,数据具有明确的几何结构和距离度量方式,通常可以用坐标表示。
欧几里得数据的特性:
坐标表示:数据点可以用向量形式表示
可度量的距离:点与点之间的距离可通过欧氏距离公式计算。
线性运算:支持向量加减、内积(点积)等线性代数操作。
规则的结构:数据具有网格状或连续的几何结构(如图像、时间序列、传统表格数据)。
1.1.2什么叫非欧几里得数据:
非欧几里得数据通常存在于不规则或复杂结构中,无法直接用欧氏距离衡量相似性。
eg:图数据(Graph Data):如社交网络、分子结构,节点间通过边连接,关系依赖拓扑结构。
流形数据(Manifold Data):如球面或环面,局部类似欧氏空间,但全局结构复杂。
序列数据(如文本、语音):具有时间或顺序依赖关系,需考虑上下文。
特性 | GCN | 传统CNN |
---|---|---|
数据形式 | 图结构(非欧几里得) | 网格数据(图像等) |
卷积核 | 基于图结构的动态邻域聚合 | 固定大小的滑动窗口 |
应用场景 | 社交网络、分子结构等 | 图像、视频、文本 |
局部感受野 | 一阶或多阶邻居 | 固定邻域(如3x3) |
1.2图卷积的核心原理
1.2.1卷积的泛化:从信号处理到图结构,使用图傅里叶变换
传统信号处理中的傅里叶变换将信号从时域转换到频域,卷积定理表明时域卷积等价于频域乘法。
GCN将此思想推广到图上:
(1)图傅里叶变换:基于图拉普拉斯矩阵 的特征分解。
这里要插一下,为啥要进行傅里叶变换呢?
在信号处理和数学中,对卷积进行傅里叶变换的主要目的是利用卷积定理,将复杂的时域卷积运算转化为频域的简单乘积运算,从而大幅降低计算复杂度。
(2)谱域图卷积:定义频域的滤波器,对信号(节点特征)进行滤波。
1.2.1拉普拉斯矩阵 (Laplacian Matrix)
什么是拉普拉斯矩阵?
拉普拉斯矩阵的核心作用是将复杂的图结构转化为线性代数可处理的形式。
拉普拉斯矩阵通常有两种形式:组合拉普拉斯矩阵和归一化拉普拉斯矩阵。
(1) 组合拉普拉斯矩阵(Combinatorial Laplacian)
对于一个无向图 ,其拉普拉斯矩阵
定义为:
D 是度矩阵(对角矩阵,对角线元素为每个节点的度数)(度数(Degree) 是指一个节点(顶点)直接连接的边的数量)
A 是邻接矩阵(元素 Aij=1 表示节点 i 和 j 相连,否则为 0)。
(2) 归一化拉普拉斯矩阵(Normalized Laplacian)
对称归一化:
随机游走归一化:
拉普拉斯矩阵性质:
(1)半正定性:所有特征值大于零。
(2)特征值与图的连通性:零特征值的重数等于图中连通分量的数量。若图是连通的,则只有一个零特征值。
(3)二次型与图的总变差
对于任意向量 ,有:
该式衡量了信号 在图上的平滑程度(相邻节点差异越小,值越小)。
eg:
节点集合 V={1,2,3};边集合 E={(1,2),(2,3)}
邻接矩阵 A为:
因为边集合给出,ij = 1,2;ij = 2,1; ij= 2,3; ij= 3,2都相连所以是1
度矩阵 D 为:
因为2连接了1,3所以第二个节点度数为2,第一三节点度数为1。
组合拉普拉斯矩阵 L 为:
求出特征值为 0,1,3,表明图是连通的(仅一个零特征值)
1.3 GCN组成
主要组件:
输入层:节点特征矩阵 (N为节点数,d为特征维度)。
邻接矩阵 A:描述节点连接关系。
图卷积层:逐层传播,聚合邻域信息。
输出层:根据任务目标设计(如节点分类用Softmax)。
1.4 GCN的变体
模型 | 核心改进 | 适用场景 |
---|---|---|
GraphSAGE | 采样邻居+聚合函数(Mean/LSTM/Pooling) | 大规模图(避免全图计算) |
GAT | 注意力机制加权邻域聚合 | 动态或异质关系(如用户兴趣动态变化) |
GIN | 基于Weisfeiler-Lehman同构测试的理论增强 | 图分类任务(如分子分类) |
APPNP | 结合PageRank与传播机制,缓解过平滑 | 需要深层网络的复杂图结构 |
1.5GCN的优缺点
优点:
图结构建模能力:显式编码节点间依赖,捕捉复杂关系。
参数共享:所有节点共享同一权重矩阵 W,扩展性强。
端到端学习:无需手工设计图特征。
缺点:
过平滑(Over-smoothing):深层GCN会导致不同类节点特征趋同。
邻域范围受限:多层后可能引入噪声(实际应用中通常不超过3层)。
计算复杂度:存储密集的邻接矩阵对大规模图不友好(需用稀疏矩阵优化)。
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
# 加载数据集
dataset = Planetoid(root='./data/Cora', name='Cora')
data = dataset[0]
# 定义GCN模型
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = torch.nn.functional.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.nn.functional.log_softmax(x, dim=1)
# 初始化模型与优化器
model = GCN(dataset.num_features, 16, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练过程
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = torch.nn.functional.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# 测试准确率
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = correct / data.test_mask.sum()
print(f'Test Accuracy: {accuracy:.4f}')
二 时空图神经网络学习
2.1StemGNN的目标
通过联合建模时空演化规律,实现对复杂动态系统的精准预测。核心创新点为:
(1)动态图结构学习:根据数据自动生成时序变化的图邻接矩阵。、
(2)频域-时域联合分解:结合傅里叶变换与小波分析提取多尺度特征。
2.2StemGNN框架
StemGNN的架构分为三部分:
(1)时空特征分解层:将原始数据分解为多个频带的时空分量。
输入:时空数据矩阵(T时间步,N节点,D特征)。
通过快速傅里叶变换(FFT)或移动平均(MA)提取多尺度分量。
例如,将时间序列分解为趋势(Trend)、周期(Seasonal)和残差(Residual):
(2)动态图学习模块:学习不同频带对应的动态空间关系(Graph Learner)。
为每个特征分量生成独立的邻接矩阵 。
利用神经网络学习节点间的动态相关性。
例如,通过学习函数 生成邻接矩阵:
不同频带(如高低频)可能对应不同的空间交互模式。
(3)时空卷积层:对每个频带分别进行时空卷积并融合最终结果。
空间卷积:对每个分量的图结构应用GNN。
(GNN是图神经网络,只是说GCN图卷积网络是GNN中最典型的之一)
时间卷积:使用扩张因果卷积(TCN)或Transformer捕捉时间依赖。
融合输出:将各分量的预测结果合并,得到最终输出。
其中, 为时空卷积,
为分解,
为分解后的频带数量。
特性 | 传统模型 | StemGNN |
---|---|---|
时空建模 | 分离处理(如GCN+LSTM) | 联合建模,自适应交互 |
图结构 | 静态或固定邻接矩阵 | 动态生成,分频带学习 |
特征分解 | 单一输入直接处理 | 频域分解,捕捉多尺度模式 |
长期依赖 | RNN梯度消失/爆炸 | 扩张卷积或Transformer缓解长程问题 |
三 应用场景
3.1 交通流量预测
输入:路网传感器的时间流量数据,节点为监测点。
预测目标:未来1小时各站点的流量。
动态图示例:工作日上午的城区间相关性更强。
3.2 股市波动分析
输入:多只股票的历史价格变动。
动态图:公司间的关联性随市场新闻动态变化。
输出:未来股价趋势或风险评估。
3.3气象预测
输入:多个气象站的气压、温度、风速数据。
动态图:台风路径导致的区域影响变化。
四 代码实现
4.1动态图学习
import torch
import torch.nn as nn
class GraphLearner(nn.Module):
def __init__(self, input_dim, hidden_dim, node_num):
super().__init__()
self.node_num = node_num
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, node_num * node_num)
def forward(self, x):
# x: [batch, T, N, D], 取最后一个时间步特征
x_last = x[:, -1, :, :] # [batch, N, D]
batch_size = x_last.size(0)
h = torch.relu(self.fc1(x_last)) # [batch, N, hidden]
adj = self.fc2(h) # [batch, N, N*N]
adj = adj.view(batch_size, self.node_num, self.node_num)
adj = torch.softmax(adj, dim=-1) # 归一化为概率邻接矩阵
return adj
4.2空间卷积和时间卷积
class SpatioTemporalBlock(nn.Module):
def __init__(self, node_num, input_dim, gcn_hidden, tcn_hidden):
super().__init__()
# 空间卷积(GAT示例)
self.gat = GATConv(input_dim, gcn_hidden, heads=1)
# 时间卷积(TCN示例)
self.tcn = nn.Conv1d(gcn_hidden, tcn_hidden, kernel_size=3, dilation=2, padding=2)
def forward(self, x, adj):
# x: [batch, T, N, D]
batch, T, N, D = x.shape
x = x.permute(0, 2, 1, 3) # [batch, N, T, D]
h_space = []
for t in range(T):
# 对每个时间步应用GAT
h_t = self.gat(x[:, :, t, :], adj) # [batch, N, gcn_hidden]
h_space.append(h_t)
h_space = torch.stack(h_space, dim=2) # [batch, N, T, gcn_hidden]
# 时间卷积
h_time = h_space.permute(0, 3, 2, 1) # [batch, gcn_hidden, T, N]
h_time = h_time.reshape(batch, -1, T) # [batch, gcn_hidden*N, T]
h_time = self.tcn(h_time) # [batch, tcn_hidden, T]
return h_time.permute(0, 2, 1)
五 问题和解决办法
计算复杂度:动态图生成和多分量处理增加计算成本。
长时序训练:分解层可能丢失部分高频细节。
方法 | 描述 |
---|---|
轻量图学习 | 简化邻接矩阵生成网络(如低秩分解)。 |
自适应分解 | 使用小波变换替代固定频带分解。 |
分层融合 | 跨分量的特征交互(类似Inception结构)。 |
StemGNN通过动态图生成与时空频带分解,显著提升了复杂时空数据的建模能力。其核心价值在于:
(1)自适应的空间关系捕获,避免静态图的假设偏差。
(2)联合优化时间与空间特征,挖掘深层次的耦合规律。