查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。
本文简要分析一下在rlpyt中,强化学习模型的参数是在什么地方被更新、怎么被更新的。
▶▶ 概述
模型参数是在Algorithm模块的optimize_agent()函数里被更新的,它在Runner类(例如 MinibatchRl)的train()函数里被调用。
文章来源:https://www.codelast.com/
▶▶ Runner类的调用
以MinibatchRl这个Runner类为例,它的 train() 函数中有这么一句:
opt_info = self.algo.optimize_agent(itr, samples)
其中,self.algo 就是一个Algorithm类的对象,这里的optimize_agent()函数会用采样得到的一批数据(samples)更新一次模型参数。
文章来源:https://www.codelast.com/
▶▶ Algorithm类更新模型参数的实现
在前文中提到了rlpyt有一个模块叫做Algorithm,它们位于项目的 rlpyt/algos/ 路径下:
├── base.py├── dqn│ ├── cat_dqn.py│ ├── dqn.py│ └── r2d1.py├── pg│ ├── a2c.py│ ├── base.py│ └── ppo.py├── qpg│ ├── ddpg.py│ ├── sac.py│ ├── sac_v.py│ └── td3.py└── utils.py
文章来源:https://www.codelast.com/
以DQN为例(rlpyt/algos/dqn/dqn.py),其optimize_agent()函数有这么几句:
self.optimizer.zero_grad() # 将所有参数的梯度都置零 loss, td_abs_errors = self.loss(samples_from_replay) loss.backward() # 误差反向传播计算参数梯度 grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() # 通过梯度做一步参数更新
加上注释的几句就是主要的模型参数更新逻辑。其中,self.optimizer其实就是PyTorch的optimzer对象(例如 torch.optim.Adam),用于优化神经网络的参数。
但是乍一看,这几句optimizer的操作,貌似和模型(torch.nn.Module)的参数没有关系?
所以这就涉及到另一个问题:optimizer和model是怎么关联上的?
在DQN.optim_initialize()函数中创建了 self.optimizer 对象:
self.optimizer = self.OptimCls(self.agent.parameters(), lr=self.learning_rate, **self.optim_kwargs)
其中,self.OptimCls 就是PyTorch的optimzer类,例如 torch.optim.Adam。其构造函数可以接受一个 params 参数:
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
官方文档对 params参数的说明:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
在创建 self.optimizer 对象的时候,传入了一个 self.agent.parameters() 参数,这个函数的实现在 BaseAgent.parameters() 这里:
def parameters(self): """Parameters to be optimized (overwrite in subclass if multiple models).""" return self.model.parameters()
其中,self.model 就是 torch.nn.Module 类型的对象,其 parameters() 函数返回的就是模型要优化的参数。
于是 model 就这样和 optimizer 关联起来了。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):