强化学习:Prioritized Experience Replay 学习笔记

一、什么是 Prioritized Experience Replay?🎯

1.1 核心概念

  • 传统经验回放的问题:传统经验回放是随机抽样,但有些经验(如游戏中罕见的胜利时刻)更重要却很少出现。
  • 优先经验回放:智能筛选重要经验优先学习,提升学习效率。

1.2 要解决的核心问题

场景问题PER的解决方案
稀疏奖励(如游戏通关)成功经验极少被抽到主动聚焦关键经验
关键决策点(如避免车祸)平等对待所有时刻重点学习危险时刻
样本不均衡简单场景反复学习动态调整学习重点

二、为什么需要优先学习?🚀

2.1 关键指标:TD-error

TD_error = |真实奖励 - 预测奖励| = |Q现实 - Q估计|
  • 物理意义:预测的不准确程度
  • 优先级标准
    • TD-error 大 → 预测不准 → 高优先级(需重点学习)
    • TD-error 小 → 预测较准 → 低优先级

📌 类比:误差越大,说明该经验学习价值越高,学它能快速提升,所以优先级 p 就越高 。

2.2 优先级计算

  • 比例优先级P(i)=piα∑kpkαP(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}P(i)=kpkαpiα pi=∣TD−error∣+ϵp_i = |TD-error| + \epsilonpi=TDerror+ϵα\alphaα 控制优先级的影响程度(α=0\alpha=0α=0 退化为均匀采样,α=1\alpha=1α=1 为严格按优先级)
  • 排名优先级P(i)=1rank(i)αP(i) = \frac{1}{rank(i)^\alpha}P(i)=rank(i)α1 rank(i)rank(i)rank(i) 是样本按TD误差排序后的排名

三、核心技术:SumTree 数据结构🌳

3.1 结构图解

42
29
13
13
16
3
10
12
4
11
2
  • 叶子节点:存储经验的优先级值(如3,10,12,4…)
  • 父节点 = 左子节点 + 右子节点
  • 根节点 = 所有优先级总和

3.2 高效抽样过程

  • 将优先级总和按批次大小等分区间,即用总优先级除以 batch size。
  • 从 root 开始往下找:对比当前节点的左子节点值,比左子节点大,就去右子节点,同时更新当前值(减去左子节点值);否则去左子节点,直到找到叶子节点,对应的就是要采样的样本。

eg:假设优先级总和=42,抽3个样本:

  1. 划分区间:[0-14), [14-28), [28-42]
  2. 每个区间随机选数:如选5, 20, 35
  3. 从根节点向下搜索(以35为例):
    • 35 > 左子29 → 向右走,剩余值 = 35 - 29 = 6
    • 6 < 右子13 → 向左走
    • 6 < 左子11 → 到达叶子节点(优先级11的样本)

四、完整工作流程🔧

与环境交互
存储经验
赋予初始高优先级
是否训练?
SumTree抽样
计算TD-error
更新网络参数
更新样本优先级

4.1 具体步骤

  1. 初始化

    • 主网络 Q(s,a;θ)Q(s,a;\theta)Q(s,a;θ) 和目标网络 Q−(s,a;θ−)Q^-(s,a;\theta^-)Q(s,a;θ)
    • 优先经验回放缓冲区 DDD,使用SumTree数据结构高效维护优先级;
    • 设置参数:学习率 α\alphaα,折扣因子 γ\gammaγ,优先级指数 α\alphaα,重要性采样指数 β\betaβ,目标网络更新频率 CCC
  2. 对于每个episode

    • 初始化状态 s1s_1s1
    • 对于每个时间步 ttt
      • ϵ\epsilonϵ-贪婪策略选择动作 ata_tat
      • 执行动作 ata_tat,获得奖励 rtr_trt 和下一状态 st+1s_{t+1}st+1
      • 计算TD误差 δt\delta_tδt,并将 (st,at,rt,st+1,donet,δt)(s_t,a_t,r_t,s_{t+1},\text{done}_t, \delta_t)(st,at,rt,st+1,donet,δt) 存入缓冲区 DDD
      • DDD 中按优先级采样批次数据 BBB,同时获取重要性采样权重 wiw_iwi
      • 计算TD目标 yiy_iyi** 对于每条采样经验:
        • 如果 donei==True:yi=ridone_i == True: y_i = r_idonei==True:yi=ri (终止状态无未来奖励)
        • 否则:yi=ri+γ∗maxa′Q^(si+1,a′;θ−)y_i = r_i + γ * max_{a'} Q̂(s_{i+1}, a'; θ⁻)yi=ri+γmaxaQ^(si+1,a;θ)(使用目标网络计算最大未来 Q 值)
      • 更新主网络: 通过最小化加权损失函数 L(θ)=1∣B∣∑i∈Bwi⋅(yi−Q(si,ai;θ))2\mathcal{L}(\theta) = \frac{1}{|B|}\sum_{i\in B} w_i \cdot (y_i - Q(s_i,a_i;\theta))^2L(θ)=B1iBwi(yiQ(si,ai;θ))2 更新 θ\thetaθ
      • 更新样本优先级:pi=∣δi∣+ϵp_i = |\delta_i| + \epsilonpi=δi+ϵ
      • CCC 步更新目标网络:θ−←θ\theta^- \leftarrow \thetaθθ

五、重要性采样:避免偏差的秘诀🎯

5.1 为什么需要?

  • 高优先级样本被频繁抽样 → 导致过拟合
  • 就像只复习错题会忽略基础知识

5.2 解决方案:重要性权重

wi=(1N⋅P(i))β w_i = \left( \frac{1}{N \cdot P(i)} \right)^\beta wi=(NP(i)1)β

  • NNN:记忆库总样本数
  • P(i)P(i)P(i):样本被抽中的概率
  • β\betaβ:从 0.4→1.0 逐渐增加(后期更强调偏差修正)

5.3 实际应用

# 伪代码实现
max_weight = (min_p * N)**(-beta)  # 归一化基准

for sample in batch:
    p = sample_priority / total_p
    weight = (N * p)**(-beta) / max_weight
    loss = weight * (td_error)**2

六、算法优缺点总结⚖️

6.1 核心优势

  • 学习效率高:重要样本优先学(尤其在稀疏奖励场景),像 MountainCar 里的成功样本,能加速模型收敛,少走弯路。
  • 策略质量提升:关键样本多次学习,有助于模型精准掌握复杂决策逻辑,让最终策略更优,比如游戏中更会应对关键局面。

6.2 潜在挑战

问题解决方案
过度关注早期错误定期重置优先级
计算开销稍增使用SumTree保持高效
超参数敏感α=0.6,βstart=0.4\alpha=0.6, \beta_{start}=0.4α=0.6,βstart=0.4

七、实际应用建议💡

7.1 应用场景

  1. 稀疏奖励任务(如MountainCar)
  2. 安全关键决策(自动驾驶避障)
  3. 大规模训练(减少收敛时间)

7.2 参数设置指南

# 通用最佳实践
replay_buffer_size = 1,000,000  # 记忆库大小
batch_size = 64                 # 批次大小
alpha = 0.6                     # 优先级强度
beta_start = 0.4                # 初始偏差修正
beta_end = 1.0                  # 最终偏差修正
beta_annealing_steps = 1000000  # β从0.4→1.0的步数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值