PyG-GAT+Mini理解与实现

本文详细介绍了图注意力网络(GAT)的原理,包括其与GCN的区别,创新点,以及在有向图和inductive任务中的优势。通过分析论文和PyTorch Geometric(PyG)的实现,探讨了GAT的运算方式和源码核心理解。此外,还展示了如何在PyG中实现GAT,包括内置模型的使用和自定义GAT层的多种实现方式,适合于不同规模的数据集训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

零、简介

一、GAT讲解

学习资源

动机

创新

图数据结构的两种“特征”

GAT的两种运算方式(masked只针对一阶邻居)

公式

论文原文

分析 

与GCN的联系与区别

为什么GAT适用于有向图?

为什么GAT适用于inductive任务?

源码实现核心理解

模型

GAL层

二、GAT的PyG内置的模型实现

 pyg对GATConv的实现

分析

完整代码

最简单2层模型的实现

三、自定义GAT层实现(基于PyG)

单层注意力

多层注意力

通用版(自定义)-fullbatch

通用版(!!!基于内置)-minibatch小型数据集训练

通用版(!!!基于内置)-minibatch大数据集OGB训练



零、简介

本人前一段学习的GCN看的差不多了,现在准备看GAT,于是开始搜集相关资料学习。其中本论文是ICLR2018 论文,Graph Attention Network在GNN中非常重要,再之前图卷积网络GCN的基础之上引入了注意力机制,非常实用。

一、GAT讲解

学习资源

动机

  • GCN 假设图是无向的,因为利用了对称的拉普拉斯矩阵 (只有邻接矩阵 A 是对称的,拉普拉斯矩阵才可以正交分解),不能直接用于有向图。GCN 的作者为了处理有向图,需要对 Graph 结构进行调整,要把有向边划分成两个节点放入 Graph 中。例如 e1、e2 为两个节点,r 为 e1,e2 的有向关系,则需要把 r 划分为两个关系节点 r1 和 r2 放入图中。连接 (e1, r1)、(e2, r2)。
  • GCN 不能处理动态图,GCN 在训练时依赖于具体的图结构,测试的时候也要在相同的图上进行。因此只能处理 transductive 任务,不能处理 inductive 任务。transductive 指训练和测试的时候基于相同的图结构,例如在一个社交网络上,知道一部分人的类别,预测另一部分人的类别。inductive 指训练和测试使用不同的图结构,例如在一个社交网络上训练,在另一个社交网络上预测。
  • GCN 不能为每个邻居分配不同的权重,GCN 在卷积时对所有邻居节点均一视同仁,不能根据节点重要性分配不同的权重。

2018 年图注意力网络 GAT 被提出,用于解决 GCN 的上述问题,论文是《GRAPH ATTENTION NETWORKS》。GAT 采用了 Attention 机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于 inductive 任务。

创新

图数据结构的两种“特征”

GAT的两种运算方式(masked只针对一阶邻居)

还有一件事件需要提前说清楚:GAT本质上可以有两种运算方式的,这也是原文中作者提到的

公式

论文原文

分析 

具体参考

和所有的attention mechanism一样,GAT的计算也分为两步走:

1、 计算注意力系数(attention coefficient)

2、 加权求和(aggregate)

3、多头注意力拼接

源码:https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html#GATConv.forward

与GCN的联系与区别

无独有偶,我们可以发现本质上而言:GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算),利用graph上的local stationary学习新的顶点特征表达。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为 顶点特征之间的相关性被更好地融入到模型中。

为什么GAT适用于有向图?

我认为最根本的原因是GAT的运算方式是逐顶点的运算(node-wise),这一点可从公式(1)—公式(3)中很明显地看出。每一次运算都需要循环遍历图上的所有顶点来完成。逐顶点运算意味着,摆脱了拉普利矩阵的束缚,使得有向图问题迎刃而解。

为什么GAT适用于inductive任务?

GAT中重要的学习参数是 W 与 a(.)  ,因为上述的逐顶点运算方式,这两个参数仅与1.1节阐述的顶点特征相关,与图的结构毫无关系。所以测试任务中改变图的结构,对于GAT影响并不大,只需要改变  ,重新计算即可。

与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。学习的参数很大程度与图结构相关,这使得GCN在inductive任务上遇到困境。

源码实现核心理解

模型

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
        #nhid*nheads=8*8   nclass=7  ->GraphAttentionLayer.W:(64,7)
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):#cora数据集:x:(2708,1433)节点属性,2708个节点,每个节点1433个属性 adj:(2708,2708)节点构造的邻接矩阵
        x = F.dropout(x, self.dropout, training=self.training)# 需要将模型整体的training状态参数传入dropout函数
        #8个head,每个生成(2708,8),然后执行cat操作后得到x:(2708,64)
        #每个att(x, adj)都是公式1-4的执行过程  torch.cat就是公式5,也是公式6的括号里面部分
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)#x:(2708,64)
        x = F.elu(self.out_att(x, adj))#执行前x:(2708,64),执行完后x:(2708,7)  公式6
        return F.log_softmax(x, dim=1)

GAL层

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    。。。
    def forward(self, input, adj):
        h = torch.mm(input, self.W)#公式1:WH部分
        #W:(1433,8) input(2708,1433)->h:(2708,8) 直接通过torch.mm矩阵乘完成全连接运算
        #最后还有一次self-attention,由self.out_att(x, adj)传入得到的self.W:(64,8),传入的input为(2707,64)
        N = h.size()[0]#N:2708
        '''
        repeat(1,N)不指定axis时逐个元素复制repeats次,形成一个行向量,这里就是第一维度不变,后面第二个维度乘以N
        repeat(N,1)不指定axis时逐个元素复制repeats次,形成一个行向量,这里就是第二维度不变,前面第一个维度乘以N
        view(-1,h,c),view前面的维度(这里指的是所有元素的个数)除以h再除以c得到的数值自动补到-1的位置,后面维度不变;这里N*N就是2708,8*2708,第二个维度就是前面(2708*8*2708/2708,8*2708)等于8
        cat是第二个维度(dim=1)进行拼接,也就是8+8=16
        a就是拼接的维度,即(16,1)
        squeeze(2)就是第三维度为1的话降维,结果就是剩下前两维

        已知h:(2708,8),N=2708
        h.repeat(1, N):(2708,8*2708).view->(2708*2708,8)
        h.repeat(N, 1)->(2708*2708,8)

        a_input=torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1)->(2708*2708,16).view->2708,2708,2*8=(2708,2708,16)
        self.a:(16,1)偏移量b
        e=torch.matmul(a_input, self.a):(2708,2708,16)*(16,1)->(2708,2708,1).squeeze->e:2708,2708
        
        #具体的合并操作讲解如下
        a_input为公式3中的a[WHi||WHj],cat操作就是合并,图1左部分WHi和WHj合并,是每个节点属性分别与所有节点
        (2708个节点包括自己)属性的合并操作:
        结果如下:对于节点1,分别与节点1属性,节点2属性,。。。一直到节点2708属性合并:
        [节点1: [[节点1的8个属性,节点1的8个属性]
                [节点1的8个属性,节点2的8个属性]
                。。。。。。
                [节点1的8个属性,节点2708的8个属性]]     共2708个
        节点2:        
                [[节点2的8个属性,节点1的8个属性]
                [节点2的8个属性,节点2的8个属性]
                。。。。。。
                [节点2的8个属性,节点2708的8个属性]]     共2708个
        。。。
        节点2708:        
                [[节点2708的8个属性,节点1的8个属性]
                [节点2708的8个属性,节点2的8个属性]
                。。。。。。
                [节点2708的8个属性,节点2708的8个属性]]     共2708个  
        ]所以a_input是(2708,2708,16)
        这表明每个节点的8个属性(是经过mm运算将1403个属性集成为8个)与所有节点都进行了聚集。
        
        #注意力系数这里也做了改动,只分配给一阶邻居
        self-attention会将注意力系数分配到图中所有的节点上,这种做法显然会丢失结构信息。
        为了解决这一问题,本文使用了一种masked attention的方式——仅将注意力分配到节点的邻节点集上。
        也就是后面的根据邻接关系更新这个e:  attention = torch.where(adj > 0, e, zero_vec)
        '''
        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)#公式3:两个WH的拼接部分
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))#公式3里面的leakyrelu部分(将前面两个WH拼接的结果接上一个a神经网络后再relu) 即公式1的e
        #造了一个e:(2708,2708)维度,元素为全0的张量
        zero_vec = -9e15*torch.ones_like(e)
        #根据邻接矩阵adj,找出e中大于0的元素组成张量
        '''
        torch.where(condition,x,y)
        out = x,if condition is 1
            = y ,if condition is 0
        '''
        #也就是根据邻接矩阵adj来组建attention(2708,2708),若对应adj位置大于0,取e中对应位置元素,若小于0取为zero_vec中对应位置元素
        attention = torch.where(adj > 0, e, zero_vec) #公式3注意力分配判断,只是将注意力分配到邻居节点,而不是所有节点
        #真正的公式3,单层全连接处理后只考虑邻接矩阵后的激活函数,就相当于一个考虑了邻居的权重系数
        attention = F.softmax(attention, dim=1) #公式3最后做的softmax
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)#公式4,h_prime(2708,8)
        #将这个权重系数attention与Wh相乘,即在原先节点属性运算时考虑了邻接节点的权重关系。
        #相当于公式4的括号里面部分,只考虑了一阶邻接节点

        if self.concat:
            return F.elu(h_prime)#公式4
        else:
            return h_prime

二、GAT的PyG内置的模型实现

可以参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

静静喜欢大白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值