查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 本文是上一篇文章的续文,继续分析CpuSampler的源码。
我们已经知道了CpuSampler有两个父类:BaseSampler 和 ParallelSamplerBase。其中,BaseSampler主要是定义了一堆接口,没什么好说的,因此本文接着分析另一个父类 ParallelSamplerBase。在 ParallelSamplerBase 中,初始化函数 initialize() 做了很多重要的工作,已经够写一篇长长的文章来分析了,这正是本文的主要内容。
▶▶ 初始化函数 initialize() 做了哪些重要工作
一句话总结 initialize() 的重要功能:计算一些特殊参数的值,初始化agent,创建并行控制器,创建并启动多个worker进程。
✍ 这里说的“并行控制器”(parallel ctrl)是指用Python multiprocessing模块来实现并行功能的时候,需要使用一些变量来协调各个并行的进程,使它们可以正确运作。这些用于协调的变量就是“并行控制器”。
▶▶ 计算特殊参数的值
在并行模式下,有些参数(比如采样用的worker的数量)不是由用户直接设置的,而是计算出来的。而且这样的参数还挺多,所以有大段大段的代码都用来干这事了。
如果下面的代码没有注释的话,肯定会让人一头雾水:
n_envs_list = self._get_n_envs_list(affinity=affinity) # 用户设置的worker数不一定与environment数相匹配,这里会重新调整 self.n_worker = n_worker = len(n_envs_list) # 经过调整之后的worker数 B = self.batch_spec.B # environment实例的数量 global_B = B * world_size # "平行宇宙"概念下的environment实例的数量 env_ranks = list(range(rank * B, (rank + 1) * B)) # 含义可参考:https://www.codelast.com/?p=10932 self.world_size = world_size self.rank = rank if self.eval_n_envs > 0: # 在example_*.py中传入的参数 self.eval_n_envs_per = max(1, self.eval_n_envs // n_worker) # 计算每个worker至少承载几个evaluation的environment(至少1) self.eval_n_envs = eval_n_envs = self.eval_n_envs_per * n_worker # 保证至少有"worker数量"个eval environment实例 logger.log(f"Total parallel evaluation envs: {eval_n_envs}.") self.eval_max_T = eval_max_T = int(self.eval_max_steps // eval_n_envs)
最为“神奇”的就是 self._get_n_envs_list() 这个函数,它用来计算每个worker承载几个environment实例。这个说法是不是特别奇怪?原因是:用户可以指定environment实例的数量,也可以指定worker的数量,但这两个数量可能是不相等的,于是,要么worker数不够,要么worker数有多;在第1种情况下,一个worker需要带>1个environment实例,在第2种情况下,不需要那么多worker,所以要减少worker的数量,才能保证一个worker刚好带一个environment实例。
文章来源:https://www.codelast.com/
我给 self._get_n_envs_list() 函数加上了注释,相信足以让大家理解它的功能了:
def _get_n_envs_list(self, affinity=None, n_worker=None, B=None): """ 根据environment实例的数量(所谓的"B"),以及用户设定的用于采样的worker的数量(n_worker),来计算得到一个list,这个list的元素的总数, 就是最终的worker的数量;而这个list里的每个元素的值,分别是每个worker承载的environment实例的数量。 :param affinity: 一个字典(dict),包含硬件亲和性定义。 :param n_worker: 用户设定的用于采样的worker的数量。 :param B: environment实例的数量。 :return 一个list,其含义如上所述。 """ B = self.batch_spec.B if B is None else B # 参考BatchSpec类,可以认为B是environment实例的数量 n_worker = len(affinity["workers_cpus"]) if n_worker is None else n_worker # worker的数量(不超过物理CPU数否则在别处报错) """ 当environment实例的数量<worker的数量时,例如有8个worker(即8个物理CPU),5个environment实例,每一个物理CPU运行一个environment, 那么此时会有3个物理CPU多余,此时就会把worker的数量设置成和environment实例数量一样,使得每个CPU都刚好运行一个environment实例。 """ if B < n_worker: logger.log(f"WARNING: requested fewer envs ({B}) than available worker " f"processes ({n_worker}). Using fewer workers (but maybe better to " "increase sampler's `batch_B`.") n_worker = B n_envs_list = [B // n_worker] * n_worker """ 当environment实例的数量不是worker数量的整数倍时,每个worker被分配到的environment实例的数量是不均等的。 """ if not B % n_worker == 0: logger.log("WARNING: unequal number of envs per process, from " f"batch_B {self.batch_spec.B} and n_worker {n_worker} " "(possible suboptimal speed).") for b in range(B % n_worker): n_envs_list[b] += 1 return n_envs_list
文章来源:https://www.codelast.com/
▶▶ 初始化agent
agent对象只有一个!并不是每一个worker进程都对应到不同的agent对象!这是理解CpuSampler时需要知晓的一个重要概念。
agent通过以下代码初始化(ParallelSamplerBase.initialize() 函数):
env = self.EnvCls(**self.env_kwargs) self._agent_init(agent, env, global_B=global_B, env_ranks=env_ranks) examples = self._build_buffers(env, bootstrap_value) env.close() del env
可以看到,这里初始化了environment对象,并把它作为一个参数传给了agent初始化函数 self._agent_init(),事实上,在 self._agent_init() 函数里,只用到了 env 对象的 spaces 这个属性,而没有引用整个 env 对象,因此在使用完之后,使用 env.close() 以及 del env 来清理掉env不会有问题。
self._build_buffers() 是一个非常复杂的操作,它的主要功能是创建强化学习中必备的replay buffer。直觉上,有人可能认为replay buffer这个东西,不就是创建一个list或者类似的数据结构就能搞定的吗?但实际上不是这么简单,从这个函数一级级点进去就会发现代码还不少,而且它里面甚至还用到了Python multiprocessing,所以创建replay buffer的实现就不在本文分析了。
文章来源:https://www.codelast.com/
self._agent_init() 函数的实现很简单:
def _agent_init(self, agent, env, global_B=1, env_ranks=None): agent.initialize(env.spaces, share_memory=True, global_B=global_B, env_ranks=env_ranks) self.agent = agent
在这里看到:agent初始化之后,赋值给了 self.agent,这就是 CpuSampler 中唯一使用的 agent 对象。
文章来源:https://www.codelast.com/
▶▶ 创建并行控制器
并行控制器(parallel ctrl)用于协调多个采样用的worker进程。
在 initialize() 里,创建并行控制器的代码只有一句:
def _build_parallel_ctrl(self, n_worker): """ 创建用于控制并行训练过程的一些数据结构。 multiprocessing.RawValue:不存在lock的多进程间共享值。 multiprocessing.Barrier:一种简单的同步原语,用于固定数目的进程相互等待。当所有进程都调用wait以后,所有进程会同时开始执行。 multiprocessing.Queue:用于多进程间数据传递的消息队列。 :param n_worker: 真正的worker数(不一定等于用户设置的那个原始值)。 """ self.ctrl = AttrDict( quit=mp.RawValue(ctypes.c_bool, False), barrier_in=mp.Barrier(n_worker + 1), barrier_out=mp.Barrier(n_worker + 1), do_eval=mp.RawValue(ctypes.c_bool, False), itr=mp.RawValue(ctypes.c_long, 0), ) self.traj_infos_queue = mp.Queue() self.eval_traj_infos_queue = mp.Queue() self.sync = AttrDict(stop_eval=mp.RawValue(ctypes.c_bool, False))
这里AttrDict是一个“扩展的”dict,mp就是Python multiprocessing模块,而Python multiprocessing是一个巨大的话题,我自己也只是初步了解,所以没办法讲透彻,这里只举两个例子,来说明这些并行控制器的作用:
✔ ctrl.quit 可以理解为一个bool类型的进程间共享变量。在 minibatch_rl.py 中,训练完成后,会执行 shutdown(),它会调用 sampler.shutdown(),从而会把 ctrl.quit 的值设置为True;同时,在 worker.py 中会看到,当检测到 ctrl.quit 的值为True时,会退出采样过程。所有采样的worker进程都受这个变量控制。所以这样就做到了在主进程中控制并行跑的worker进程。
✔ multiprocessing.Queue() 用于在多进程间传递消息。在每个采样的worker进程中,会把收集到的trajectory info放到同一个traj_infos_queue中,在主进程中会把汇总的trajectory info进一步处理成统计数据,然后记日志、打印到屏幕上,等等。
文章来源:https://www.codelast.com/
▶▶ 创建并启动多个worker进程
worker进程用于采样(agent与environment交互得到的)数据。
在创建这些进程之前,需要先为它们构建所需的参数:
common_kwargs = self._assemble_common_kwargs(affinity, global_B) workers_kwargs = self._assemble_workers_kwargs(affinity, seed, n_envs_list)
为什么需要分成 common_kwargs 以及 workers_kwargs 两个参数?这是因为:对每个worker进程来说,有些参数是通用的,有些参数是不通用的(例如,每个worker使用的CPU数量、承载的environment实例的数量等),因此,rlpyt把它们分成了两拨,分别放在两个对象里。
在准备好了参数之后,就开始创建多个worker进程,并把它们启动起来了:
# 创建一批子进程 target = sampling_process if worker_process is None else worker_process self.workers = [mp.Process(target=target, kwargs=dict(common_kwargs=common_kwargs, worker_kwargs=w_kwargs)) for w_kwargs in workers_kwargs] # 启动子进程 for w in self.workers: w.start() self.ctrl.barrier_out.wait() # Wait for workers ready (e.g. decorrelate).
在这里,使用的是 multiprocessing.Process() 来创建的进程,target 为进程函数名,进程函数是可以自行指定的,rlpyt也提供了默认的实现,即 worker.py 中的 sampling_process() 函数。采样进程的实现代码 worker.py 虽然不长,但要完全看懂并不容易,所以留到后面的文章再分析。
在worker进程启动之后,它就进入了持续的采样过程。注意上面代码的最后一句 self.ctrl.barrier_out.wait(),这里使用了 multiprocessing的Barrier来控制各个worker进程同步。由于 barrier_out 创建的时候是这样的:
barrier_out=mp.Barrier(n_worker + 1)
所以,它需要 n_worker + 1 个 wait() 才能让所有进程同时“解锁”(即同时开始执行),在 initialize() 函数里的 self.ctrl.barrier_out.wait() 算一个,每个worker函数——即 sampling_process()——里也分别有一个 barrier_out.wait(),所有这些 wait() 加起来刚好是 n_worker + 1 个,这使得 initialize() 函数执行完,所有 worker 就会“跑起来”开始采样。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):