[原创] 强化学习框架 rlpyt 源码分析:(9) 基于CPU的并行采样器CpuSampler

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 本文是上一篇文章的续文,继续分析CpuSampler的源码。
我们已经知道了CpuSampler有两个父类:BaseSampler 和 ParallelSamplerBase。其中,BaseSampler主要是定义了一堆接口,没什么好说的,因此本文接着分析另一个父类 ParallelSamplerBase。在 ParallelSamplerBase 中,初始化函数 initialize() 做了很多重要的工作,已经够写一篇长长的文章来分析了,这正是本文的主要内容。

▶▶ 初始化函数 initialize() 做了哪些重要工作
一句话总结 initialize() 的重要功能:计算一些特殊参数的值,初始化agent,创建并行控制器,创建并启动多个worker进程。
这里说的“并行控制器”(parallel ctrl)是指用Python multiprocessing模块来实现并行功能的时候,需要使用一些变量来协调各个并行的进程,使它们可以正确运作。这些用于协调的变量就是“并行控制器”。

▶▶ 计算特殊参数的值
在并行模式下,有些参数(比如采样用的worker的数量)不是由用户直接设置的,而是计算出来的。而且这样的参数还挺多,所以有大段大段的代码都用来干这事了。
如果下面的代码没有注释的话,肯定会让人一头雾水:

n_envs_list = self._get_n_envs_list(affinity=affinity)  # 用户设置的worker数不一定与environment数相匹配,这里会重新调整
self.n_worker = n_worker = len(n_envs_list)  # 经过调整之后的workerB = self.batch_spec.B  # environment实例的数量
global_B = B * world_size  # "平行宇宙"概念下的environment实例的数量
env_ranks = list(range(rank * B, (rank + 1) * B))  # 含义可参考:https://www.codelast.com/?p=10932
self.world_size = world_size
self.rank = rank

if self.eval_n_envs > 0:  # example_*.py中传入的参数
    self.eval_n_envs_per = max(1, self.eval_n_envs // n_worker)  # 计算每个worker至少承载几个evaluationenvironment(至少1)
    self.eval_n_envs = eval_n_envs = self.eval_n_envs_per * n_worker  # 保证至少有"worker数量"eval environment实例
    logger.log(f"Total parallel evaluation envs: {eval_n_envs}.")
    self.eval_max_T = eval_max_T = int(self.eval_max_steps // eval_n_envs)

最为“神奇”的就是 self._get_n_envs_list() 这个函数,它用来计算每个worker承载几个environment实例。这个说法是不是特别奇怪?原因是:用户可以指定environment实例的数量,也可以指定worker的数量,但这两个数量可能是不相等的,于是,要么worker数不够,要么worker数有多;在第1种情况下,一个worker需要带>1个environment实例,在第2种情况下,不需要那么多worker,所以要减少worker的数量,才能保证一个worker刚好带一个environment实例。
文章来源:https://www.codelast.com/
我给 self._get_n_envs_list() 函数加上了注释,相信足以让大家理解它的功能了:

def _get_n_envs_list(self, affinity=None, n_worker=None, B=None):
    """
    根据environment实例的数量(所谓的"B"),以及用户设定的用于采样的worker的数量(n_worker),来计算得到一个list,这个list的元素的总数,
    就是最终的worker的数量;而这个list里的每个元素的值,分别是每个worker承载的environment实例的数量。

    :param affinity: 一个字典(dict),包含硬件亲和性定义。
    :param n_worker: 用户设定的用于采样的worker的数量。
    :param B: environment实例的数量。
    :return 一个list,其含义如上所述。
    """
    B = self.batch_spec.B if B is None else B  # 参考BatchSpec类,可以认为Benvironment实例的数量
    n_worker = len(affinity["workers_cpus"]) if n_worker is None else n_worker  # worker的数量(不超过物理CPU数否则在别处报错)
    """
    environment实例的数量<worker的数量时,例如有8worker(8个物理CPU)5environment实例,每一个物理CPU运行一个environment    那么此时会有3个物理CPU多余,此时就会把worker的数量设置成和environment实例数量一样,使得每个CPU都刚好运行一个environment实例。
    """
    if B < n_worker:
        logger.log(f"WARNING: requested fewer envs ({B}) than available worker "
            f"processes ({n_worker}). Using fewer workers (but maybe better to "
            "increase sampler's `batch_B`.")
        n_worker = B
    n_envs_list = [B // n_worker] * n_worker
    """
    environment实例的数量不是worker数量的整数倍时,每个worker被分配到的environment实例的数量是不均等的。
    """
    if not B % n_worker == 0:
        logger.log("WARNING: unequal number of envs per process, from "
            f"batch_B {self.batch_spec.B} and n_worker {n_worker} "
            "(possible suboptimal speed).")
        for b in range(B % n_worker):
            n_envs_list[b] += 1
    return n_envs_list

文章来源:https://www.codelast.com/
▶▶ 初始化agent
agent对象只有一个!并不是每一个worker进程都对应到不同的agent对象!这是理解CpuSampler时需要知晓的一个重要概念。
agent通过以下代码初始化(ParallelSamplerBase.initialize() 函数):

env = self.EnvCls(**self.env_kwargs)
self._agent_init(agent, env, global_B=global_B,
    env_ranks=env_ranks)
examples = self._build_buffers(env, bootstrap_value)
env.close()
del env

可以看到,这里初始化了environment对象,并把它作为一个参数传给了agent初始化函数 self._agent_init(),事实上,在 self._agent_init() 函数里,只用到了 env 对象的 spaces 这个属性,而没有引用整个 env 对象,因此在使用完之后,使用 env.close() 以及 del env 来清理掉env不会有问题。
self._build_buffers() 是一个非常复杂的操作,它的主要功能是创建强化学习中必备的replay buffer。直觉上,有人可能认为replay buffer这个东西,不就是创建一个list或者类似的数据结构就能搞定的吗?但实际上不是这么简单,从这个函数一级级点进去就会发现代码还不少,而且它里面甚至还用到了Python multiprocessing,所以创建replay buffer的实现就不在本文分析了。
文章来源:https://www.codelast.com/
self._agent_init() 函数的实现很简单:

def _agent_init(self, agent, env, global_B=1, env_ranks=None):
    agent.initialize(env.spaces, share_memory=True,
        global_B=global_B, env_ranks=env_ranks)
    self.agent = agent

在这里看到:agent初始化之后,赋值给了 self.agent,这就是 CpuSampler 中唯一使用的 agent 对象。
文章来源:https://www.codelast.com/
▶▶ 创建并行控制器
并行控制器(parallel ctrl)用于协调多个采样用的worker进程。
在 initialize() 里,创建并行控制器的代码只有一句:

def _build_parallel_ctrl(self, n_worker):
    """
    创建用于控制并行训练过程的一些数据结构。

    multiprocessing.RawValue:不存在lock的多进程间共享值。
    multiprocessing.Barrier:一种简单的同步原语,用于固定数目的进程相互等待。当所有进程都调用wait以后,所有进程会同时开始执行。
    multiprocessing.Queue:用于多进程间数据传递的消息队列。

    :param n_worker: 真正的worker(不一定等于用户设置的那个原始值)
    """
    self.ctrl = AttrDict(
        quit=mp.RawValue(ctypes.c_bool, False),
        barrier_in=mp.Barrier(n_worker + 1),
        barrier_out=mp.Barrier(n_worker + 1),
        do_eval=mp.RawValue(ctypes.c_bool, False),
        itr=mp.RawValue(ctypes.c_long, 0),
    )
    self.traj_infos_queue = mp.Queue()
    self.eval_traj_infos_queue = mp.Queue()
    self.sync = AttrDict(stop_eval=mp.RawValue(ctypes.c_bool, False))

这里AttrDict是一个“扩展的”dict,mp就是Python multiprocessing模块,而Python multiprocessing是一个巨大的话题,我自己也只是初步了解,所以没办法讲透彻,这里只举两个例子,来说明这些并行控制器的作用:
 ctrl.quit 可以理解为一个bool类型的进程间共享变量。在 minibatch_rl.py 中,训练完成后,会执行 shutdown(),它会调用 sampler.shutdown(),从而会把 ctrl.quit 的值设置为True;同时,在 worker.py 中会看到,当检测到 ctrl.quit 的值为True时,会退出采样过程。所有采样的worker进程都受这个变量控制。所以这样就做到了在主进程中控制并行跑的worker进程。
 multiprocessing.Queue() 用于在多进程间传递消息。在每个采样的worker进程中,会把收集到的trajectory info放到同一个traj_infos_queue中,在主进程中会把汇总的trajectory info进一步处理成统计数据,然后记日志、打印到屏幕上,等等。
文章来源:https://www.codelast.com/
▶▶ 创建并启动多个worker进程
worker进程用于采样(agent与environment交互得到的)数据。
在创建这些进程之前,需要先为它们构建所需的参数:

common_kwargs = self._assemble_common_kwargs(affinity, global_B)
workers_kwargs = self._assemble_workers_kwargs(affinity, seed, n_envs_list)

为什么需要分成 common_kwargs 以及 workers_kwargs 两个参数?这是因为:对每个worker进程来说,有些参数是通用的,有些参数是不通用的(例如,每个worker使用的CPU数量、承载的environment实例的数量等),因此,rlpyt把它们分成了两拨,分别放在两个对象里。

在准备好了参数之后,就开始创建多个worker进程,并把它们启动起来了:

# 创建一批子进程
target = sampling_process if worker_process is None else worker_process
self.workers = [mp.Process(target=target,
    kwargs=dict(common_kwargs=common_kwargs, worker_kwargs=w_kwargs))
    for w_kwargs in workers_kwargs]
# 启动子进程
for w in self.workers:
    w.start()

self.ctrl.barrier_out.wait()  # Wait for workers ready (e.g. decorrelate).

在这里,使用的是 multiprocessing.Process() 来创建的进程,target 为进程函数名,进程函数是可以自行指定的,rlpyt也提供了默认的实现,即 worker.py 中的 sampling_process() 函数。采样进程的实现代码 worker.py 虽然不长,但要完全看懂并不容易,所以留到后面的文章再分析。
在worker进程启动之后,它就进入了持续的采样过程。注意上面代码的最后一句 self.ctrl.barrier_out.wait(),这里使用了 multiprocessing的Barrier来控制各个worker进程同步。由于 barrier_out 创建的时候是这样的:

barrier_out=mp.Barrier(n_worker + 1)

所以,它需要 n_worker + 1 个 wait() 才能让所有进程同时“解锁”(即同时开始执行),在 initialize() 函数里的 self.ctrl.barrier_out.wait() 算一个,每个worker函数——即 sampling_process()——里也分别有一个 barrier_out.wait(),所有这些 wait() 加起来刚好是 n_worker + 1 个,这使得 initialize() 函数执行完,所有 worker 就会“跑起来”开始采样。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论