[原创] 强化学习框架 rlpyt 源码分析:(2) 掌管训练流程的runner类

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。

▶▶ runner的主要功能
rlpyt 项目的“runners”目录下有这些代码文件:

├── runners
│   ├── __init__.py
│   ├── async_rl.py
│   ├── base.py
│   ├── minibatch_rl.py
│   └── sync_rl.py 
这些文件里的class,被 rlpyt 称为“runner”。它们的主要功能是控制RL训练流程,例如采样一批数据,训练模型,记录统计数据,评估模型等,这样的一个工作流,就是由runner来控制的。
文章来源:https://www.codelast.com/
▶▶ example_1使用的runner类

runner = MinibatchRlEval(
    algo=algo,
    agent=agent,
    sampler=sampler,
    n_steps=50e6,
    log_interval_steps=1e3,
    affinity=dict(cuda_idx=cuda_idx),
)
config = dict(game=game)
name = "dqn_" + game
log_dir = "example_1"
with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
    runner.train()

example_1 只使用了 minibatch_rl.py 中的 MinibatchRlEval 这个class,它的主要功能是:按mini-batch来训练RL模型,并且每隔一定的周期就evaluate一次模型。
从代码可见,runner 依赖于 algorithm、agent、sampler,是因为 runner 需要控制 agent 模块设置训练模式、控制 sampler 模块采样、控制 algorithm 模块来优化神经网络参数。
相比之下,minibatch_rl.py 里还有一个class MinibatchRl,从它的名字我们也可以直观地猜测出来它的功能了:按mini-batch来训练模型,但不对模型做evaluation。
而 MinibatchRlEval和 MinibatchRl 的共性,则被抽象到了它们共同的父类 MinibatchRlBase 中。
文章来源:https://www.codelast.com/
▶▶ MinibatchRlEval 的训练逻辑
训练发生在 train() 函数中,我用比较详尽的注释来说明模型训练的逻辑:

def train(self):
    n_itr = self.startup()  # 调用startup()会导致调用父类的__init__()方法,从而会把外面的algo,agent,sampler传进去
    with logger.prefix(f"itr #0 "):
        eval_traj_infos, eval_time = self.evaluate_agent(0)  # 开始训练模型之前先evaluate一次
        self.log_diagnostics(0, eval_traj_infos, eval_time)  # 记录诊断信息(写日志)
    for itr in range(n_itr):  # 重复训练N轮
        with logger.prefix(f"itr #{itr} "):
            self.agent.sample_mode(itr)  # 设置成采样模式
            samples, traj_infos = self.sampler.obtain_samples(itr)  # 采样一批数据
            self.agent.train_mode(itr)  # 把神经网络module设置成训练模式,传进入的迭代次数其实没用
            opt_info = self.algo.optimize_agent(itr, samples)  # 训练模型,反向传播之类的工作就是在这里面做的
            self.store_diagnostics(itr, traj_infos, opt_info)  # 更新内存中的一些统计数据
            if (itr + 1) % self.log_interval_itrs == 0:  # 每迭代到记录一次日志的步数
                eval_traj_infos, eval_time = self.evaluate_agent(itr)  # 评估模型
                self.log_diagnostics(itr, eval_traj_infos, eval_time)  # 记录诊断信息(写日志)
    self.shutdown()  # 完成后的清理工作

这有几个略显奇怪的地方:
(1)训练的轮数 n_iter,不是简单地人为指定的,而是在父类 MinibatchRlBase 的 startup() 函数中,通过一个比较麻烦的方法计算出来的,实现在 get_n_itr() 函数里,需要好好看几遍理解一下。
(2)在正式开始循环训练N轮之前,先做了一次evaluation,为什么会这样操作?我认为是可以记录一些初始的状态(evaluation的时候会把各种统计数据记录到日志里),便于观察是否有明显问题。
文章来源:https://www.codelast.com/
▶▶ cpu_affinity的概念
在 MinibatchRlEval 的父类 MinibatchRlBase 的 startup() 函数中,会看到一些和 CPU亲和性(affinity)的代码,例如:

try:
    if self.affinity.get("master_cpus", None) is not None and self.affinity.get("set_affinity", True):
        p.cpu_affinity(self.affinity["master_cpus"])
    cpu_affin = p.cpu_affinity()
except AttributeError:
    cpu_affin = "UNAVAILABLE MacOS"
logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: {cpu_affin}.")

CPU亲和性(affinity)就是进程要在某个给定的CPU上尽量长时间地运行而不被迁移到其他处理器的倾向性。设置CPU亲和性可以在某些情况下提高程序性能。从 rlpyt 这里的逻辑可见,它为了优化程序速度做了很多和硬件交互的工作。
文章来源:https://www.codelast.com/
▶▶ world_size的概念
在 MinibatchRlEval 的父类 MinibatchRlBase 的 startup() 函数中,会看到一些和 world_size 相关的代码,而作者并没有解释 world_size 是个什么东西。通过下面这句:

self.itr_batch_size = self.sampler.batch_spec.size * world_size

可以推测:由于 self.sampler.batch_spec.size 是在所有environment实例上的所有time step的总和,我认为 world_size 是一个“平行宇宙”的概念——想像一下美剧《闪电侠》,这个世界存在 Earth 1,Earth 2,...,Earth N... 在“当前宇宙”内的发生的采样,它是算在 batch_spec.size内的,而像这样的场景,我们可以把它复制很多个出来,用所有这些创造出来的集合来训练RL模型。
文章来源:https://www.codelast.com/
▶▶ 哪些模块在runner中会被初始化(initialize)?
从 MinibatchRlEval 的父类 MinibatchRlBase 的 startup() 函数可以看到,sampler、algorithm两个模块是在 runner 里初始化的,而 agent 并没有被 initialize,这是为什么?因为 agent 在 sampler 类的 initialize() 函数里被初始化了,所以,initialize 了 sampler,自然就 initialize 了 agent——具体可查看 SerialSampler 类的代码。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论