slot attention 网络搭建的代码讲解

系列文章目录



前言

  pytorch 基础有了,机器学习的基础不牢靠。想直接进入计算机视觉领域难度很大,一个新领域,开头最难。我们应该抱着学习的热忱,哪里不懂就去查,一定记得回顾。一头雾水,没关系,只要动起来,多看,多理解,我相信零散的知识会织出一件漂亮的衣裳。 本文结合代码和人工智能的解释进行讲解。


一、slot attention 源码的搭建结构

  网络搭建的类一般只有两个函数,一个是初始化函数,另一个便是前向传播函数。初始化函数主要设置一些参数和操作算子,这些参数有的默认值,直接赋值,当然也可以训练时更改。前向传播有两个任务,第一个便是根据我们输入的数据集来给初始化函数中的参数赋值,另一个任务是按照一定的顺序调用操作算子实现我们需要的功能。下面给一张结构图。

图 1 网络总体架构

  这个代码我感觉和 transformer 有很多经验之处,查找知识的时候,很多小点都会与 transformer 中的组件进行对比。其中初始化中的基础参数很好理解,数据集的输入与形状,槽的数目和形状设置等等。可学习参数的设置就需要有一定的经验,要有机器学习的基础,在初始化时要提升模型的稳定性,多样性(高斯分布),σ结合数据集设置取值范围。
需要安装的库:

import torch
from torch import nn
from torch.nn import init

二、网络初始化函数 init

  网络初始化函数分成两个部分。参数部分:网络基本参数、可学习参数。操作算子部分:注意力机制计算、GRU更新门、归一化层和前馈网络。

函数代码:

    def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
        super().__init__()
        self.dim = dim  # 初始化维度
        self.num_slots = num_slots # 初始化数目
        self.iters = iters # 迭代次数
        self.eps = eps # 精度
        self.scale = dim ** -0.5 #

        self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) # 这个维度有什么用?

        self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim)) #
        init.xavier_uniform_(self.slots_logsigma) #

        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(dim, dim)
        self.to_v = nn.Linear(dim, dim)

        self.gru = nn.GRUCell(dim, dim)

        hidden_dim = max(dim, hidden_dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(inplace = True),
            nn.Linear(hidden_dim, dim)
        )

        self.norm_input  = nn.LayerNorm(dim)
        self.norm_slots  = nn.LayerNorm(dim)
        self.norm_pre_ff = nn.LayerNorm(dim)

1. 网络基本参数

  网络基本参数包括槽的数量,特征的维度,迭代的次数、数值稳定项和隐藏层的维度。这里面介绍两个eps和scale。

1) eps

  eps是一个非常小的正数,对结果几乎没有影响。核心作用:

  1. 在注意力权重计算中,防止出现除零错误。
  2. 某些权重为零的时候,保证分母不会为零。

有些会使用clamp(min=eps)代替直接相加,常常用在概率归一化、注意力计算、梯度计算场景,保证数值稳定的标准实践。

2)scale

  这个参数用于点积注意力的放缩,点积运算就是矩阵乘法,前一个矩阵的列要求等于后一个矩阵的行,而方阵的对应元素相乘叫做哈达玛积。

  1. 数学原理
    公式表示:scale = 1 d i m \frac{1}{\sqrt{dim}} dim 1 ,其中 dim 是查询(Q)和键(K)向量的维度。
    • 作用: 缩放点积运算后的结果,防止其值过大或者过小。
    • 稳定梯度: 防止 softmax 层输入过大,防止梯度消失。
    • 控制数值范围: 使得注意力机制更加平滑,避免极端值出现。
    • 理论保障: 符合点积的方差分析,确保模型训练的稳定性。
  2. 为什么要放缩?
    • 当 dim 较大时,因为每个维度贡献的累加,点积的结果会倾向于变得很大,导致 soft_max 层梯度变得非常小,类似于 one-hot 编码,这被称为点积规模过大问题。
    • dim 很大时,K * V 的可能非常大,soft_max 层接近 0 或 1, 模型难以学习。
    • 放缩后: k ⋅ v d i m \frac{k \cdot v}{\sqrt{dim}} dim kv 这个值保持在 [ − d i m , d i m ] [-\sqrt{dim},\sqrt{dim}] [dim ,dim ] 附近,soft_max的输出更加平滑。
  3. 其他方案
    • 有些模型可以使用可学习的放缩因子,而不是固定的,如 Perceiver IO.
    • 在稀疏注意力中,放缩因子可能会调整来适应不同的注意力模式。

2. 可学习的初始槽位参数

  课学sigma习的参数主要有 slot_muslot_logsigma ,简称 μ 和 σ。

  1. 为什么维度是(1,1,dim)?
    • 广播机制: Pytorch 会自动将(1, 1, dim)广播到 (batch_size, num_slots, dim) ,避免重复储存,而且这里的运算也是张量,而不是数字型的标量。
    • 内存效率: 相比较(num_slots, dim),这种方式更加节省内存,尤其是 slot 的数量较大的时候。
  2. 参数含义及初始化
参数作用初始化
μ高斯分布的均值标准正态分布 randn
σ高斯分布的对数标准差初始化为0,后用 Xavier 初始化
  1. 为什么使用 Xavier 对 σ 实现对数初始化?
    • 保持方差稳定: Xavier 初始化根据输入输出维度调整初始化范围,避免了梯度爆炸或者消失。
    • 标准的高斯分布,标准差为 0 ,此时使用 Xavier ,不是固定的 0 ,是为了让模型能够学习到更灵活的初始分布。
  2. 为什么不对 μ 使用 Xavier ?
    • 高斯分布本身已经具备合理的初始多样性。
    • 模型后续会通过注意力机制迅速调整槽位。
  3. 超参数选择建议:
    • dim 较大时: 建议调小 σ 初始范围,避免噪声过大。
    • num_slots 较多时: 改用(num_slots, dim)而非广播,增强槽位差异性。

3. 注意力计算层

  注意力计算是一个很好工具,也是 transformer 底层的核心功能,在这里大家可以跳到相应的章节进行 Q, K, V 查询键值的学习,点击。大家关注一下张量的变化,实现的过程,此处解释一下含义。

为什么三个角色独立变化?

  • 角色分离:

    • Query: 表示当前关注的问题,比如槽位。
    • Key: 表示被检索的索引,比如说特征。
    • Value: 表示被提取的内容,比如特征的语义信息。
  • 灵活性: 独立的参数允许模型学习不同的投影。

4. GRU 更新门

   GRUCell 门控循环单元,用于在 slot attention 中更新槽位的状态。

updated_slot = GRUCell(aggregated_info, current_slot)

在每次更新迭代中,根据槽位和注意力聚合结果,动态更新槽位状态。

  • aggregated_info 是注意力加权的值 V 聚合后的结果。
  • current_slot 是上一轮槽位的状态。
  1. 为什么使用 GRU 而不使用普通的MLP ?
方法优点适用场景
GRU显示建模时序依赖、门控防止梯度消失需要迭代更新的场景
MLP结构简单单步前向推理
  • 门控机制: GRU 的更新门 update gate 和重置门 reset gate 可以稳定的处理长期依赖。
  • 在原文中也验证了 GRU 的效果更好。
  1. 设计关键点

    • 维度一致: 输入层和隐藏层都是 dim,避免了槽位更新前后位置不变。
    • 独立更新: 每个槽位通过独立的 GRU 更新,保证槽位的特异性。
    • 迭代特性: GRU 的隐藏层在多次迭代中逐步优化
  2. 超参数选择建议

    • dim 较大,可以尝试增加 GRU隐藏层的维度,但是需要保证输入和输出的兼容性。
    • num_slot 较大,进行深层迭代时(迭代次数 > 8),建议采用 LayerNorm 稳定训练。

5. 归一化和前馈神经网络

  三个模块的归一化就看下面这张表吧。

归一化层功能作用
self.norm_input特征归一化稳定初始分布
self.norm_slots槽位状态归一化防止数值爆炸
self.norm_pre_ff前馈网络前归一化在MLP前标准化,提升梯度稳定性

前馈网络 MLP
  这是一个经典的前馈神经网络,MLP 通过两头的升维和降维之间的配合,使用ReLU激活函数,拟合复杂的函数关系,提升模型的表达能力。MLP 输入维度和输出维度一致性,确保了与自注意力机制、GRU 组件的兼容性,维持整个模块进行端到端的可训练性。输入输出同维度,便于与残差连接(如 output = x + mlp(x))结合。

三、前向传播函数 forward

1. 输入处理与参数初始化

b, n, d, device, dtype = *inputs.shape,  inputs.device,  inputs.dtype  
n_s = num_slots if num_slots is not None else self.num_slots  

   这里 b 表示输入的 batch 号码,n 是输入的数量,d 表示输入的维度。device 和 type 分别表示运行的地方,GPU 还是 CPU.输入数据的类型。n_s 看后面的判断就知道是槽的数量。

        mu = self.slots_mu.expand(b, n_s, -1)
        sigma = self.slots_logsigma.exp().expand(b, n_s, -1)

        slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)

        inputs = self.norm_input(inputs)        
        k, v = self.to_k(inputs), self.to_v(inputs)

  这里进行 μ 和 σ 的初始化,根据这两个参数采用高斯分布来初始化 slot 的数量。expand() 函数是广播函数,我们的 μ 和 σ 都是(1,1,dim) 这里根据实际数据集广播到(batch,slot 的数量,数据维度)此处的 -1 表示输入维度保持不变,我们在 init 函数中已经传入了维度。self.norm_input(input) 对输入特征进行归一化,后面顺带计算 k,v。

2. 循环更新slot

        for _ in range(self.iters):
            slots_prev = slots # 传入上一个slots

            slots = self.norm_slots(slots)
            q = self.to_q(slots)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            attn = dots.softmax(dim=1) + self.eps

            attn = attn / attn.sum(dim=-1, keepdim=True)

            updates = torch.einsum('bjd,bij->bid', v, attn)

            slots = self.gru(
                updates.reshape(-1, d),
                slots_prev.reshape(-1, d)
            )

  迭代更新slot,把上一轮的slot作为初始值,然后进行归一化,在计算Q。K,V 使用输入的 input 计算,有意思啊。点积运算计算 q,k 的乘积,采用爱因斯坦积运算。然后使用GRU了来更新槽,当然我们需要把形状reshape到一个二维平面上面。

            slots = slots.reshape(b, -1, d)
            slots = slots + self.mlp(self.norm_pre_ff(slots))

形状恢复到三维,然后进行卷积,提升模型的表达能力。

此时我们还有一个疑问,那就是数据的形状变化,我后续会详细解答。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值