PyTorch Geometric(PYG)-实现小批量data类中__inc__与__cat_dim__的含义与作用

本文解析了PyTorch Geometric (PYG) 中__inc__与__cat_dim__的作用及实现方法,通过实例展示了如何自定义这两个函数来实现特定的数据拼接需求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

PYG中实现小批量data类中__inc__与__cat_dim__的含义与作用

1.作用

此两个函数出现在pytorch geometric实现批量操作时,batch集行为的自定义修改方法,两种方法都是为了解决多个数据之间的拼接问题。

2.直观图解

  • 官方初始定义,均对某一属性值进行判定
def __inc__(self, key, value):
    if 'index' in key or 'face' in key:
        return self.num_nodes
    else:
        return 0

def __cat_dim__(self, key, value):
    if 'index' in key or 'face' in key:
        return 1
    else:
        return 0

返回值的具体含义见图:

  • __ inc __
    在这里插入图片描述
    即__ inc __返回值表示相应矩阵错位步数,一般用于边的邻接矩阵:
    在这里插入图片描述
  • __ cat_dim __
    在这里插入图片描述
    即__ cat_dim __返回值表示相应矩阵拼接的维度。按行或列拼接,一般用于节点或者结果矩阵的拼接

3.官方示例

官方默认batch拼接形式如下在这里插入图片描述
A为邻接矩阵,X为节点矩阵,Y为结果矩阵

此时对比官方示例,一目了然:

  • 成对数据结构
class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
  • 批量时,要分成两个数据集来拼接,所以需要自定义边矩阵的拼接
def __inc__(self, key, value):
    if key == 'edge_index_s':
        return self.x_s.size(0)
    if key == 'edge_index_t':
        return self.x_t.size(0)
    else:
        return super().__inc__(key, value)
  • 测试
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
>>> Batch(edge_index_s=[2, 8], x_s=[10, 16],
          edge_index_t=[2, 6], x_t=[8, 16])

print(batch.edge_index_s)
>>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
            [1, 2, 3, 4, 6, 7, 8, 9]])

print(batch.edge_index_t)
>>> tensor([[0, 0, 0, 4, 4, 4],
            [1, 2, 3, 5, 6, 7]])

两个A矩阵,分别按照对角线扩展方式拼接了~!

### Graph Attention Network (GAT)实现应用 #### 背景介绍 Graph Attention Network (GAT)[^2] 是一种基于注意力机制的神经网络架构,专门用于处理图结构数据。它能够通过学习节点间的权重分配来捕捉复杂的节点关系和特征表示。 #### 使用 PyTorchPyG 实现 GAT 以下是使用 PyTorchPyTorch Geometric (PyG)实现 GAT 的具体方法: 1. **安装依赖库** 需要先确保已安装 `torch` 和 `torch_geometric` 库。 ```bash pip install torch torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric ``` 2. **导入必要的模块** 下面是一些基本的 Python 导入语句: ```python import torch from torch.nn import Linear, Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import add_self_loops, degree ``` 3. **定义 GAT 层** 定义一个多头自注意力层(multi-head self-attention),这是 GAT 的核心部分之一。 ```python class GATConv(MessagePassing): def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0): super(GATConv, self).__init__(aggr='add') self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.lin = Linear(in_channels, heads * out_channels, bias=False) self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels)) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() torch.nn.init.xavier_uniform_(self.att) def forward(self, x, edge_index): H, C = self.heads, self.out_channels x = self.lin(x).view(-1, H, C) return self.propagate(edge_index, x=x, size=None) def message(self, x_j, x_i, index, ptr, size_i): alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = torch.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1) ``` 4. **构建完整的 GAT 模型** 将多个 GAT 层堆叠起来形成一个完整的模型。 ```python class GAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_heads): super(GAT, self).__init__() self.gat_conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, concat=True) self.gat_conv2 = GATConv(hidden_channels * num_heads, out_channels, heads=1, concat=False) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.gat_conv1(x, edge_index) x = torch.relu(x) x = self.gat_conv2(x, edge_index) return torch.log_softmax(x, dim=1) ``` 5. **训练过程** 训练过程中可以通过标准的交叉熵损失函数优化模型参数。 ```python model = GAT(dataset.num_features, hidden_channels=8, out_channels=dataset.num_classes, num_heads=8) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() ``` 6. **测试性能** 测试阶段评估模型在验证集上的表现。 ```python model.eval() _, pred = model(data).max(dim=1) correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) acc = correct / int(data.test_mask.sum()) print(f'Accuracy: {acc:.4f}') ``` #### 关键概念解析 - **Self-Attention Mechanism**: 自注意力机制允许每个节点关注其邻居的不同程度,从而动态调整重要性权重[^4]。 - **LeakyReLU Activation Function**: LeakyReLU 函数被用来引入非线性激活,在计算注意力分数时起到重要作用- **Multi-Head Attention**: 多头注意力似于集成学习的思想,有助于提高模型表达能力并稳定梯度下降过程[^3]。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值