[原创] 强化学习框架 rlpyt:如何保存训练过程中的所有model

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 
本文描述了如何保存迭代训练过程的所有model,以及背后的逻辑。

▶▶ 迭代训练过程中产生的所有model,能全部保存下来吗
当然可以。以 example_1 为例,它有如下代码:

with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
    runner.train()

只需要把 snapshot_mode="last" 改成 snapshot_mode="all",就可以把迭代过程中的所有模型全部保存到磁盘文件了。
“last”表示只保存最后一次迭代的model文件。
文章来源:https://www.codelast.com/
▶▶ 保存model的逻辑
model是在 logger.py 里的 save_itr_params() 函数里保存到磁盘文件的:

def save_itr_params(itr, params):
    if _snapshot_dir:
        if _snapshot_mode == 'all':
            file_name = osp.join(get_snapshot_dir(), 'itr_%d.pkl' % itr)
        elif _snapshot_mode == 'last':
            # override previous params
            file_name = osp.join(get_snapshot_dir(), 'params.pkl')
        elif _snapshot_mode == "gap":
            if itr == 0 or (itr + 1) % _snapshot_gap == 0:
                file_name = osp.join(get_snapshot_dir(), 'itr_%d.pkl' % itr)
            else:
                return
        elif _snapshot_mode == 'none':
            return
        else:
            raise NotImplementedError
        torch.save(params, file_name)  # 模型参数保存到文件

其根据 _snapshot_mode 变量来控制保存逻辑:
 all:保存所有迭代的model文件。
 last:只保存最后一次迭代的model文件。
 gap:每N次迭代保存一个model文件,N可以通过logger.set_snapshot_mode()函数来设置。
 none:不保存任何model文件。
文章来源:https://www.codelast.com/
而 _snapshot_mode,正是由 logger_context() 函数的 snapshot_mode 参数最终设置进去的。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论