查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 本文是上一篇文章的续文,继续分析CpuSampler的源码。
本文将分析 CPU并行模式下的 ParallelSamplerBase 类的worker实现。
▶▶ worker的代码在哪
rlpyt/samplers/parallel/worker.py
▶▶ worker是做什么用的
用于采样agent与environment交互得到的数据。
文章来源:https://www.codelast.com/
▶▶ 代码分析
我直接在代码里加了大量注释:
def initialize_worker(rank, seed=None, cpu=None, torch_threads=None): """ 初始化采样用的worker。 :param rank: 采样进程的标识序号。 :param seed: 种子,一个整数值。 :param cpu: CPU序号,例如 0, 1, 2 等等。 :param torch_threads: CPU并发执行的线程数。 """ log_str = f"Sampler rank {rank} initialized" cpu = [cpu] if isinstance(cpu, int) else cpu p = psutil.Process() try: if cpu is not None: p.cpu_affinity(cpu) # 设置CPU亲和性(MacOS不支持) cpu_affin = p.cpu_affinity() except AttributeError: cpu_affin = "UNAVAILABLE MacOS" log_str += f", CPU affinity {cpu_affin}" torch_threads = (1 if torch_threads is None and cpu is not None else torch_threads) # Default to 1 to avoid possible MKL hang. if torch_threads is not None: torch.set_num_threads(torch_threads) # 设置CPU并发执行的线程数 log_str += f", Torch threads {torch.get_num_threads()}" if seed is not None: set_seed(seed) time.sleep(0.3) # (so the printing from set_seed is not intermixed) log_str += f", Seed {seed}" logger.log(log_str) def sampling_process(common_kwargs, worker_kwargs): """ Arguments fed from the Sampler class in master process. 采样进程函数。 :param common_kwargs: 各个worker通用的参数列表。 :param worker_kwargs: 各个worker可能不同的参数列表。 """ c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs) initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads) # 初始化用于training的environment实例和collector实例 envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)] collector = c.CollectorCls( rank=w.rank, envs=envs, samples_np=w.samples_np, batch_T=c.batch_T, TrajInfoCls=c.TrajInfoCls, agent=c.get("agent", None), # Optional depending on parallel setup. sync=w.get("sync", None), step_buffer_np=w.get("step_buffer_np", None), global_B=c.get("global_B", 1), env_ranks=w.get("env_ranks", None), ) agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps) # 这里会做收集(采样)第一批数据的工作 collector.start_agent() # collector的初始化 # 初始化用于evaluation的environment实例和collector实例 if c.get("eval_n_envs", 0) > 0: eval_envs = [c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)] eval_collector = c.eval_CollectorCls( rank=w.rank, envs=eval_envs, TrajInfoCls=c.TrajInfoCls, traj_infos_queue=c.eval_traj_infos_queue, max_T=c.eval_max_T, agent=c.get("agent", None), sync=w.get("sync", None), step_buffer_np=w.get("eval_step_buffer_np", None), ) else: eval_envs = list() ctrl = c.ctrl # 用于控制多个worker进程同时运行时能正确运作的控制器 ctrl.barrier_out.wait() # 每个worker都有一个wait(),加上ParallelSamplerBase.initialize()中的一个wait(),刚好n_worker+1个 while True: collector.reset_if_needed(agent_inputs) # Outside barrier? ctrl.barrier_in.wait() if ctrl.quit.value: # 在主进程中set了这个值为True时,所有worker进程会退出采样 break if ctrl.do_eval.value: # 在主进程的evaluate_agent()函数里set了这个值为True时,这里才会收集evaluation用的数据 eval_collector.collect_evaluation(ctrl.itr.value) # Traj_infos to queue inside. else: # 不是做evaluation agent_inputs, traj_infos, completed_infos = collector.collect_batch( agent_inputs, traj_infos, ctrl.itr.value) for info in completed_infos: c.traj_infos_queue.put(info) # 向所有worker进程共享的队列塞入当前worker的统计数据 ctrl.barrier_out.wait() # 清理environment for env in envs + eval_envs: env.close()
文章来源:https://www.codelast.com/
在worker的代码中,比较绕的就是,worker是怎么把采样到的数据返回放到replay buffer里的?
在上一篇文章中,我们知道 ParallelSamplerBase.initialize() 函数初始化了replay buffer:
examples = self._build_buffers(env, bootstrap_value)
以及:
def _build_buffers(self, env, bootstrap_value): self.samples_pyt, self.samples_np, examples = build_samples_buffer( self.agent, env, self.batch_spec, bootstrap_value, agent_shared=True, env_shared=True, subprocess=True) return examples
在这里,self.samples_np 对应的是replay buffer的存储对象。而 worker 的参数 workers_kwargs 初始化的时候,会把 self.samples_np 拆分成多个slice,并传入 worker:
samples_np=self.samples_np[:, slice_B],
在 worker 中,构造 collector 对象的时候,会把这个传入的 samples_np 再传给 collector 的构造函数。这样,replay buffer 就与 collector 关联起来了。
最后,在 collector.collect_batch() 的时候,会把采样得到的数据放入 samples_np 中,也就是相当于放到了 replay buffer 里。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):