首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Prioritized Experience Replay (DQN)——让DQN变得更会学习

Prioritized Experience Replay (DQN)——让DQN变得更会学习

作者头像
CristianoC
发布于 2020-05-31 07:15:29
发布于 2020-05-31 07:15:29
1.8K00
代码可运行
举报
运行总次数:0
代码可运行

目录

1.前言2.算法2.1 SumTree有效抽样2.2 Memory类2.3 更新方法对比结果

1.前言

这次我们还是使用MountainCar来进行实验,因为这次我们不需要重度改变它的reward了。所以只要是没有拿到小旗子reward=-1,拿到小旗子时,我们定义它获得了+10的reward。比起之前DQN中,这个reward定义更加准确。如果使用这种reward定义方式,可以想象Natural DQN会花很长时间学习。因为记忆库中只有很少很少的+10的reward可以学习,正负样本不一样。而使用Prioritized replay,就会重视这种少量,但值得学习的样本。

接下来我们就来看看他是怎么做到的。

2.算法

这一套算法的重点就在我们batch抽样的时候并不是随机抽样的,而是按照Memory中的样本优先级来抽。所以这能更有效地找到我们需要学习的样本。

那么样本的优先级是怎么定呢?原来我们可以用到TD-error,也就是Q现实-Q估计来规定优先学习的成都。如果TD-error越大,就代表我们的预测精度还有很多上升空间,那么这个样本就越需要被学习,也就是优先级p越高。

有了TD-error就有了优先级p,那我们如何有效地根据p来抽样呢?如果每次抽样都需要针对p对所有样本排序,这将会是一件非常消耗计算能力的事情,好在我们还有其他方法,这种方法不会对得到的样本进行排序,这就是论文中说到的SumTree

SumTree是一种树形结构,每片输液存储每个样本的优先级p,每个树枝节点只有两个分叉,节点的值是两个分叉的合,所以SumTree的顶端就是所有p的合。正如下面图片,最下面一层树叶存储样本的p。叶子上一层最左边的13=3+10,按这个规律相加,顶层的roor就是全部p的合了。

抽样的时,我们会将p的总和除以batch size,分成batch size那么多区间,(n=sum(p)/batch_size)。如果将所有node的priority加起来是42的话,我们如果抽6个样本,这时的区间拥有的priority可能是这样的: [0-7], [7-14], [14-21], [21-28], [28-35], [35-42] 然后在每个区间里随机选取一个数。比如在第4个区间[21-28]选到了24,就按照这个24从最顶上的42开始往下搜索。首先看到最顶上42下面有两个child nodes,拿着手中的24对比左边的chlid29,如果左边的chlid比自己手中的值大,那我们就走左边这条路,接着再对比29下面的左边那个点13,这时,手中的24比13大,那我们就走右边的路,并且将手中的值根据13修改一下,变成24-13=11.接着拿11和13右下角的12比,结果12比11大,那我们就选12当做这次选到的priority,并且也选择12对应的数据。

2.1 SumTree有效抽样

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class SumTree(object):
 2    # 建立 tree 和 data,
 3    # 因为 SumTree 有特殊的数据结构,
 4    # 所以两者都能用一个一维 np.array 来存储
 5    def __init__(self, capacity):
 6
 7    # 当有新 sample 时, 添加进 tree 和 data
 8    def add(self, p, data):
 9
10    # 当 sample 被 train, 有了新的 TD-error, 就在 tree 中更新
11    def update(self, tree_idx, p):
12
13    # 根据选取的 v 点抽取样本
14    def get_leaf(self, v):
15
16    # 获取 sum(priorities)
17    @property
18    def totoal_p(self):

2.2 Memory类

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class Memory(object):
 2    # 建立 SumTree 和各种参数
 3    def __init__(self, capacity):
 4
 5    # 存储数据, 更新 SumTree
 6    def store(self, transition):
 7
 8    # 抽取 sample
 9    def sample(self, n):
10
11    # train 完被抽取的 samples 后更新在 tree 中的 sample 的 priority
12    def batch_update(self, tree_idx, abs_errors):

具体完整的代码我在最后会附上我github的链接,这里说一下这个关于ISweight到底怎么算。需要提到一点是,代码中的计算方法是经过了简化的,将论文中的步骤合并了一些,比如:prob = p / self.tree.total_p; ISWeights = np.power(prob/min_prob, -self.beta)

在paper 中, ISWeight = (N*Pj)^(-beta) / maxi_wi里面的maxi_wi是为了 normalize ISWeight, 所以我们先把他放在一边. 所以单纯的importance sampling 就是(N*Pj)^(-beta),那 maxi_wi = maxi[(N*Pi)^(-beta)].

如果将这两个式子合并,

ISWeight = (N*Pj)^(-beta) / maxi[ (N*Pi)^(-beta) ]

而且如果将maxi[ (N*Pi)^(-beta)]中的 (-beta) 提出来, 这就变成了mini[ (N*Pi) ] ^ (-beta)

看出来了吧, 有的东西可以抵消掉的. 最后

ISWeight = (Pj / mini[Pi])^(-beta)

这样我们就有了代码中的样子.

还有代码中的alpha是一个决定我们要使用多少 ISweight 的影响, 如果alpha = 0,我们就没使用到任何 Importance Sampling.

2.3 更新方法

我们在_init_中加一个prioritized参数来表示DQN是否具备prioritized能力。为了对比的需要,我们的tf.Session()也单独传入,并移除原本在DQN代码中的这一句:self.sess.run(tf.global_variables_initializer())

搭建神经网络时,我们发现DQN with Prioritized replay只多了一个ISWeights,这个正是刚刚算法中提到的Importance-Sampling Weights,用来恢复被Prioritized replay打乱的抽样概率分布。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class DQNPrioritizedReplay:
 2    def _build_net(self)
 3        ...
 4        # self.prioritized 时 eval net 的 input 多加了一个 ISWeights
 5        self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s')  # input
 6        self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q_target')  # for calculating loss
 7        if self.prioritized:
 8            self.ISWeights = tf.placeholder(tf.float32, [None, 1], name='IS_weights')
 9
10        ...
11        # 为了得到 abs 的 TD error 并用于修改这些 sample 的 priority, 我们修改如下
12        with tf.variable_scope('loss'):
13            if self.prioritized:
14                self.abs_errors = tf.reduce_sum(tf.abs(self.q_target - self.q_eval), axis=1)    # for updating Sumtree
15                self.loss = tf.reduce_mean(self.ISWeights * tf.squared_difference(self.q_target, self.q_eval))
16            else:
17                self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))

因为和Natural DQN使用的Memory不一样,所以在存储transition的时候方式也略不相同

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class DQNPrioritizedReplay:
 2    def store_transition(self, s, a, r, s_):
 3        if self.prioritized:    # prioritized replay
 4            transition = np.hstack((s, [a, r], s_))
 5            self.memory.store(transition)
 6        else:       # random replay
 7            if not hasattr(self, 'memory_counter'):
 8                self.memory_counter = 0
 9            transition = np.hstack((s, [a, r], s_))
10            index = self.memory_counter % self.memory_size
11            self.memory[index, :] = transition
12            self.memory_counter += 1

我们在learn()部分的改变也在如下展示:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class DQNPrioritizedReplay:
 2    def learn(self):
 3        ...
 4        # 相对于 DQN 代码, 改变的部分
 5        if self.prioritized:
 6            tree_idx, batch_memory, ISWeights = self.memory.sample(self.batch_size)
 7        else:
 8            sample_index = np.random.choice(self.memory_size, size=self.batch_size)
 9            batch_memory = self.memory[sample_index, :]
10
11        ...
12
13        if self.prioritized:
14            _, abs_errors, self.cost = self.sess.run([self._train_op, self.abs_errors, self.loss],
15                                         feed_dict={self.s: batch_memory[:, :self.n_features],
16                                                    self.q_target: q_target,
17                                                    self.ISWeights: ISWeights})
18            self.memory.batch_update(tree_idx, abs_errors)   # update priority
19        else:
20            _, self.cost = self.sess.run([self._train_op, self.loss],
21                                         feed_dict={self.s: batch_memory[:, :self.n_features],
22                                                    self.q_target: q_target})
23
24        ...

对比结果

运行我Github中的这个MountainCar脚本,我们就不难发现,我们都从两种方法最初拿到第一个R+=10奖励的时候算起,看看经历过一次R+=10后,他们有没有好好利用这次的奖励,可以看出,有 Prioritized replay的可以高效地利用这些不常拿到的奖励,并好好学习他们。所以Prioritized replay 会更快结束每个 episode, 很快就到达了小旗子。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-07-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 计算机视觉漫谈 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
DQN三大改进(二)-Prioritised replay
Prioritised replay原文:https://arxiv.org/pdf/1511.05952.pdf 代码地址:https://github.com/princewen/tensorflow_practice/tree/master/Prioritized_Replay_DQN_demo 如果大家觉得代码排版较乱,可以参考原文:https://www.jianshu.com/p/db14fdc67d2c 1、背景 这篇文章我们会默认大家已经了解了DQN的相关知识,如果大家对于DQN还不是很了解
石晓文
2018/04/11
2.9K0
DQN三大改进(二)-Prioritised replay
Double DQN——解决DQN中的过估计问题
本篇教程是基于Deep Q network(DQN)的教程,缩减了在DQN方面的介绍,着重强调Double DQN和DQN的不同之处。
CristianoC
2020/06/02
2K0
实战深度强化学习DQN-理论和实践
1、Q-learning回顾 Q-learning 的 算法过程如下图所示: 在Q-learning中,我们维护一张Q值表,表的维数为:状态数S * 动作数A,表中每个数代表在当前状态S下可以采用动作
石晓文
2018/04/11
3K0
实战深度强化学习DQN-理论和实践
DQN系列(2): Double DQN算法原理与实现
论文地址: https://arxiv.org/pdf/1509.06461.pdf
深度强化学习实验室
2020/02/11
2.2K0
DQN系列(2): Double DQN算法原理与实现
强化学习(十一) Prioritized Replay DQN
    在强化学习(十)Double DQN (DDQN)中,我们讲到了DDQN使用两个Q网络,用当前Q网络计算最大Q值对应的动作,用目标Q网络计算这个最大动作对应的目标Q值,进而消除贪婪法带来的偏差。今天我们在DDQN的基础上,对经验回放部分的逻辑做优化。对应的算法是Prioritized Replay DQN。
刘建平Pinard
2018/10/22
1.1K0
强化学习(十一) Prioritized Replay DQN
强化学习算法总结(一)——从零到DQN变体
中对应价值最大的动作的Q值进行更新,注意这里只是更新,并不会真的执行这个价值最大的动作。这里的更新策略(评估策略)与我们的行为策略(
CristianoC
2021/04/16
2.8K0
强化学习算法总结(一)——从零到DQN变体
Rainbow:整合DQN六种改进的深度强化学习方法!
在2013年DQN首次被提出后,学者们对其进行了多方面的改进,其中最主要的有六个,分别是: Double-DQN:将动作选择和价值估计分开,避免价值过高估计 Dueling-DQN:将Q值分解为状态价值和优势函数,得到更多有用信息 Prioritized Replay Buffer:将经验池中的经验按照优先级进行采样 Multi-Step Learning:使得目标价值估计更为准确 Distributional DQN(Categorical DQN):得到价值分布 NoisyNet:增强模型的探索能力
石晓文
2019/01/02
3.6K0
强化学习-DQN
之前两篇文章介绍的内容其实都属于策略网络,即用神经网络去模拟在给定状态s下,每个动作a的执行概率。这篇用到的DQN则属于值函数网络,在这一大类里又可以分为:状态值函数和状态-动作值函数,DQN属于后者,即用神经网络去模拟在给定状态s和动作a的情况下,回报的期望。
luxuantao
2021/02/24
8930
强化学习系列案例 | 训练智能体玩Flappy Bird游戏
Flappy Bird是一款简单操作的手机游戏,在游戏中有一只飞翔的小鸟,在飞行中会遇到管道障碍物,玩家需要操控小鸟往上飞,飞行过程中不能坠地也不能触碰障碍物,不断的实行动作会飞的越来越高;如果不采取飞行动作,则会快速下降。因此玩家要使用合适的策略控制小鸟飞行,使小鸟不会坠地同时能顺利地穿越障碍物。本案例使用强化学习算法DQN训练智能体,使其最终学会玩Flappy Bird游戏。
数据酷客
2020/04/24
3.1K1
强化学习系列案例 | 训练智能体玩Flappy Bird游戏
【强化学习】DQN 在运筹学中的应用
前段时间给出了 Q-Learning 在排班调度中的应用,现在给出 DQN 的实现。
阿泽 Crz
2020/11/09
1K0
【深度强化学习】DQN训练超级玛丽闯关
上一期 MyEncyclopedia公众号文章 通过代码学Sutton强化学习:从Q-Learning 演化到 DQN,我们从原理上讲解了DQN算法,这一期,让我们通过代码来实现DQN 在任天堂经典的超级玛丽游戏中的自动通关吧。本系列将延续通过代码学Sutton 强化学习系列,逐步通过代码实现经典深度强化学习应用在各种游戏环境中。本文所有代码在
黄博的机器学习圈子
2020/12/11
1.4K0
【深度强化学习】DQN训练超级玛丽闯关
【强化学习】DQN 的各种改进
DQN 发表于 NIPS 2013,在此之后 DeepMind 不断对 DQN 进行改进,首先在 2015 年初发布了 Nature 文章,提出了 Nature 版本的 DQN,然后接下来在 2015 年一年内提出了 Double DQN,Prioritied Replay,还有 Dueling Network 三种主要方法,又极大的提升了 DQN 的性能,目前的改进型 DQN 算法在 Atari 游戏的平均得分是 Nature 版 DQN 的三倍之多。因此,在本文中,我们将介绍一下各个改进的方法,并在最后给出用 Nature-DQN 的实现方法。
阿泽 Crz
2020/11/17
3.5K0
【强化学习】DQN 的各种改进
深度强化学习 | DQN训练超级玛丽闯关
本系列将延续通过代码学Sutton 强化学习系列,逐步通过代码实现经典深度强化学习应用在各种游戏环境中。本文所有代码在
NewBeeNLP
2021/03/03
1.5K0
深度强化学习 | DQN训练超级玛丽闯关
【强化学习】DQN 的三种改进在运筹学中的应用
这篇文章主要介绍 DQN 的三种改进:Nature DQN、Double DQN、Dueling DQN 在运筹学中的应用,并给出三者的对比,同时也会给出不同增量学习下的效果。
阿泽 Crz
2020/12/11
1.5K0
【强化学习】DQN 的三种改进在运筹学中的应用
深度强化学习-DDPG算法原理和实现
基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作。如果我们省略中间的步骤,即直接根据当前的状态来选择动作。基于这种思想我们就引出了强化学习中另一类很重要的算法,即策略梯度(Policy Gradient)。之前我们已经介绍过策略梯度的基本思想和实现了,大家可以有选择的进行预习和复习:
用户1332428
2023/03/28
1K0
深度强化学习-DDPG算法原理和实现
【RL Base】强化学习核心算法:深度Q网络(DQN)算法
深度Q网络(DQN)是深度强化学习的核心算法之一,由Google DeepMind在2015年的论文《Playing Atari with Deep Reinforcement Learning》中提出。DQN通过结合深度学习和强化学习,利用神经网络近似Q值函数,在高维、连续状态空间的环境中表现出了强大的能力。
不去幼儿园
2024/12/03
5310
【RL Base】强化学习核心算法:深度Q网络(DQN)算法
PaddlePaddle版Flappy-Bird—使用DQN算法实现游戏智能
刚刚举行的 WAVE SUMMIT 2019 深度学习开发者峰会上,PaddlePaddle 发布了 PARL 1.1 版本,这一版新增了 IMPALA、A3C、A2C 等一系列并行算法。作者重新测试了一遍内置 example,发现卷积速度也明显加快,从 1.0 版本的训练一帧需大约 1 秒优化到了 0.15 秒(配置:win8,i5-6200U,GeForce-940M,batch-size=32)。
用户1386409
2019/06/06
7580
PaddlePaddle版Flappy-Bird—使用DQN算法实现游戏智能
深度强化学习经验回放(Experience Replay Buffer)的三点高性能修改建议:随机采样、减少保存的数据量、简化计算等
我们使用 Numpy 库在内存里、使用 PyTorch 库在显存里 创建了一整块连续的空间,对比了 List 和 Tuple 的方案。结果:连续存储空间的明显更节省时间。因此,DRL 库的 ReplayBuffer 有必要围绕着 连续内存空间 来设计。
汀丶人工智能
2023/10/11
1.6K0
深度强化学习经验回放(Experience Replay Buffer)的三点高性能修改建议:随机采样、减少保存的数据量、简化计算等
强化学习从基础到进阶-案例与实践4.2:深度Q网络DQN-Cart pole游戏展示
比如本项目的Cart pole小游戏中,agent就是动图中的杆子,杆子有向左向右两种action。
汀丶人工智能
2023/06/30
4120
深度强化学习-Policy Gradient基本实现
在之前的几篇文章中,我们介绍了基于价值Value的强化学习算法Deep Q Network。有关DQN算法以及各种改进算法的原理和实现,可以参考之前的文章: 实战深度强化学习DQN-理论和实践: DQN三大改进(一)-Double DQN DQN三大改进(二)-Prioritised replay DQN三大改进(三)-Dueling Network 基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然后根据价值贪心的选择动作。如果我们省略中间的步骤,即直接根据当前的状态来选择动作,也
石晓文
2018/04/11
1.9K0
深度强化学习-Policy Gradient基本实现
推荐阅读
相关推荐
DQN三大改进(二)-Prioritised replay
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档