[原创] 强化学习框架 rlpyt 源码分析:(7) 模型参数是在哪更新的

转载需注明出处:https://www.codelast.com/

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。
本文简要分析一下在rlpyt中,强化学习模型的参数是在什么地方被更新、怎么被更新的。

▶▶ 概述
模型参数是在Algorithm模块的optimize_agent()函数里被更新的,它在Runner类(例如 MinibatchRl)的train()函数里被调用。
文章来源:https://www.codelast.com/
▶▶ Runner类的调用
MinibatchRl这个Runner类为例,它的 train() 函数中有这么一句:

opt_info = self.algo.optimize_agent(itr, samples)

其中,self.algo 就是一个Algorithm类的对象,这里的optimize_agent()函数会用采样得到的一批数据(samples)更新一次模型参数。
文章来源:https://www.codelast.com/
▶▶ Algorithm类更新模型参数的实现
前文中提到了rlpyt有一个模块叫做Algorithm,它们位于项目的 rlpyt/algos/ 路径下:

├── base.py
├── dqn
│   ├── cat_dqn.py
│   ├── dqn.py
│   └── r2d1.py
├── pg
│   ├── a2c.py
│   ├── base.py
│   └── ppo.py
├── qpg
│   ├── ddpg.py
│   ├── sac.py
│   ├── sac_v.py
│   └── td3.py
└── utils.py
这些就是rlpyt里面的“算法”模块,它们实现了DQN,PPO等算法。
文章来源:https://www.codelast.com/
以DQN为例(rlpyt/algos/dqn/dqn.py),其optimize_agent()函数有这么几句:

self.optimizer.zero_grad()  # 将所有参数的梯度都置零
loss, td_abs_errors = self.loss(samples_from_replay)
loss.backward()  # 误差反向传播计算参数梯度
grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.clip_grad_norm)
self.optimizer.step()  # 通过梯度做一步参数更新

加上注释的几句就是主要的模型参数更新逻辑。其中,self.optimizer其实就是PyTorch的optimzer对象(例如 torch.optim.Adam),用于优化神经网络的参数。
但是乍一看,这几句optimizer的操作,貌似和模型(torch.nn.Module)的参数没有关系?
所以这就涉及到另一个问题:optimizer和model是怎么关联上的?
DQN.optim_initialize()函数中创建了 self.optimizer 对象:

self.optimizer = self.OptimCls(self.agent.parameters(),
    lr=self.learning_rate, **self.optim_kwargs)

其中,self.OptimCls 就是PyTorch的optimzer类,例如 torch.optim.Adam。其构造函数可以接受一个 params 参数:

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
             weight_decay=0, amsgrad=False):

官方文档对 params参数的说明:

params (iterable): iterable of parameters to optimize or dicts defining parameter groups

在创建 self.optimizer 对象的时候,传入了一个 self.agent.parameters() 参数,这个函数的实现在 BaseAgent.parameters() 这里:

def parameters(self):
    """Parameters to be optimized (overwrite in subclass if multiple models)."""
    return self.model.parameters()

其中,self.model 就是 torch.nn.Module 类型的对象,其 parameters() 函数返回的就是模型要优化的参数。
于是 model 就这样和 optimizer 关联起来了。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。

发表评论

电子邮件地址不会被公开。 必填项已用*标注