查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。
▶▶ sampler的主要功能
训练强化学习模型需要训练数据,收集训练数据的工作就是由sampler类做的。
收集训练数据,就需要在environment中步进,因此environment的实例化工作也在sampler中完成。
在很多强化学习教程中,收集数据也叫采样数据,这也是sampler这个名字的由来。但需要注意的是,真正去做“收集数据”这个工作的,是一种叫做collector的class。sampler会在 initialize() 的时候,把 collector 对象也初始化。
所以 sampler 可以看做是在 collector 外面又包了一层。
文章来源:https://www.codelast.com/
▶▶ BatchSpec里的 T 和 B 的概念
在 SerialSampler 的 initialize() 函数里,会看到实例化 environment 的代码:
B = self.batch_spec.B envs = [self.EnvCls(**self.env_kwargs) for _ in range(B)]
这里会把 B 个 environment 对象构造出来。
我觉得作者起了一个非常不好的变量名:B。仔细看一下,self.batch_spec 这个变量是在 SerialSampler 的父类 BaseSampler 中赋值的:
self.batch_spec = BatchSpec(batch_T, batch_B)
而 BatchSpec 是一个父类为 namedtuple 的class:
class BatchSpec(namedtuple("BatchSpec", "T B")):
由 Python namedtuple 的性质可以知道,当用 BatchSpec(batch_T, batch_B) 构造一个对象的时候,该对象内部会生成两个成员变量 self.T 和 self.B,它们的值分别为 batch_T 和 batch_B。
这也是为什么 BatchSpec 类的 size() 函数可以这样写的原因:
@property def size(self): return self.T * self.B
从 BatchSpec 类的注释里可以知道,T 是时间步(time step)的概念,B 是 独立的trajectory分段的概念。
所谓时间步 T 是指agent与一个environment交互时,会按时间先后顺序不断地步进到下一个state,走一步即一个step。此值>=1。
所谓独立的trajectory分段,是指独立的trajectory的数量,即environment实例的数量。此值>=1。
说到这里,不难发现,environment按B的数量来实例化是有道理的。
进一步:
global_B = B * world_size
基于之前的文章里提到的 world_size 的概念,就可以看出来这里的 global_B 指的是多个“平行宇宙”下的所有 environment 的数量和。
文章来源:https://www.codelast.com/
▶▶ env_ranks的概念
env_ranks 又是一个“没有注释,又很难看懂是什么意思”的东西。
env_ranks = list(range(rank * B, (rank + 1) * B))
在 example_1 中,env_ranks 计算出来得到了一个 list:[0]。
这里得到的list是一个长度为 B 的list,B为environment的数量。你需要一层层挖下去才知道它是干嘛用的。
env_ranks 在 rlpyt/samplers/serial/sampler.py 的两个地方用到了:一个是 agent 的 initialize() 函数,另一个是 collector 类的构造函数,如下:
agent.initialize(envs[0].spaces, share_memory=False, global_B=global_B, env_ranks=env_ranks)
以及:
collector = self.CollectorCls( rank=0, envs=envs, samples_np=samples_np, batch_T=self.batch_spec.T, TrajInfoCls=self.TrajInfoCls, agent=agent, global_B=global_B, env_ranks=env_ranks, # Might get applied redundantly to agent. )
作者对第2种情况做了注释:“Might get applied redundantly to agent.” 这里的意思是:可能和agent(里面的逻辑)重复了。通过下面的分析可以知道,第1种情况和第2种情况最终会调用到同一个函数,因此它们确实是做了重复的工作。
文章来源:https://www.codelast.com/
分别看看这两个地方用 env_ranks 来做什么。
★ agent 的 initialize() 函数
在 DqnAgent 类的 initialize() 函数里,和 env_ranks 有关的代码,只有一个地方是有用的:
if env_ranks is not None: self.make_vec_eps(global_B, env_ranks)
这里调用的是 EpsilonGreedyAgentMixin 类的 make_vec_eps() 函数。巧合的是,这与下面的第2种情况相同,所以直接来分析第2种情况。
文章来源:https://www.codelast.com/
★ collector类的构造函数
example_1 使用的 collector 类是 CpuResetCollector,在这个类的代码中(rlpyt/samplers/parallel/cpu/collectors.py)并没有使用 env_ranks,但是在其父类 DecorrelatingStartCollector 的父类 BaseCollector(这句话很拗口,“父类的父类”)的 start_agent() 函数里面,我们就会看到使用了 env_ranks:
def start_agent(self): if getattr(self, "agent", None) is not None: # Not in GPU collectors. self.agent.collector_initialize( global_B=self.global_B, # Args used e.g. for vector epsilon greedy. env_ranks=self.env_ranks, ) self.agent.reset() self.agent.sample_mode(itr=0)
这里的 self.env_ranks 就是在 __init__() 里传入的,即 sampler 中传入的 env_ranks。
同时我们会看到,对 example_1 来说,if getattr(self, "agent", None) is not None 这个条件是满足的,因此这里会执行 agent.collector_initialize()。
example_1 的 agent 类是 DqnAgent,它有两个父类:BaseAgent 和 EpsilonGreedyAgentMixin,其中 BaseAgent 没有实现 collector_initialize() 函数:
def collector_initialize(self, global_B=1, env_ranks=None): """If need to initialize within CPU sampler (e.g. vector eps greedy)""" pass
而 EpsilonGreedyAgentMixin 类实现了 collector_initialize() 函数,所以最终调用的就是它(层层嵌套,已疯):
def collector_initialize(self, global_B=1, env_ranks=None): if env_ranks is not None: self.make_vec_eps(global_B, env_ranks)
所以这里的 make_vec_eps() 又是干了啥?
def make_vec_eps(self, global_B, env_ranks): if self.eps_final_min is not None and self.eps_final_min != self._eps_final_scalar: # vector epsilon. if self.alternating: # In FF case, sampler sets agent.alternating. assert global_B % 2 == 0 global_B = global_B // 2 # Env pairs will share epsilon. env_ranks = list(set([i // 2 for i in env_ranks])) self.eps_init = self._eps_init_scalar * torch.ones(len(env_ranks)) global_eps_final = torch.logspace( torch.log10(torch.tensor(self.eps_final_min)), torch.log10(torch.tensor(self._eps_final_scalar)), global_B) self.eps_final = global_eps_final[env_ranks] self.eps_sample = self.eps_init
可以看到这个函数就是为了计算 self.eps_final 以及 self.eps_sample 的值。
对 example_1 来说,self.eps_final_min 为 None,因此 make_vec_eps() 函数里最外层的 if 为 False,只有最后一句代码 self.eps_sample = self.eps_init 有实效,因此,env_ranks 在这里啥用也没有!
“你让我看了这么多字,结果就告诉我它没用?!” 真不好意思,事实就是这样。
文章来源:https://www.codelast.com/
但是,env_ranks 对 example_1 没用,在其他的场景下还是有用的啊,我讲了这么多废话,还是没有说清楚 env_ranks 到底是干嘛的。我说一下我的理解:对不同的environment实例,对它们用ε-greedy来选择action的时候,ε 可能是不同的。由于rlpyt在不同的并行模式下,会形成不同的“虚拟environment数量”的概念(比如在Alternating模式下,每两个environment构成的一个pair会共享相同的 ε 值,两个 environment 视为一个虚拟的environment),因此在各种场景下都要确定一个对应到实际场景下的、虚拟的environment数量,这就是env_ranks的含义。
再次强调,这只是我目前的理解,如果有一天我有了新的领悟,那我可能会回来修正这些表述。
文章来源:https://www.codelast.com/
▶▶ 收集训练数据发生的地方:obtain_samples() 函数
obtain_samples() 函数其实是调用了 collector 类的 collect_batch() 函数去收集训练数据:
def obtain_samples(self, itr): agent_inputs, traj_infos, completed_infos = self.collector.collect_batch( self.agent_inputs, self.traj_infos, itr) self.collector.reset_if_needed(agent_inputs) self.agent_inputs = agent_inputs self.traj_infos = traj_infos return self.samples_pyt, completed_infos
这里看上去有一点奇怪的是:收集到的数据 self.samples_pyt 并没有在 collect_batch() 函数中被更新,所以为什么每次收集一个batch的数据的时候,得到的 self.samples_pyt 都是最新的呢?
我觉得类似的现象在 rlpyt 中太多了,无形中增加了理解源码的难度。
文章来源:https://www.codelast.com/
要弄清楚这个问题,来看看 self.samples_pyt 是怎么定义的:在 initialize() 函数里有:
self.samples_pyt = samples_pyt
而 samples_pyt 是由另一个函数创建出来的:
samples_pyt, samples_np, examples = build_samples_buffer(agent, envs[0], self.batch_spec, bootstrap_value, agent_shared=False, env_shared=False, subprocess=False)
进这个函数里看一下就知道,samples_pyt 其实就是 samples_np 转成的对应 tensor 形式。而 PyTorch 和 NumPy array 是共享底层内存的,修改其中一个的数据会导致另一个也被修改,可以认为 samples_pyt 和 samples_np 在底层是对应到同一个东西(对这句话我持保留意见,目前还不能完全肯定这种说法正确,需要进一步理解 rlpyt 源码才能给出确定的答案,但姑且这么理解先)。
文章来源:https://www.codelast.com/
此外,samples_np 被传给了 collector 类的构造函数:
collector = self.CollectorCls( rank=0, envs=envs, samples_np=samples_np, batch_T=self.batch_spec.T, TrajInfoCls=self.TrajInfoCls, agent=agent, global_B=global_B, env_ranks=env_ranks, # Might get applied redundantly to agent. )
所以这就相当于把 samples_pyt 和 collector 类建立了联系。
再看一下 example_1 的 collector 类(即 CpuResetCollector)的 collect_batch() 函数,它在计算返回值的时候,果然有用到 samples_np:
agent_buf, env_buf = self.samples_np.agent, self.samples_np.env
经过这么一绕,obtain_samples() 函数中返回的 self.samples_pyt 就有意义了。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):