1. 原型网络 (Prototypical Networks)
核心思想
原型网络提出"类原型"概念,将每个类别表示为支持集中所有样本嵌入的平均向量。预测时通过计算查询样本与各类原型的欧式距离实现分类。
核心:
ck=1∣Sk∣∑(xi,yi)∈Skfϕ(xi)c_k = \frac{1}{|S_k|} \sum_{(x_i,y_i) \in S_k} f_\phi(x_i)ck=∣Sk∣1∑(xi,yi)∈Skfϕ(xi)
p(y=k∣x)=exp(−d(fϕ(x),ck))∑jexp(−d(fϕ(x),cj))p(y=k|x) = \frac{\exp(-d(f_\phi(x), c_k))}{\sum_j \exp(-d(f_\phi(x), c_j))}p(y=k∣x)=∑jexp(−d(fϕ(x),cj))exp(−d(fϕ(x),ck))
算法流程
关键技术
-
嵌入空间设计
- 使用CNN(如Conv64F)提取特征
- 特征空间需满足:类内距离 < 类间距离
-
距离度量
- 欧氏距离:d(z,ck)=∥z−ck∥22d(z, c_k) = \|z - c_k\|^2_2d(z,ck)=∥z−ck∥22
- 线性分类器等价性:−∥z−ck∥2=2ckTz−ckTck+zTz-\|z - c_k\|^2 = 2c_k^Tz - c_k^Tc_k + z^Tz−∥z−ck∥2=2ckTz−ckTck+zTz
-
训练策略
- 小样本情景:5-way 1-shot等设定
- 损失函数:负对数似然损失
L=−logpϕ(y=k∣x) \mathcal{L} = -\log p_\phi(y=k|x) L=−logpϕ(y=k∣x)
创新点
- 无参数分类器:无需额外训练线性层
- 端到端训练:联合优化嵌入函数和分类逻辑
- 泛化性强:适用于零样本学习场景
2. 关系网络 (Relation Network)
核心思想
关系网络将相似度度量本身设计为可学习的神经网络(关系模块),直接预测查询样本与类原型的相似度得分,而非使用固定距离函数。
计算关系分数:
r(gθ(ck,fϕ(x)))r(g_\theta(c_k, f_\phi(x)))r(gθ(ck,fϕ(x)))
其中ckc_kck为类原型,xxx为查询样本
双模块架构
关键技术
-
关系函数设计
- 输入:类原型向量与查询样本嵌入的拼接
- 结构:2-3层全连接网络
- 输出:[0,1]区间的相似度得分
-
损失函数
- 均方误差损失:
L=∑(rk,q−I(yq=k))2 \mathcal{L} = \sum (r_{k,q} - \mathbb{I}(y_q=k))^2 L=∑(rk,q−I(yq=k))2
- 均方误差损失:
-
特征融合
g(ck,xq)=σ(W2[ReLU(W1[ck;xq]+b1)]+b2) g(c_k,x_q) = \sigma(W_2[\text{ReLU}(W_1[c_k;x_q]+b_1)] + b_2) g(ck,xq)=σ(W2[ReLU(W1[ck;xq]+b1)]+b2)
创新点
- 灵活相似度学习:代替固定距离度量
- 端到端训练:支持模块协同优化
- 跨域适应:可处理不同模态数据
3. R2D2算法
核心思想
R2D2结合了线性回归的效率和端到端训练的适应性:
- 使用闭式解计算类别权重
- 通过可微的平方损失实现端到端训练
- 设计双层优化框架实现元学习
核心公式:
W∗=argminW∥XW−Y∥F2+λ∥W∥F2W^* = \arg\min_W \|XW - Y\|_F^2 + \lambda \|W\|_F^2W∗=argminW∥XW−Y∥F2+λ∥W∥F2
闭式解:W∗=XT(XXT+λI)−1YW^* = X^T(XX^T + \lambda I)^{-1}YW∗=XT(XXT+λI)−1Y
算法流程
关键技术
-
基于平方损失的可微分类器
- 损失函数:L=∥XW−Y∥F2+λ∥W∥F2\mathcal{L} = \|XW - Y\|_F^2 + \lambda \|W\|_F^2L=∥XW−Y∥F2+λ∥W∥F2
- 闭式解:W=X⊤(XX⊤+λI)−1YW = X^\top (XX^\top + \lambda I)^{-1}YW=X⊤(XX⊤+λI)−1Y
-
特征提取器优化
- 最小化查询损失:Lquery(fφ)=∑xq,yqℓ(fφ(xq)⊤W∗,yq)\mathcal{L}_{query}(f_φ) = \sum_{x_q,y_q} \ell(f_φ(x_q)^\top W^*, y_q)Lquery(fφ)=∑xq,yqℓ(fφ(xq)⊤W∗,yq)
-
正则化策略
- L2正则:防止小样本过拟合
- 自适应正则系数:平衡偏差方差
创新点
- 高效闭式解:避免迭代优化分类器
- 元学习框架:支持跨任务泛化
- 特征-分类器协同优化:提升特征表示质量
- 二阶优化兼容:支持Hessian矩阵计算
算法对比分析
特性 | 原型网络 | 关系网络 | R2D2 |
---|---|---|---|
分类方式 | 最近类原型 | 学习相似度 | 闭式回归 |
参数更新 | 嵌入函数 | 嵌入+关系函数 | 嵌入函数 |
距离度量 | 欧氏距离 | 可学习函数 | Mahalanobis距离 |
计算复杂度 | O(n) | O(n+k) | O(d^3) |
训练策略 | 端到端 | 端到端 | 双层优化 |
特征质量 | 中等 | 强 | 强 |
小样本表现 | 简单场景优 | 复杂场景好 | 平衡性好 |
应用场景分析
1. 原型网络适用场景
- 计算资源受限的移动端
- 类别极度不平衡场景
- 新增类别频繁的开放世界识别
- 跨模态检索:如图文匹配
2. 关系网络适用场景
- 细粒度图像识别(鸟类/花卉等)
- 工业缺陷检测的少量样本建模
- 医学影像分析中的病灶识别
- 艺术风格鉴定
3. R2D2适用场景
- 大规模类别的元学习
- 多模态融合任务
- 实时推理需求场景
- 增量学习系统
核心差异
+-----------------+
| 特征表示质量 |
+-------+---------+
|
+-------v---------+
| 距离度量适应性 |
+-------+---------+
|
+---------------+---------------+
| |
+-------v-------+ +---------v---------+
| 固定距离度量 | | 自适应度量 |
| (欧氏/余弦) | | (神经网络学习) |
+-------+-------+ +---------+---------+
| |
+-------v-------+ +---------v---------+
| 原型网络 | | 关系网络/R2D2 |
| (简单高效) | | (高精度复杂) |
+---------------+ +-------------------+
这三种算法构成了小样本学习的核心方法论体系,各有侧重解决不同场景下的少样本学习问题。原型网络因其简洁有效成为入门基线,关系网络通过可学习的相似度度量提升性能上限,R2D2则在理论框架上实现突破,通过双层优化实现闭式解分类。