[原创] 强化学习框架 rlpyt 源码分析:(3) 相当简洁又十分巧妙的EpsilonGreedy类

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。

▶▶ EpsilonGreedy 类从哪来,做何用
agent 在 environment 里步进的时候,会根据policy network的计算结果,选择一个 action,再去根据这个 action 计算相应的 reward。对 example_1 来说,agent 类是 DqnAgent,其 step() 函数就是用于执行步进操作的:

@torch.no_grad()
def step(self, observation, prev_action, prev_reward):
    prev_action = self.distribution.to_onehot(prev_action)
    model_inputs = buffer_to((observation, prev_action, prev_reward),
        device=self.device)
    q = self.model(*model_inputs)
    q = q.cpu()
    action = self.distribution.sample(q)
    agent_info = AgentInfo(q=q)
    return AgentStep(action=action, agent_info=agent_info)

action = self.distribution.sample(q) 这里会用到 rlpyt/distributions/epsilon_greedy.py 里实现的 EpsilonGreedy 类,从名字上看,猜测它是 ε-greedy 算法的实现(实际上它就是)。
ε-greedy 是强化学习算法使用的一种探索策略。这里的目的是使用 ε-greedy 算法来选择 action。
文章来源:https://www.codelast.com/
▶▶ EpsilonGreedy 类详解
EpsilonGreedy 类有两个父类:DiscreteMixin 和 Distribution。其中 DiscreteMixin 实现了一些辅助功能的函数;Distribution 里基本是各种未实现的接口定义。
对 EpsilonGreedy 类本身来说,其精华在于只有短短5行代码的 sample() 函数:

def sample(self, q):
    arg_select = torch.argmax(q, dim=-1)
    mask = torch.rand(arg_select.shape) < self._epsilon
    arg_rand = torch.randint(low=0, high=q.shape[-1], size=(mask.sum(),))
    arg_select[mask] = arg_rand
    return arg_select

乍一看的感觉就是:这都是些什么乱七八糟的操作啊?完全不知道它在干嘛。
但从前文的分析,我们可以猜测出来这个函数是 ε-greedy 算法的实现,带着这个想法,我们读起代码来就有方向了。 
文章来源:https://www.codelast.com/
下面用一些实例辅助,来一行行分析代码,包你看懂!
sample() 函数的输入参数 q 是一个 tensor,因为从上面的分析知道,q 是policy network前向传播的计算结果。在这里,我们假设 q 为下面这个矩阵:

[[-0.2187, -0.2758,  0.4933,  1.0700],
[ 0.2689,  3.5079,  1.5640,  1.1730],
[-0.6858,  0.2571,  1.0396,  0.6344]]

现在再假设 sample() 函数用到的一个变量 self._epsilon = 0.3。这里要提一下,尽管这里我为了简单,用一个标量 0.3 来举例,但不代表 self._epsilon 一定要是个标量。如果仔细研读另一个类 EpsilonGreedyAgentMixin 的代码,会发现它调用了 EpsilonGreedy.set_epsilon() 函数:

self.distribution.set_epsilon(self.eps_sample)

而EpsilonGreedy.set_epsilon() 函数的定义为:

def set_epsilon(self, epsilon):
    self._epsilon = epsilon

此时 set 进去的 epsilon 有可能是一个 tensor 而不是一个 scalar。
记住这一点,我们继续用简单的scalar的情况来举例,即令 self._epsilon = 0.3。
文章来源:https://www.codelast.com/
 第1行代码:

arg_select = torch.argmax(q, dim=-1)

这句的功能是:返回指定的维度(dim,-1表示最后一个维度)上,值最大的那个数的index。
结果,arg_select 值为 [3, 1, 2],这是因为,对输入矩阵来说,第一行最大的值是 1.0700,其index为3;第二行最大的值是 3.5079,其index为0;第三行最大的值是 1.0396,其index为0,因此拼起来就是 [3, 0, 0]。
文章来源:https://www.codelast.com/
 第2行代码:

mask = torch.rand(arg_select.shape) < self._epsilon

会得到一个bool的矩阵,标识了torch.rand生成的随机数组里的每个元素是比self._epsilon大还是小。
结果,mask 值为[True, False, True],这是因为,此时 torch.rand(arg_select.shape)得到的一个随机矩阵是[0.2983, 0.4749, 0.2926] (由于是随机的,因此不是每次都是这个结果,这里仅拿某一次运行的结果作为例子来陈述),这个随机矩阵的3个数,分别和 self._epsilon 比小,得到的结果就是 [True, False, True]。
文章来源:https://www.codelast.com/
✔ 第3行代码最为复杂:

arg_rand = torch.randint(low=0, high=q.shape[-1], size=(mask.sum(),))

torch.randint()返回均匀分布的[low,high)之间的整数随机值,mask.sum()得到bool矩阵中True元素的个数(假设为x),因此得到的arg_rand是x个[low,high)之间的随机数。例如 print(torch.randint(0, 20, (6, ))) 的输出可能是:tensor([14,  4,  7, 17, 16,  3])。
mask.sum() 的值为 2,因为这等同于执行 torch.sum(mask),即计算 mask 这个 Tensor 上的所有元素的和,对元素为 bool 类型的情况,True为1,False为0,因此结果为2。
q.shape[-1] 的值为 4,因为 shape 为(3, 4),因此 shape[-1] 就是最后一个值,即 4。
因此 arg_rand 这一句执行的语句就是:torch.rand(low=0, high=4, size=(2, )),即在 [0, 4) 间随机取两个整数,结果为 [2, 3]。
文章来源:https://www.codelast.com/
 第4行代码:

arg_select[mask] = arg_rand

mask是一个bool的Tensor,把它传给另一个Tensor arg_select的时候,返回的是mask中为True的那些entry。
arg_select[mask] = arg_rand 这句在执行之前,arg_select为[3, 1, 2],mask为[True, False, True],arg_rand为[2, 3],对mask里为True的两个位置,找到arg_select里的对应位置,替换成arg_rand里的值,就是最后的结果:[2, 1, 3]。
文章来源:https://www.codelast.com/
从最后的结果 [2, 1, 3] 可以看到,它已经不能标识输入矩阵 q 的每一行的最大值的index了。
所以把上面的逻辑总结一遍,sample() 函数实现的功能就是:
找出输入矩阵某个维度上的最大值,然后按一定的机率(即epsilon)“不选取”那个值最大的index,最终得到一个具有“少量随机性”的最大值index矩阵。
这不正是 ε-greedy 算法干的事情吗?所以你明白 EpsilonGreedy 类为什么叫这个名字了吧。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论