写点什么

强化学习—DQN:不讲前世,就论今生

用户头像
打工人!
关注
发布于: 2021 年 04 月 04 日
强化学习—DQN:不讲前世,就论今生

前言

相信小可爱们点进这篇文章,要么是对强化学习(Reinforcement learning-RL)有一定的了解,要么是想要了解强化学习的魅力所在,要么是了解了很多基础知识,但是不知道代码如何写。今天我就以最经典和基础的算法(DQN)带大家一探强化学习的强大基因,不讲前世(不讲解公式推导),只讨论今生(通俗语言+图示讲解 DQN 算法流程,以及代码如何实现)。


预备知识

如果小可爱们只想了解 DQN 算法的流程,那么跟着我的步伐,一点一点看下去就可以。如果你想要使用代码实现算法并亲眼看到它的神奇之处,也可以在本文中找到答案。

本文代码实现基于 tensorflow1.8,想看懂代码的,建议熟悉 tf 的基础知识


下面进入正题:首先想想为什么叫它强化学习呢?强化,强化,再强化??? 哦!想起来了,好像和我们日常学习中的强化练习、《英语强化习题集》有点相似,考过研的小可爱们都知道张宇老师还有强化课程和相关习题册。哈哈,是不是想起来在高三或考研期间被各种强化学习资料支配的恐惧呢?不用担心,在本文中没有强化资料。我们再思考一下,小可爱们当初为了考取理想的分数,不断做题强化,是不是就是为了提高做题速度、逻辑思考能力和正确率,从而提高成绩?通过不断做题强化,我们学到了很多知识,提高了分数,那这些指标是不是我们的收益?在进行强化训练整个过程中,我们是不是一次次地和试题进行较量,拼个你死我活,接着试题答案要么给我们当头一击吃个 0 蛋、要么赏给我们颗糖吃吃,顿时心里美滋滋。试题答案给了我们反馈,同时我们在接收反馈之后,反思自己,找到自己的不足并纠正,以使在以后面对这些题时,可以主动避开错误的解法,尽可能拿更高的分数。慢慢地,我们找到了做题套路,解题能力得到了强化提高。

强化学习基本思想理解

强化学习的基本思想也是一样的道理。下面是强化学习框架:

强化学习框架图

现在把图中的 Agent(智能体)当做小可爱自己,Environment(环境)当做大量的试题集(包含大量的不同的试题),State s(状态)当做当前时刻小可爱正在做的某一试题,Action(动作)当做小可爱解题过程和步骤,Reward(奖励)当做试题答案给小可爱的作答打分,State s'当做在试题集中小可爱当前所做题的下一个试题【这道题小可爱做完了,也看到得了多少分,就会去做下一道试题,是不是就意味着状态从当前状态转移到了下一个状态】。小可爱们按照这种思路再次读一读上面一段话,对照着强化学习框架,是不是就明白了智能体与环境交互的过程,也明白了强化学习的基本思想了呢?

强化学习中的探索和利用理解

如果小可爱们在读了上面的讲解之后,已经对强化学习基本思想有了进一步的理解,那么恭喜小可爱顺利通过第一关。如果还没有更加明白,不要慌!跟着我,在学习这一部分探索和利用问题的同时,你也可以进一步理解和巩固强化学习基本思想。

在强化学习算法中,当智能体与环境进行交互时有一个探索的概念和一个利用的概念。顾名思义,探索的意思就是智能体以前在某些状态没有做过这种动作【状态和动作是一一对应的,在某一状态下,做出某一动作】,它要尝试解锁一些新动作,说不定这个被解锁的动作就可以帮助智能体得到更高的分。利用的意思就是智能体从以前做过的所有动作中,选择一个可以帮助他获得更高分的动作。也就是利用以前的经历去选择出最优动作。

下面以小可爱做一道数学题为例来说明这个问题,顺便加深对强化学习基本思想的理解。这是一道非常难的一道数学题,满分 100 分,有上百个步骤,同时也有很多种解法,这对于第一次接触这种题的小可爱来说,太难了。但是这并难不倒聪明的小可爱们,花了一天时间,大概写了 100 个步骤,终于把这道题做出来了。在对比了答案之后,事情不尽人意,只做对了几步,得了一点点步骤分。一看只得了这么一点分,小可爱们归纳整理做这道题的经验并反思。又开始重新做这道题,并试图在某些步骤尝试不同的解法【这是不是探索呢?】,最终花费的时间比第一次少了许多,并且得分也比以前高了。这让小可爱们甚是兴奋,更加自信了。小可爱们试图又重新做这道题,基于前几次的做题经验【这里是不是包含了利用呢?】以及探索使用不同的方法,发现有时候在某些步骤上用比较复杂的方法,但是却为后面的解题大大减少了计算量,从而减小了错误概率。最后,尽管用时稍微长些,但是得分更高了。就这样一次又一次地重新做这道题,小可爱们慢慢地摸索到了这道题的最佳做题套路,在用时和得分之间寻找到了平衡,使得最终所摸索到的关于这道题的做题技巧性价比最高。

看完上面一段话,小可爱们是不是对于强化学习中的探索和利用以及基本思想有了更加深刻地理解了呢?恭喜你,小可爱,又完美通过一关。

图解 DQN 算法

这个部分,将使用图解和文字结合的方法对 DQN 算法进行剖析。先上个图:

DQN算法结构图

如果看不懂上面的图,没关系。下面该部分将再次用图解的方式对 DQN 算法流程进行更细致的剖析,上图整活:

DQN算法深入剖析图


高能预警!

注意上图中的 eval 网络和 target 网络结构是完全一样的,但是权重参数不是实时同步的,target 网络权重参数更新是落后于 eval 网络的权重参数更新的。

从以上两个图中我们可以知道,DQN 算法是分为两个部分的。

  • 第一部分—经验收集阶段:智能体想要从幼儿园进化到大学水平,学习成绩更好那就必须要有经验啊。那这一部分就是智能体与环境进行交互收集经验的阶段。注意:这一部分只会用到 eval 网络,智能体从环境中观察到状态 s 输入到 eval 网络,网络会输出每一个动作对应的 Q 值。智能体在选择动作时一般使用ε贪心策略来选择动作—【在整个训练初始阶段,大概率随机选择动作(探索),随着训练轮数的增加,小概率随机选择动作,大概率选择最大 Q 值对应的动作(利用)】。在选择动作后,会作用到环境,环境会相应的反馈给智能体一个奖励 r,同时环境从当前状态 s 转移到 s'每进行一次交互,都会把(s,a,r,s')保存到经验池中。

  • 第二部分—训练学习阶段:在交互次数达到预先设定的值后,就进入到了训练学习阶段。这个阶段是智能体真正利用以前的经验来进行反思提高的(训练)。智能体会从经验池中随机抽取 batch-size 组的数据,把这组数据中的 s 批量输入到 eval 网络,s'批量输入到 target 网络中,得到该组数据中每一条数据对应的 q_eval 和 q_target。然后计算该组数据的损失,然后进行反向传播进行训练。训练结束后,再次进入到第一部分。

以上两个部分交替进行,直到达到预先设定的训练轮数。相信各位小可爱在看了本部分的介绍,明白了 DQN 算法的具体流程,恭喜你,又一次成功通关。

代码实现—基于 tensorflow1 版本


import numpy as npimport tensorflow as tf
np.random.seed(1)tf.set_random_seed(1)

# Deep Q Network off-policyclass DeepQNetwork: def __init__( self, n_actions, n_features, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, replace_target_iter=300, memory_size=500, batch_size=32, e_greedy_increment=None, output_graph=False, ): self.n_actions = n_actions self.n_features = n_features self.lr = learning_rate self.gamma = reward_decay self.epsilon_max = e_greedy self.replace_target_iter = replace_target_iter self.memory_size = memory_size self.batch_size = batch_size self.epsilon_increment = e_greedy_increment self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max
# total learning step self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.memory = np.zeros((self.memory_size, n_features * 2 + 2))
# consist of [target_net, evaluate_net] self._build_net()
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net') e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
with tf.variable_scope('hard_replacement'): self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
self.sess = tf.Session()
if output_graph: # $ tensorboard --logdir=logs tf.summary.FileWriter("logs/", self.sess.graph)
self.sess.run(tf.global_variables_initializer()) self.cost_his = []
def _build_net(self): # ------------------ all inputs ------------------------ self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s') # input State self.s_ = tf.placeholder(tf.float32, [None, self.n_features], name='s_') # input Next State self.r = tf.placeholder(tf.float32, [None, ], name='r') # input Reward self.a = tf.placeholder(tf.int32, [None, ], name='a') # input Action
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)
# ------------------ build evaluate_net ------------------ with tf.variable_scope('eval_net'): e1 = tf.layers.dense(self.s, 20, tf.nn.relu, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='e1') self.q_eval = tf.layers.dense(e1, self.n_actions, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='q')
# ------------------ build target_net ------------------ with tf.variable_scope('target_net'): t1 = tf.layers.dense(self.s_, 20, tf.nn.relu, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='t1') self.q_next = tf.layers.dense(t1, self.n_actions, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='t2')
with tf.variable_scope('q_target'): q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, ) self.q_target = tf.stop_gradient(q_target) with tf.variable_scope('q_eval'): a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1) self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, ) with tf.variable_scope('loss'): self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error')) with tf.variable_scope('train'): self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0 transition = np.hstack((s, [a, r], s_)) # replace the old memory with new memory index = self.memory_counter % self.memory_size self.memory[index, :] = transition self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation}) action = np.argmax(actions_value) else: action = np.random.randint(0, self.n_actions) return action
def learn(self): # check to replace target parameters if self.learn_step_counter % self.replace_target_iter == 0: self.sess.run(self.target_replace_op) print('\ntarget_params_replaced\n')
# sample batch memory from all memory if self.memory_counter > self.memory_size: sample_index = np.random.choice(self.memory_size, size=self.batch_size) else: sample_index = np.random.choice(self.memory_counter, size=self.batch_size) batch_memory = self.memory[sample_index, :]
_, cost = self.sess.run( [self._train_op, self.loss], feed_dict={ self.s: batch_memory[:, :self.n_features], self.a: batch_memory[:, self.n_features], self.r: batch_memory[:, self.n_features + 1], self.s_: batch_memory[:, -self.n_features:], })
self.cost_his.append(cost)
# increasing epsilon self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()
复制代码


发布于: 2021 年 04 月 04 日阅读数: 227
用户头像

打工人!

关注

打工人! 2019.11.10 加入

InfoQ年度最佳内容获得者。 InfoQ签约作者 本人打工人一枚,自动化和控制专业入坑人一枚。目前在研究深度强化学习(DRL)技术。准备入坑互联网小白一枚。喜欢了解科技前沿技术,喜欢拍照。

评论 (9 条评论)

发布
用户头像
不错不错,作者大大继续努力更新呀
2021 年 04 月 15 日 08:46
回复
用户头像
强啊老李
2021 年 04 月 13 日 21:55
回复
用户头像
不错,学习了!
2021 年 04 月 13 日 19:23
回复
用户头像
受益匪浅,感谢!
2021 年 04 月 13 日 19:18
回复
用户头像
写的不错,受益匪浅!
2021 年 04 月 13 日 19:16
回复
用户头像
看完之后醍醐灌顶
2021 年 04 月 12 日 15:15
回复
用户头像
生动形象,感谢分享!!!!!!
2021 年 04 月 12 日 15:14
回复
用户头像
不错不错
2021 年 04 月 12 日 10:12
回复
谢谢!
2021 年 04 月 12 日 10:18
回复
没有更多了
强化学习—DQN:不讲前世,就论今生