[原创] 强化学习框架 rlpyt 源码分析:(4) 收集训练数据的sampler类

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。

▶▶ sampler的主要功能
训练强化学习模型需要训练数据,收集训练数据的工作就是由sampler类做的。
收集训练数据,就需要在environment中步进,因此environment的实例化工作也在sampler中完成。

在很多强化学习教程中,收集数据也叫采样数据,这也是sampler这个名字的由来。但需要注意的是,真正去做“收集数据”这个工作的,是一种叫做collector的class。sampler会在 initialize() 的时候,把 collector 对象也初始化。
所以 sampler 可以看做是在 collector 外面又包了一层。
文章来源:https://www.codelast.com/
▶▶ BatchSpec里的 T 和 B 的概念
在 SerialSampler 的 initialize() 函数里,会看到实例化 environment 的代码:

B = self.batch_spec.B
envs = [self.EnvCls(**self.env_kwargs) for _ in range(B)]

这里会把 B 个 environment 对象构造出来。
我觉得作者起了一个非常不好的变量名:B。仔细看一下,self.batch_spec 这个变量是在 SerialSampler 的父类 BaseSampler 中赋值的:

self.batch_spec = BatchSpec(batch_T, batch_B)

而 BatchSpec 是一个父类为 namedtuple 的class:

class BatchSpec(namedtuple("BatchSpec", "T B")):

由 Python namedtuple 的性质可以知道,当用 BatchSpec(batch_T, batch_B) 构造一个对象的时候,该对象内部会生成两个成员变量 self.T 和 self.B,它们的值分别为 batch_T 和 batch_B。
这也是为什么 BatchSpec 类的 size() 函数可以这样写的原因:

@property
def size(self):
    return self.T * self.B

从 BatchSpec 类的注释里可以知道,T 是时间步(time step)的概念,B 是 独立的trajectory分段的概念。
所谓时间步 T 是指agent与一个environment交互时,会按时间先后顺序不断地步进到下一个state,走一步即一个step。此值>=1。
所谓独立的trajectory分段,是指独立的trajectory的数量,即environment实例的数量。此值>=1。

说到这里,不难发现,environment按B的数量来实例化是有道理的。
进一步:

global_B = B * world_size

基于之前的文章里提到的 world_size 的概念,就可以看出来这里的 global_B 指的是多个“平行宇宙”下的所有 environment 的数量和。
文章来源:https://www.codelast.com/
▶▶ env_ranks的概念
env_ranks 又是一个“没有注释,又很难看懂是什么意思”的东西。

env_ranks = list(range(rank * B, (rank + 1) * B))

在 example_1 中,env_ranks 计算出来得到了一个 list:[0]。
这里得到的list是一个长度为 B 的list,B为environment的数量。你需要一层层挖下去才知道它是干嘛用的。

env_ranks 在 rlpyt/samplers/serial/sampler.py 的两个地方用到了:一个是 agent 的 initialize() 函数,另一个是 collector 类的构造函数,如下:

agent.initialize(envs[0].spaces, share_memory=False,
                 global_B=global_B, env_ranks=env_ranks)

以及:

collector = self.CollectorCls(
    rank=0,
    envs=envs,
    samples_np=samples_np,
    batch_T=self.batch_spec.T,
    TrajInfoCls=self.TrajInfoCls,
    agent=agent,
    global_B=global_B,
    env_ranks=env_ranks,  # Might get applied redundantly to agent.
)

作者对第2种情况做了注释:“Might get applied redundantly to agent.” 这里的意思是:可能和agent(里面的逻辑)重复了。通过下面的分析可以知道,第1种情况和第2种情况最终会调用到同一个函数,因此它们确实是做了重复的工作。
文章来源:https://www.codelast.com/
分别看看这两个地方用 env_ranks 来做什么。
agent 的 initialize() 函数
在 DqnAgent 类的 initialize() 函数里,和 env_ranks 有关的代码,只有一个地方是有用的:

if env_ranks is not None:
    self.make_vec_eps(global_B, env_ranks)

这里调用的是 EpsilonGreedyAgentMixin 类的 make_vec_eps() 函数。巧合的是,这与下面的第2种情况相同,所以直接来分析第2种情况。
文章来源:https://www.codelast.com/
 collector类的构造函数
example_1 使用的 collector 类是 CpuResetCollector,在这个类的代码中(rlpyt/samplers/parallel/cpu/collectors.py)并没有使用 env_ranks,但是在其父类 DecorrelatingStartCollector 的父类 BaseCollector(这句话很拗口,“父类的父类”)的 start_agent() 函数里面,我们就会看到使用了 env_ranks:

def start_agent(self):
    if getattr(self, "agent", None) is not None:  # Not in GPU collectors.
        self.agent.collector_initialize(
            global_B=self.global_B,  # Args used e.g. for vector epsilon greedy.
            env_ranks=self.env_ranks,
        )
        self.agent.reset()
        self.agent.sample_mode(itr=0)

这里的 self.env_ranks 就是在 __init__() 里传入的,即 sampler 中传入的 env_ranks。
同时我们会看到,对 example_1 来说,if getattr(self, "agent", None) is not None 这个条件是满足的,因此这里会执行 agent.collector_initialize()。
example_1 的 agent 类是 DqnAgent,它有两个父类:BaseAgent 和 EpsilonGreedyAgentMixin,其中 BaseAgent 没有实现 collector_initialize() 函数:

def collector_initialize(self, global_B=1, env_ranks=None):
    """If need to initialize within CPU sampler (e.g. vector eps greedy)"""
    pass

而 EpsilonGreedyAgentMixin 类实现了 collector_initialize() 函数,所以最终调用的就是它(层层嵌套,已疯):

def collector_initialize(self, global_B=1, env_ranks=None):
    if env_ranks is not None:
        self.make_vec_eps(global_B, env_ranks)

所以这里的 make_vec_eps() 又是干了啥?

def make_vec_eps(self, global_B, env_ranks):
    if self.eps_final_min is not None and self.eps_final_min != self._eps_final_scalar:  # vector epsilon.
        if self.alternating:  # In FF case, sampler sets agent.alternating.
            assert global_B % 2 == 0
            global_B = global_B // 2  # Env pairs will share epsilon.
            env_ranks = list(set([i // 2 for i in env_ranks]))
        self.eps_init = self._eps_init_scalar * torch.ones(len(env_ranks))
        global_eps_final = torch.logspace(
            torch.log10(torch.tensor(self.eps_final_min)),
            torch.log10(torch.tensor(self._eps_final_scalar)),
            global_B)
        self.eps_final = global_eps_final[env_ranks]
    self.eps_sample = self.eps_init

可以看到这个函数就是为了计算 self.eps_final 以及 self.eps_sample 的值。
对 example_1 来说,self.eps_final_min 为 None,因此 make_vec_eps() 函数里最外层的 if 为 False,只有最后一句代码 self.eps_sample = self.eps_init 有实效,因此,env_ranks 在这里啥用也没有
“你让我看了这么多字,结果就告诉我它没用?!” 真不好意思,事实就是这样。
文章来源:https://www.codelast.com/
但是,env_ranks 对 example_1 没用,在其他的场景下还是有用的啊,我讲了这么多废话,还是没有说清楚 env_ranks 到底是干嘛的。我说一下我的理解对不同的environment实例,对它们用ε-greedy来选择action的时候,ε 可能是不同的。由于rlpyt在不同的并行模式下,会形成不同的“虚拟environment数量”的概念(比如在Alternating模式下,每两个environment构成的一个pair会共享相同的 ε 值,两个 environment 视为一个虚拟的environment),因此在各种场景下都要确定一个对应到实际场景下的、虚拟的environment数量,这就是env_ranks的含义。
再次强调,这只是我目前的理解,如果有一天我有了新的领悟,那我可能会回来修正这些表述。
文章来源:https://www.codelast.com/
▶▶ 收集训练数据发生的地方:obtain_samples() 函数
obtain_samples() 函数其实是调用了 collector 类的 collect_batch() 函数去收集训练数据:

def obtain_samples(self, itr):
    agent_inputs, traj_infos, completed_infos = self.collector.collect_batch(
        self.agent_inputs, self.traj_infos, itr)
    self.collector.reset_if_needed(agent_inputs)
    self.agent_inputs = agent_inputs
    self.traj_infos = traj_infos
    return self.samples_pyt, completed_infos

这里看上去有一点奇怪的是:收集到的数据 self.samples_pyt 并没有在 collect_batch() 函数中被更新,所以为什么每次收集一个batch的数据的时候,得到的 self.samples_pyt 都是最新的呢?
我觉得类似的现象在 rlpyt 中太多了,无形中增加了理解源码的难度。
文章来源:https://www.codelast.com/
要弄清楚这个问题,来看看 self.samples_pyt 是怎么定义的:在 initialize() 函数里有:

self.samples_pyt = samples_pyt

而 samples_pyt 是由另一个函数创建出来的:

samples_pyt, samples_np, examples = build_samples_buffer(agent, envs[0],
    self.batch_spec, bootstrap_value, agent_shared=False,
    env_shared=False, subprocess=False)

进这个函数里看一下就知道,samples_pyt 其实就是 samples_np 转成的对应 tensor 形式。而 PyTorch 和 NumPy array 是共享底层内存的,修改其中一个的数据会导致另一个也被修改,可以认为 samples_pyt 和 samples_np 在底层是对应到同一个东西对这句话我持保留意见,目前还不能完全肯定这种说法正确,需要进一步理解 rlpyt 源码才能给出确定的答案,但姑且这么理解先)。
文章来源:https://www.codelast.com/
此外,samples_np 被传给了 collector 类的构造函数:

collector = self.CollectorCls(
    rank=0,
    envs=envs,
    samples_np=samples_np,
    batch_T=self.batch_spec.T,
    TrajInfoCls=self.TrajInfoCls,
    agent=agent,
    global_B=global_B,
    env_ranks=env_ranks,  # Might get applied redundantly to agent.
)

所以这就相当于把 samples_pyt 和 collector 类建立了联系。
再看一下 example_1 的 collector 类(即 CpuResetCollector)的 collect_batch() 函数,它在计算返回值的时候,果然有用到 samples_np:

agent_buf, env_buf = self.samples_np.agent, self.samples_np.env

经过这么一绕,obtain_samples() 函数中返回的 self.samples_pyt 就有意义了。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论