小样本学习算法

1. 原型网络 (Prototypical Networks)

核心思想

原型网络提出"类原型"概念,将每个类别表示为支持集中所有样本嵌入的平均向量。预测时通过计算查询样本与各类原型的欧式距离实现分类。

核心:
ck=1∣Sk∣∑(xi,yi)∈Skfϕ(xi)c_k = \frac{1}{|S_k|} \sum_{(x_i,y_i) \in S_k} f_\phi(x_i)ck=Sk1(xi,yi)Skfϕ(xi)
p(y=k∣x)=exp⁡(−d(fϕ(x),ck))∑jexp⁡(−d(fϕ(x),cj))p(y=k|x) = \frac{\exp(-d(f_\phi(x), c_k))}{\sum_j \exp(-d(f_\phi(x), c_j))}p(y=kx)=jexp(d(fϕ(x),cj))exp(d(fϕ(x),ck))

算法流程

输入
支撑集 Support Set
嵌入函数f_φ提取特征
计算类原型向量c_k
查询集 Query Set
计算特征嵌入
计算与各类原型的距离
Softmax分类

关键技术

  1. 嵌入空间设计

    • 使用CNN(如Conv64F)提取特征
    • 特征空间需满足:类内距离 < 类间距离
  2. 距离度量

    • 欧氏距离:d(z,ck)=∥z−ck∥22d(z, c_k) = \|z - c_k\|^2_2d(z,ck)=zck22
    • 线性分类器等价性:−∥z−ck∥2=2ckTz−ckTck+zTz-\|z - c_k\|^2 = 2c_k^Tz - c_k^Tc_k + z^Tzzck2=2ckTzckTck+zTz
  3. 训练策略

    • 小样本情景:5-way 1-shot等设定
    • 损失函数:负对数似然损失
      L=−log⁡pϕ(y=k∣x) \mathcal{L} = -\log p_\phi(y=k|x) L=logpϕ(y=kx)

创新点

  • 无参数分类器:无需额外训练线性层
  • 端到端训练:联合优化嵌入函数和分类逻辑
  • 泛化性强:适用于零样本学习场景

2. 关系网络 (Relation Network)

核心思想

关系网络将相似度度量本身设计为可学习的神经网络(关系模块),直接预测查询样本与类原型的相似度得分,而非使用固定距离函数。

计算关系分数:
r(gθ(ck,fϕ(x)))r(g_\theta(c_k, f_\phi(x)))r(gθ(ck,fϕ(x)))
其中ckc_kck为类原型,xxx为查询样本

双模块架构

关系模块
特征提取模块
类原型构建
查询样本特征
关系函数g_θ
关系得分
嵌入函数f_φ
输入图像
特征向量

关键技术

  1. 关系函数设计

    • 输入:类原型向量与查询样本嵌入的拼接
    • 结构:2-3层全连接网络
    • 输出:[0,1]区间的相似度得分
  2. 损失函数

    • 均方误差损失:
      L=∑(rk,q−I(yq=k))2 \mathcal{L} = \sum (r_{k,q} - \mathbb{I}(y_q=k))^2 L=(rk,qI(yq=k))2
  3. 特征融合
    g(ck,xq)=σ(W2[ReLU(W1[ck;xq]+b1)]+b2) g(c_k,x_q) = \sigma(W_2[\text{ReLU}(W_1[c_k;x_q]+b_1)] + b_2) g(ck,xq)=σ(W2[ReLU(W1[ck;xq]+b1)]+b2)

创新点

  1. 灵活相似度学习:代替固定距离度量
  2. 端到端训练:支持模块协同优化
  3. 跨域适应:可处理不同模态数据

3. R2D2算法

核心思想

R2D2结合了线性回归的效率和端到端训练的适应性:

  1. 使用闭式解计算类别权重
  2. 通过可微的平方损失实现端到端训练
  3. 设计双层优化框架实现元学习

核心公式:
W∗=arg⁡min⁡W∥XW−Y∥F2+λ∥W∥F2W^* = \arg\min_W \|XW - Y\|_F^2 + \lambda \|W\|_F^2W=argminWXWYF2+λWF2
闭式解:W∗=XT(XXT+λI)−1YW^* = X^T(XX^T + \lambda I)^{-1}YW=XT(XXT+λI)1Y

算法流程

训练任务嵌入模型分类器采样支持集S提取特征f_φ(x_i)构建特征矩阵X计算闭式解W*采样查询集Q提取查询特征预测查询标签回传梯度更新参数φ训练任务嵌入模型分类器

关键技术

  1. 基于平方损失的可微分类器

    • 损失函数:L=∥XW−Y∥F2+λ∥W∥F2\mathcal{L} = \|XW - Y\|_F^2 + \lambda \|W\|_F^2L=XWYF2+λWF2
    • 闭式解:W=X⊤(XX⊤+λI)−1YW = X^\top (XX^\top + \lambda I)^{-1}YW=X(XX+λI)1Y
  2. 特征提取器优化

    • 最小化查询损失:Lquery(fφ)=∑xq,yqℓ(fφ(xq)⊤W∗,yq)\mathcal{L}_{query}(f_φ) = \sum_{x_q,y_q} \ell(f_φ(x_q)^\top W^*, y_q)Lquery(fφ)=xq,yq(fφ(xq)W,yq)
  3. 正则化策略

    • L2正则:防止小样本过拟合
    • 自适应正则系数:平衡偏差方差

创新点

  1. 高效闭式解:避免迭代优化分类器
  2. 元学习框架:支持跨任务泛化
  3. 特征-分类器协同优化:提升特征表示质量
  4. 二阶优化兼容:支持Hessian矩阵计算

算法对比分析

特性原型网络关系网络R2D2
分类方式最近类原型学习相似度闭式回归
参数更新嵌入函数嵌入+关系函数嵌入函数
距离度量欧氏距离可学习函数Mahalanobis距离
计算复杂度O(n)O(n+k)O(d^3)
训练策略端到端端到端双层优化
特征质量中等
小样本表现简单场景优复杂场景好平衡性好

应用场景分析

1. 原型网络适用场景

  • 计算资源受限的移动端
  • 类别极度不平衡场景
  • 新增类别频繁的开放世界识别
  • 跨模态检索:如图文匹配

2. 关系网络适用场景

  • 细粒度图像识别(鸟类/花卉等)
  • 工业缺陷检测的少量样本建模
  • 医学影像分析中的病灶识别
  • 艺术风格鉴定

3. R2D2适用场景

  • 大规模类别的元学习
  • 多模态融合任务
  • 实时推理需求场景
  • 增量学习系统

核心差异

                    +-----------------+
                    | 特征表示质量     |
                    +-------+---------+
                            |
                    +-------v---------+
                    | 距离度量适应性   |
                    +-------+---------+
                            |
            +---------------+---------------+
            |                               |
    +-------v-------+             +---------v---------+
    | 固定距离度量    |             | 自适应度量        |
    | (欧氏/余弦)     |             | (神经网络学习)     |
    +-------+-------+             +---------+---------+
            |                               |
    +-------v-------+             +---------v---------+
    | 原型网络       |             | 关系网络/R2D2     |
    | (简单高效)     |             | (高精度复杂)       |
    +---------------+             +-------------------+

这三种算法构成了小样本学习的核心方法论体系,各有侧重解决不同场景下的少样本学习问题。原型网络因其简洁有效成为入门基线,关系网络通过可学习的相似度度量提升性能上限,R2D2则在理论框架上实现突破,通过双层优化实现闭式解分类。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值