查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。
▶▶ 观察训练日志引出的问题
以 example_1 为例,在训练的过程中,程序会不断打印出类似于下面的日志(部分内容):
2019-11-08 20:38:42.067188 | StepsInEval 37962019-11-08 20:38:42.067216 | TrajsInEval 52019-11-08 20:38:42.067240 | CumEvalTime 23.12652019-11-08 20:38:42.067276 | CumTrainTime 2.646412019-11-08 20:38:42.067297 | Iteration 2492019-11-08 20:38:42.067315 | CumTime (s) 25.77292019-11-08 20:38:42.067333 | CumSteps 10002019-11-08 20:38:42.067350 | CumCompletedTrajs 12019-11-08 20:38:42.067368 | CumUpdates 02019-11-08 20:38:42.067385 | StepsPerSecond 386.0792019-11-08 20:38:42.067402 | UpdatesPerSecond 02019-11-08 20:38:42.067419 | ReplayRatio 02019-11-08 20:38:42.067436 | CumReplayRatio 02019-11-08 20:38:42.067453 | LengthAverage 759.22019-11-08 20:38:42.067480 | LengthStd 1.166192019-11-08 20:38:42.067499 | LengthMedian 7592019-11-08 20:38:42.067516 | LengthMin 7582019-11-08 20:38:42.067533 | LengthMax 7612019-11-08 20:38:42.067550 | ReturnAverage -212019-11-08 20:38:42.067567 | ReturnStd 02019-11-08 20:38:42.067584 | ReturnMedian -212019-11-08 20:38:42.067601 | ReturnMin -212019-11-08 20:38:42.067618 | ReturnMax -212019-11-08 20:38:42.067635 | NonzeroRewardsAverage 212019-11-08 20:38:42.067652 | NonzeroRewardsStd 02019-11-08 20:38:42.067669 | NonzeroRewardsMedian 212019-11-08 20:38:42.067686 | NonzeroRewardsMin 212019-11-08 20:38:42.067703 | NonzeroRewardsMax 212019-11-08 20:38:42.067720 | DiscountedReturnAverage -1.877712019-11-08 20:38:42.067737 | DiscountedReturnStd 0.02196052019-11-08 20:38:42.067754 | DiscountedReturnMedian -1.881362019-11-08 20:38:42.067771 | DiscountedReturnMin -1.900362019-11-08 20:38:42.067788 | DiscountedReturnMax -1.843922019-11-08 20:38:42.067805 | lossAverage nan2019-11-08 20:38:42.067822 | lossStd nan2019-11-08 20:38:42.067839 | lossMedian nan2019-11-08 20:38:42.067856 | lossMin nan2019-11-08 20:38:42.067873 | lossMax nan2019-11-08 20:38:42.067890 | gradNormAverage nan2019-11-08 20:38:42.067907 | gradNormStd nan2019-11-08 20:38:42.067924 | gradNormMedian nan2019-11-08 20:38:42.067941 | gradNormMin nan2019-11-08 20:38:42.067958 | gradNormMax nan2019-11-08 20:38:42.067975 | tdAbsErrAverage nan2019-11-08 20:38:42.067992 | tdAbsErrStd nan2019-11-08 20:38:42.068009 | tdAbsErrMedian nan2019-11-08 20:38:42.068026 | tdAbsErrMin nan2019-11-08 20:38:42.068043 | tdAbsErrMax nan
仔细看就会发现,最后的若干个模型指标都是“nan”,在训练了一段时间之后,这些值就变成了有意义的值,例如:
2019-11-08 20:40:40.941580 | lossAverage 0.01291652019-11-08 20:40:40.941597 | lossStd 0.01370612019-11-08 20:40:40.941614 | lossMedian 0.01503482019-11-08 20:40:40.941631 | lossMin 0.0001053232019-11-08 20:40:40.941648 | lossMax 0.06024072019-11-08 20:40:40.941665 | gradNormAverage 0.02839392019-11-08 20:40:40.941682 | gradNormStd 0.01682192019-11-08 20:40:40.941699 | gradNormMedian 0.03014822019-11-08 20:40:40.941716 | gradNormMin 0.006612182019-11-08 20:40:40.941732 | gradNormMax 0.0863342019-11-08 20:40:40.941749 | tdAbsErrAverage 0.05290542019-11-08 20:40:40.941766 | tdAbsErrStd 0.1684162019-11-08 20:40:40.941783 | tdAbsErrMedian 0.02332032019-11-08 20:40:40.941800 | tdAbsErrMin 8.33329e-052019-11-08 20:40:40.941817 | tdAbsErrMax 1
文章来源:https://www.codelast.com/
▶▶ nan 日志在哪记下来的
为了弄清楚上面的问题,我们要找到根源——打印“nan”日志的地方。上面那些显示为“nan”的日志,是 rlpyt/utils/logging/logger.py 的 record_tabular_misc_stat() 函数记录下来的:
def record_tabular_misc_stat(key, values, placement='back'): if placement == 'front': prefix = "" suffix = key else: prefix = key suffix = "" if len(values) > 0: record_tabular(prefix + "Average" + suffix, np.average(values)) record_tabular(prefix + "Std" + suffix, np.std(values)) record_tabular(prefix + "Median" + suffix, np.median(values)) record_tabular(prefix + "Min" + suffix, np.min(values)) record_tabular(prefix + "Max" + suffix, np.max(values)) else: record_tabular(prefix + "Average" + suffix, np.nan) record_tabular(prefix + "Std" + suffix, np.nan) record_tabular(prefix + "Median" + suffix, np.nan) record_tabular(prefix + "Min" + suffix, np.nan) record_tabular(prefix + "Max" + suffix, np.nan)
文章来源:https://www.codelast.com/
这个函数用来计算某些模型指标,这些模型指标有一个共同的特征:它们都可以计算平均值、标准差等统计值。这是什么意思?举个例子,有一个指标“CumTrainTime”(累积的训练时间),它就没有“平均值”的概念;而像 loss(损失函数的值)这种指标,它在多轮训练迭代过程中,是可以有“平均值”的概念的。
而类似于 loss 这种指标,还不止一个。为了简化代码,这里采用了拼接模型指标名称的做法,例如日志里的"lossAverage","gradNormAverage"之类的名称都是拼出来的,而不是直接写死,正如你上面看到的代码一样。
从上面的代码可见,当传入的“values”为空的时候,记下来的某些模型指标就会变成“nan”。
所以现在的问题变成了:在什么时候,传入的“values”会为空?
文章来源:https://www.codelast.com/
▶▶ logger的调用者 MinibatchRlEval 更新模型指标的逻辑
example_1 使用的 runner 是 MinibatchRlEval,它就是 logger 的调用者。在 MinibatchRlEval.train() 函数中定义了模型的训练、评估流程。
下面这句代码:
opt_info = self.algo.optimize_agent(itr, samples)
会把 loss 等参数收集到 opt_info 对象中,而下面这句代码:
self.store_diagnostics(itr, traj_infos, opt_info)
则会把 opt_info 更新到内存里。最后,这一句代码:
self.log_diagnostics(itr, eval_traj_infos, eval_time)
会把内存里的信息记录到日志,以及print到屏幕上。
所以,其实我们只要弄清楚 self.algo.optimize_agent() 返回 opt_info 的逻辑,就知道在什么情况下 loss 等指标为“nan”了。
文章来源:https://www.codelast.com/
▶▶ 找到根本原因:algorithm类更新模型指标的逻辑
example_1 使用的algorithm类是:
class DQN(RlAlgorithm):
它的 optimize_agent() 函数里有这样一段代码:
opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info
这里的 opt_info 其实就是一个各字段为空list的 namedtuple 对象:
OptInfo(loss=[], gradNorm=[], tdAbsErr=[])
答案已经很明显了,当前模型训练的迭代次数 < self.min_itr_learn 的时候,就会造成 loss 等模型指标为“nan”。
self.min_itr_learn 是在 DQN.initialize() 函数里初始化的:
self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
不用去管这个看似有点奇怪的逻辑,只需要知道:self.min_steps_learn 越大,“nan”打印出的次数就越多。
而 self.min_steps_learn 这个参数,是在 DQN 类对象构造的时候传入的(example_1.py):
algo = DQN(min_steps_learn=1e3)
所以,你只要改小这个值,就可以让“nan”出现的次数减少。
文章来源:https://www.codelast.com/
▶▶ 为什么要这样做,以及调整 min_steps_learn 参数的注意事项
rlpyt 为什么要用一个参数来控制模型指标的计算过程?其实它不是为了控制什么时候不显示“nan”,看 DQN.optimize_agent() 函数的这几句代码:
if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info
就会发现:当训练迭代次数没有达到 self.min_itr_learn 的时候,算法会一直把与environment交互得到的采样数据收集到 Replay Buffer 里面,如果 Replay Buffer 里的数据太少,没有达到预设的数量,那么开始优化策略网络也是没有意义的。当满足 irt >= self.min_itr_learn 的条件之后,后面才会进行反向传播之类的工作。
所以我认为,min_steps_learn 的值确实不能设置得太小。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):