系列文章目录
文章目录
前言
pytorch 基础有了,机器学习的基础不牢靠。想直接进入计算机视觉领域难度很大,一个新领域,开头最难。我们应该抱着学习的热忱,哪里不懂就去查,一定记得回顾。一头雾水,没关系,只要动起来,多看,多理解,我相信零散的知识会织出一件漂亮的衣裳。 本文结合代码和人工智能的解释进行讲解。
一、slot attention 源码的搭建结构
网络搭建的类一般只有两个函数,一个是初始化函数,另一个便是前向传播函数。初始化函数主要设置一些参数和操作算子,这些参数有的默认值,直接赋值,当然也可以训练时更改。前向传播有两个任务,第一个便是根据我们输入的数据集来给初始化函数中的参数赋值,另一个任务是按照一定的顺序调用操作算子实现我们需要的功能。下面给一张结构图。

这个代码我感觉和 transformer 有很多经验之处,查找知识的时候,很多小点都会与 transformer 中的组件进行对比。其中初始化中的基础参数很好理解,数据集的输入与形状,槽的数目和形状设置等等。可学习参数的设置就需要有一定的经验,要有机器学习的基础,在初始化时要提升模型的稳定性,多样性(高斯分布),σ结合数据集设置取值范围。
需要安装的库:
import torch
from torch import nn
from torch.nn import init
二、网络初始化函数 init
网络初始化函数分成两个部分。参数部分:网络基本参数、可学习参数。操作算子部分:注意力机制计算、GRU更新门、归一化层和前馈网络。
函数代码:
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
super().__init__()
self.dim = dim # 初始化维度
self.num_slots = num_slots # 初始化数目
self.iters = iters # 迭代次数
self.eps = eps # 精度
self.scale = dim ** -0.5 #
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) # 这个维度有什么用?
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim)) #
init.xavier_uniform_(self.slots_logsigma) #
self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
self.to_v = nn.Linear(dim, dim)
self.gru = nn.GRUCell(dim, dim)
hidden_dim = max(dim, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace = True),
nn.Linear(hidden_dim, dim)
)
self.norm_input = nn.LayerNorm(dim)
self.norm_slots = nn.LayerNorm(dim)
self.norm_pre_ff = nn.LayerNorm(dim)
1. 网络基本参数
网络基本参数包括槽的数量,特征的维度,迭代的次数、数值稳定项和隐藏层的维度。这里面介绍两个eps和scale。
1) eps
eps是一个非常小的正数,对结果几乎没有影响。核心作用:
- 在注意力权重计算中,防止出现除零错误。
- 某些权重为零的时候,保证分母不会为零。
有些会使用clamp(min=eps)
代替直接相加,常常用在概率归一化、注意力计算、梯度计算场景,保证数值稳定的标准实践。
2)scale
这个参数用于点积注意力的放缩,点积运算就是矩阵乘法,前一个矩阵的列要求等于后一个矩阵的行,而方阵的对应元素相乘叫做哈达玛积。
- 数学原理
公式表示:scale = 1 d i m \frac{1}{\sqrt{dim}} dim1 ,其中 dim 是查询(Q)和键(K)向量的维度。- 作用: 缩放点积运算后的结果,防止其值过大或者过小。
- 稳定梯度: 防止 softmax 层输入过大,防止梯度消失。
- 控制数值范围: 使得注意力机制更加平滑,避免极端值出现。
- 理论保障: 符合点积的方差分析,确保模型训练的稳定性。
- 为什么要放缩?
- 当 dim 较大时,因为每个维度贡献的累加,点积的结果会倾向于变得很大,导致 soft_max 层梯度变得非常小,类似于 one-hot 编码,这被称为点积规模过大问题。
- dim 很大时,K * V 的可能非常大,soft_max 层接近 0 或 1, 模型难以学习。
- 放缩后: k ⋅ v d i m \frac{k \cdot v}{\sqrt{dim}} dimk⋅v 这个值保持在 [ − d i m , d i m ] [-\sqrt{dim},\sqrt{dim}] [−dim,dim] 附近,soft_max的输出更加平滑。
- 其他方案
- 有些模型可以使用可学习的放缩因子,而不是固定的,如 Perceiver IO.
- 在稀疏注意力中,放缩因子可能会调整来适应不同的注意力模式。
2. 可学习的初始槽位参数
课学sigma习的参数主要有 slot_mu
和 slot_logsigma
,简称 μ 和 σ。
- 为什么维度是(1,1,dim)?
- 广播机制: Pytorch 会自动将(1, 1, dim)广播到 (batch_size, num_slots, dim) ,避免重复储存,而且这里的运算也是张量,而不是数字型的标量。
- 内存效率: 相比较(num_slots, dim),这种方式更加节省内存,尤其是 slot 的数量较大的时候。
- 参数含义及初始化
参数 | 作用 | 初始化 |
---|---|---|
μ | 高斯分布的均值 | 标准正态分布 randn |
σ | 高斯分布的对数标准差 | 初始化为0,后用 Xavier 初始化 |
- 为什么使用 Xavier 对 σ 实现对数初始化?
- 保持方差稳定: Xavier 初始化根据输入输出维度调整初始化范围,避免了梯度爆炸或者消失。
- 标准的高斯分布,标准差为 0 ,此时使用 Xavier ,不是固定的 0 ,是为了让模型能够学习到更灵活的初始分布。
- 为什么不对 μ 使用 Xavier ?
- 高斯分布本身已经具备合理的初始多样性。
- 模型后续会通过注意力机制迅速调整槽位。
- 超参数选择建议:
- dim 较大时: 建议调小 σ 初始范围,避免噪声过大。
- num_slots 较多时: 改用(num_slots, dim)而非广播,增强槽位差异性。
3. 注意力计算层
注意力计算是一个很好工具,也是 transformer 底层的核心功能,在这里大家可以跳到相应的章节进行 Q, K, V 查询键值的学习,点击。大家关注一下张量的变化,实现的过程,此处解释一下含义。
为什么三个角色独立变化?
-
角色分离:
- Query: 表示当前关注的问题,比如槽位。
- Key: 表示被检索的索引,比如说特征。
- Value: 表示被提取的内容,比如特征的语义信息。
-
灵活性: 独立的参数允许模型学习不同的投影。
4. GRU 更新门
GRUCell 门控循环单元,用于在 slot attention 中更新槽位的状态。
updated_slot = GRUCell(aggregated_info, current_slot)
在每次更新迭代中,根据槽位和注意力聚合结果,动态更新槽位状态。
aggregated_info
是注意力加权的值 V 聚合后的结果。current_slot
是上一轮槽位的状态。
- 为什么使用 GRU 而不使用普通的MLP ?
方法 | 优点 | 适用场景 |
---|---|---|
GRU | 显示建模时序依赖、门控防止梯度消失 | 需要迭代更新的场景 |
MLP | 结构简单 | 单步前向推理 |
- 门控机制: GRU 的更新门 update gate 和重置门 reset gate 可以稳定的处理长期依赖。
- 在原文中也验证了 GRU 的效果更好。
-
设计关键点
- 维度一致: 输入层和隐藏层都是 dim,避免了槽位更新前后位置不变。
- 独立更新: 每个槽位通过独立的 GRU 更新,保证槽位的特异性。
- 迭代特性: GRU 的隐藏层在多次迭代中逐步优化
-
超参数选择建议
- dim 较大,可以尝试增加 GRU隐藏层的维度,但是需要保证输入和输出的兼容性。
- num_slot 较大,进行深层迭代时(迭代次数 > 8),建议采用 LayerNorm 稳定训练。
5. 归一化和前馈神经网络
三个模块的归一化就看下面这张表吧。
归一化层 | 功能 | 作用 |
---|---|---|
self.norm_input | 特征归一化 | 稳定初始分布 |
self.norm_slots | 槽位状态归一化 | 防止数值爆炸 |
self.norm_pre_ff | 前馈网络前归一化 | 在MLP前标准化,提升梯度稳定性 |
前馈网络 MLP
这是一个经典的前馈神经网络,MLP 通过两头的升维和降维之间的配合,使用ReLU激活函数,拟合复杂的函数关系,提升模型的表达能力。MLP 输入维度和输出维度一致性,确保了与自注意力机制、GRU 组件的兼容性,维持整个模块进行端到端的可训练性。输入输出同维度,便于与残差连接(如 output = x + mlp(x))结合。
三、前向传播函数 forward
1. 输入处理与参数初始化
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
n_s = num_slots if num_slots is not None else self.num_slots
这里 b 表示输入的 batch 号码,n 是输入的数量,d 表示输入的维度。device 和 type 分别表示运行的地方,GPU 还是 CPU.输入数据的类型。n_s 看后面的判断就知道是槽的数量。
mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)
slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)
inputs = self.norm_input(inputs)
k, v = self.to_k(inputs), self.to_v(inputs)
这里进行 μ 和 σ 的初始化,根据这两个参数采用高斯分布来初始化 slot 的数量。expand() 函数是广播函数,我们的 μ 和 σ 都是(1,1,dim) 这里根据实际数据集广播到(batch,slot 的数量,数据维度)此处的 -1 表示输入维度保持不变,我们在 init 函数中已经传入了维度。self.norm_input(input)
对输入特征进行归一化,后面顺带计算 k,v。
2. 循环更新slot
for _ in range(self.iters):
slots_prev = slots # 传入上一个slots
slots = self.norm_slots(slots)
q = self.to_q(slots)
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum('bjd,bij->bid', v, attn)
slots = self.gru(
updates.reshape(-1, d),
slots_prev.reshape(-1, d)
)
迭代更新slot,把上一轮的slot作为初始值,然后进行归一化,在计算Q。K,V 使用输入的 input 计算,有意思啊。点积运算计算 q,k 的乘积,采用爱因斯坦积运算。然后使用GRU了来更新槽,当然我们需要把形状reshape到一个二维平面上面。
slots = slots.reshape(b, -1, d)
slots = slots + self.mlp(self.norm_pre_ff(slots))
形状恢复到三维,然后进行卷积,提升模型的表达能力。
此时我们还有一个疑问,那就是数据的形状变化,我后续会详细解答。