目录
通用版(!!!基于内置)-minibatch大数据集OGB训练
零、简介
本人前一段学习的GCN看的差不多了,现在准备看GAT,于是开始搜集相关资料学习。其中本论文是ICLR2018 论文,Graph Attention Network在GNN中非常重要,再之前图卷积网络GCN的基础之上引入了注意力机制,非常实用。
一、GAT讲解
学习资源
- 论文地址:https://arxiv.org/pdf/1710.10903.pdf
- 源码实现:https://github.com/Diego999/pyGAT
- pyg实现:https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html#GATConv.forward-卷积实现(见手册)
- https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/conv/gat_conv.py-同上(见git)
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/gat.py-模型实现(git)
动机
- GCN 假设图是无向的,因为利用了对称的拉普拉斯矩阵 (只有邻接矩阵 A 是对称的,拉普拉斯矩阵才可以正交分解),不能直接用于有向图。GCN 的作者为了处理有向图,需要对 Graph 结构进行调整,要把有向边划分成两个节点放入 Graph 中。例如 e1、e2 为两个节点,r 为 e1,e2 的有向关系,则需要把 r 划分为两个关系节点 r1 和 r2 放入图中。连接 (e1, r1)、(e2, r2)。
- GCN 不能处理动态图,GCN 在训练时依赖于具体的图结构,测试的时候也要在相同的图上进行。因此只能处理 transductive 任务,不能处理 inductive 任务。transductive 指训练和测试的时候基于相同的图结构,例如在一个社交网络上,知道一部分人的类别,预测另一部分人的类别。inductive 指训练和测试使用不同的图结构,例如在一个社交网络上训练,在另一个社交网络上预测。
- GCN 不能为每个邻居分配不同的权重,GCN 在卷积时对所有邻居节点均一视同仁,不能根据节点重要性分配不同的权重。
2018 年图注意力网络 GAT 被提出,用于解决 GCN 的上述问题,论文是《GRAPH ATTENTION NETWORKS》。GAT 采用了 Attention 机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于 inductive 任务。
创新
图数据结构的两种“特征”
GAT的两种运算方式(masked只针对一阶邻居)
还有一件事件需要提前说清楚:GAT本质上可以有两种运算方式的,这也是原文中作者提到的
公式
论文原文
分析
具体参考
和所有的attention mechanism一样,GAT的计算也分为两步走:
1、 计算注意力系数(attention coefficient)
2、 加权求和(aggregate)
3、多头注意力拼接
与GCN的联系与区别
无独有偶,我们可以发现本质上而言:GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算),利用graph上的local stationary学习新的顶点特征表达。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为 顶点特征之间的相关性被更好地融入到模型中。
为什么GAT适用于有向图?
我认为最根本的原因是GAT的运算方式是逐顶点的运算(node-wise),这一点可从公式(1)—公式(3)中很明显地看出。每一次运算都需要循环遍历图上的所有顶点来完成。逐顶点运算意味着,摆脱了拉普利矩阵的束缚,使得有向图问题迎刃而解。
为什么GAT适用于inductive任务?
GAT中重要的学习参数是 W 与 a(.) ,因为上述的逐顶点运算方式,这两个参数仅与1.1节阐述的顶点特征相关,与图的结构毫无关系。所以测试任务中改变图的结构,对于GAT影响并不大,只需要改变 ,重新计算即可。
与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。学习的参数很大程度与图结构相关,这使得GCN在inductive任务上遇到困境。
源码实现核心理解
模型
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
#nhid*nheads=8*8 nclass=7 ->GraphAttentionLayer.W:(64,7)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj):#cora数据集:x:(2708,1433)节点属性,2708个节点,每个节点1433个属性 adj:(2708,2708)节点构造的邻接矩阵
x = F.dropout(x, self.dropout, training=self.training)# 需要将模型整体的training状态参数传入dropout函数
#8个head,每个生成(2708,8),然后执行cat操作后得到x:(2708,64)
#每个att(x, adj)都是公式1-4的执行过程 torch.cat就是公式5,也是公式6的括号里面部分
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)#x:(2708,64)
x = F.elu(self.out_att(x, adj))#执行前x:(2708,64),执行完后x:(2708,7) 公式6
return F.log_softmax(x, dim=1)
GAL层
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
。。。
def forward(self, input, adj):
h = torch.mm(input, self.W)#公式1:WH部分
#W:(1433,8) input(2708,1433)->h:(2708,8) 直接通过torch.mm矩阵乘完成全连接运算
#最后还有一次self-attention,由self.out_att(x, adj)传入得到的self.W:(64,8),传入的input为(2707,64)
N = h.size()[0]#N:2708
'''
repeat(1,N)不指定axis时逐个元素复制repeats次,形成一个行向量,这里就是第一维度不变,后面第二个维度乘以N
repeat(N,1)不指定axis时逐个元素复制repeats次,形成一个行向量,这里就是第二维度不变,前面第一个维度乘以N
view(-1,h,c),view前面的维度(这里指的是所有元素的个数)除以h再除以c得到的数值自动补到-1的位置,后面维度不变;这里N*N就是2708,8*2708,第二个维度就是前面(2708*8*2708/2708,8*2708)等于8
cat是第二个维度(dim=1)进行拼接,也就是8+8=16
a就是拼接的维度,即(16,1)
squeeze(2)就是第三维度为1的话降维,结果就是剩下前两维
已知h:(2708,8),N=2708
h.repeat(1, N):(2708,8*2708).view->(2708*2708,8)
h.repeat(N, 1)->(2708*2708,8)
a_input=torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1)->(2708*2708,16).view->2708,2708,2*8=(2708,2708,16)
self.a:(16,1)偏移量b
e=torch.matmul(a_input, self.a):(2708,2708,16)*(16,1)->(2708,2708,1).squeeze->e:2708,2708
#具体的合并操作讲解如下
a_input为公式3中的a[WHi||WHj],cat操作就是合并,图1左部分WHi和WHj合并,是每个节点属性分别与所有节点
(2708个节点包括自己)属性的合并操作:
结果如下:对于节点1,分别与节点1属性,节点2属性,。。。一直到节点2708属性合并:
[节点1: [[节点1的8个属性,节点1的8个属性]
[节点1的8个属性,节点2的8个属性]
。。。。。。
[节点1的8个属性,节点2708的8个属性]] 共2708个
节点2:
[[节点2的8个属性,节点1的8个属性]
[节点2的8个属性,节点2的8个属性]
。。。。。。
[节点2的8个属性,节点2708的8个属性]] 共2708个
。。。
节点2708:
[[节点2708的8个属性,节点1的8个属性]
[节点2708的8个属性,节点2的8个属性]
。。。。。。
[节点2708的8个属性,节点2708的8个属性]] 共2708个
]所以a_input是(2708,2708,16)
这表明每个节点的8个属性(是经过mm运算将1403个属性集成为8个)与所有节点都进行了聚集。
#注意力系数这里也做了改动,只分配给一阶邻居
self-attention会将注意力系数分配到图中所有的节点上,这种做法显然会丢失结构信息。
为了解决这一问题,本文使用了一种masked attention的方式——仅将注意力分配到节点的邻节点集上。
也就是后面的根据邻接关系更新这个e: attention = torch.where(adj > 0, e, zero_vec)
'''
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)#公式3:两个WH的拼接部分
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))#公式3里面的leakyrelu部分(将前面两个WH拼接的结果接上一个a神经网络后再relu) 即公式1的e
#造了一个e:(2708,2708)维度,元素为全0的张量
zero_vec = -9e15*torch.ones_like(e)
#根据邻接矩阵adj,找出e中大于0的元素组成张量
'''
torch.where(condition,x,y)
out = x,if condition is 1
= y ,if condition is 0
'''
#也就是根据邻接矩阵adj来组建attention(2708,2708),若对应adj位置大于0,取e中对应位置元素,若小于0取为zero_vec中对应位置元素
attention = torch.where(adj > 0, e, zero_vec) #公式3注意力分配判断,只是将注意力分配到邻居节点,而不是所有节点
#真正的公式3,单层全连接处理后只考虑邻接矩阵后的激活函数,就相当于一个考虑了邻居的权重系数
attention = F.softmax(attention, dim=1) #公式3最后做的softmax
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)#公式4,h_prime(2708,8)
#将这个权重系数attention与Wh相乘,即在原先节点属性运算时考虑了邻接节点的权重关系。
#相当于公式4的括号里面部分,只考虑了一阶邻接节点
if self.concat:
return F.elu(h_prime)#公式4
else:
return h_prime
二、GAT的PyG内置的模型实现
可以参考
- https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv-文档分析
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html#GATConv.forward-实现
- https://github.com/rusty1s/pytorch_geometric/blob/master/tor