[原创] 强化学习框架 rlpyt:如何使用预训练(pre-trained)的model

转载需注明出处:https://www.codelast.com/

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 
本文描述了在 rlpyt 框架下,如何使用一个预训练过的(pre-trained)model作为起点,来训练自己的RL模型的过程。

▶▶ 什么是预训练模型
引用一篇文章

简单来说,预训练模型(pre-trained model)是前人为了解决类似问题所创造出来的模型。你在解决问题的时候,不用从零开始训练一个新模型,可以从在类似问题中训练过的模型入手。
比如说,如果你想做一辆自动驾驶汽车,可以花数年时间从零开始构建一个性能优良的图像识别算法,也可以从Google在ImageNet数据集上训练得到的inception model(一个预训练模型)起步,来识别图像。
一个预训练模型可能对于你的应用中并不是100%的准确对口,但是它可以为你节省大量功夫。
训练一个强化学习模型也可能会需要消耗大量计算资源,尤其是你手上没有强大算力的时候,靠一台普通电脑去train一个model可能会用掉很长时间,因此,在别人已经train好的model的基础上继续train自己的model是一个好办法。
文章来源:https://www.codelast.com/
▶▶ rlpyt 对预训练模型的支持
以使用 DQN 算法的 example_1 为例,class DQN(RlAlgorithm) 的 __init__() 函数有一个 initial_optim_state_dict 参数:

initial_optim_state_dict=None,

另外,AtariDqnAgent 类的其中一个父类:DqnAgent,它又有一个父类 BaseAgent,在 __init__() 初始化的时候也有一个 initial_model_state_dict 参数:

def __init__(self, ModelCls=None, model_kwargs=None, initial_model_state_dict=None):

这两个地方,就是当你使用预训练模型的时候需要传入的参数。
为什么会有两个参数?它们有什么区别?
前一个是Optimizer(优化器,例如 torch.optim.Adam)的 state_dict,其包含的参数有 learning rate 等。
 后一个是model的 state_dict,其包含的参数有 model 的 weight、bias 等。
直观点,来个图(图片可放大):
pre-trained model
从图中可以清楚地看到model里存储的数据,optimizer_state_dict 就是 Optimizer 的 state_dict,agent_state_dict 就是model的 state_dict。
文章来源:https://www.codelast.com/
▶▶ 代码实操:加载预训练模型
首先我们要有一个预训练模型文件,因此,我们先把没有修改过代码的 example_1 运行一段时间,生成一个 params.pkl 模型文件,假设此文件路径为:/home/codelast/rlpyt/data/local/20191111/example_1/run_0/params.pkl
现在修改 example_1.py,可以加载预训练模型了:

# 加载预训练模型
model_loaded = torch.load('/home/codelast/rlpyt/data/local/20191111/example_1/run_0/params.pkl')
optimizer_state_dict = model_loaded['optimizer_state_dict']
agent_state_dict = model_loaded['agent_state_dict']

algo = DQN(min_steps_learn=1e3, initial_optim_state_dict=optimizer_state_dict)
agent = AtariDqnAgent(initial_model_state_dict=agent_state_dict['model'])

其他代码无需修改,就这么简单!
再重新运行修改过的example,现在就已经是在pre-trained model的基础上继续进行的训练了。

发表评论

电子邮件地址不会被公开。 必填项已用*标注