[原创] 强化学习框架 rlpyt 源码分析:(1) 随处可见的Python可变参数

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。
要先声明一下:rlpyt 的源码比较复杂,想要充分理解全部模块需要下很大的功夫,本系列“源码分析”文章,并没有把 rlpyt 的源码全部分析一遍,而只是分析了它的“冰山一角”,主要目的是让读者能了解它的基本结构及基本运作方式。

▶▶ 切入点
为了能看懂 rlpyt 的源码,最正确的做法就是先找一个可以运行的例子,再从它下手开始分析。下面就从 examples/example_1.py 这个 rlpyt 自带的例子开始干。
在分析这个例子之前,你可以把它运行一下,看看程序的输出,从而有一个感性的印象。
文章来源:https://www.codelast.com/
▶▶ 为什么有很多看似未定义的类成员变量(self.xxx)?
阅读 rlpyt 代码的时候,你会发现在很多class里面,都会有很多“cannot find declaration”的类成员变量(self.xxx),无论你用任何IDE都不能点击跳转,在IDE中不能导航,这就在一定程度上增加了理解难度。例如下面这个例子:
rlpyt cannot find declaration
事实上,这些看似未定义的变量,是通过可变参数的形式定义的,而不是在 __init__() 中显式定义的。
文章来源:https://www.codelast.com/
拿上面这个例子来说,它的 __init__() 函数如下:

def __init__(
    self,
    algo,
    agent,
    sampler,
    n_steps,
    seed=None,
    affinity=None,
    log_interval_steps=1e5,
):
    n_steps = int(n_steps)
    log_interval_steps = int(log_interval_steps)
    affinity = dict() if affinity is None else affinity
    save__init__args(locals())

显然 __init__() 里面并没有定义 self.affinity = affinity,但是后面的函数里使用 self.affinity Python为什么不报错?这是因为 save__init__args(locals()) 这个“幕后英雄”帮你做了这件事。
文章来源:https://www.codelast.com/
▶▶ Python的 local() 函数

locals() 函数会以字典(dict)类型返回当前位置的全部局部变量

举个例子:

class A(object):
    def __init__(self):
        var1 = 5
        var2 = 3
        local_vars = locals()
        print(local_vars)


a = A()

输出:

{'var2': 3, 'var1': 5, 'self': <__main__.A object at 0x7f337bb3b6d8>}

在 __init__() 中,var1 和 var2 是两个局部变量,而 locals() 函数把所有局部变量的变量名、变量值都获取到,并且保存在了一个dict里,因此最后print出来的内容里包含了它们。
文章来源:https://www.codelast.com/
▶▶ save__init__args() 函数的功能
这个函数只会在 __init_() 里使用,用于把可变参数保存到对象属性中:

def save__init__args(values, underscore=False, overwrite=False, subclass_only=False):
    """
    Use in __init__() only; assign all args/kwargs to instance attributes.
    To maintain precedence of args provided to subclasses, call this in the
    subclass before super().__init__() if save__init__args() also appears in
    base class, or use overwrite=True.  With subclass_only==True, only args/kwargs
    listed in current subclass apply.
    """
    prefix = "_" if underscore else ""
    self = values['self']
    args = list()
    Classes = type(self).mro()
    if subclass_only:
        Classes = Classes[:1]
    for Cls in Classes:  # class inheritances
        if '__init__' in vars(Cls):
            args += getfullargspec(Cls.__init__).args[1:]
    for arg in args:
        attr = prefix + arg
        if arg in values and (not hasattr(self, attr) or overwrite):
            setattr(self, attr, values[arg])

文章来源:https://www.codelast.com/
来看一下关键的几句代码:

self = values['self']

在类的__init__()里调用locals()函数时,会把"self"这个item放进dict里面,因此values['self']取到的就是那个类对象。
文章来源:https://www.codelast.com/

Classes = type(self).mro()

type(self)返回对象的类型,mro()函数返回该类型的方法解析顺序(MRO,Method Resolution Order)列表。Python是一种支持多重继承的语言,而 rlpyt 大量使用了Python的多重继承。使用类的mro()函数,Python可以计算出一个MRO列表,它代表了类继承的顺序。
文章来源:https://www.codelast.com/

if subclass_only:
    Classes = Classes[:1]

当用户指定了 subclass_only 为 True 时,Classes = Classes[:1] 就是取MRO列表的第一个元素,由于在MRO列表中,子类永远在父类前面,因此这里取第一个元素就是子类。
文章来源:https://www.codelast.com/

for Cls in Classes:  # class inheritances
    if '__init__' in vars(Cls):
        args += getfullargspec(Cls.__init__).args[1:]

这里对MRO列表中的所有类进行遍历。vars() 函数返回对象object的属性和属性值的字典对象,对定义了 __init()__ 函数的类来说,if 判断的结果为 True,就会执行下面的那条语句,对没有定义 __init__() 函数的类来说则不会。那么 args += ... 那里又是干了什么呢?
getfullargspec() 函数可以获取一个可调用对象(callable object)的名称和默认值。把 Cls.__init__ 作为 getfullargspec() 的参数,就是把 Cls 这个类的 __init__() 函数里的所有参数信息都取出来放到一个dict里,举个例子,下面的代码:

from inspect import getfullargspec


class A(object):
    def __init__(self, arg1, arg2='test_value', arg3=999):
        var1 = 5
        var2 = 3
        local_vars = locals()
        print(local_vars)


print(getfullargspec(A.__init__))
print(getfullargspec(A.__init__).args[1:])

输出:

FullArgSpec(args=['self', 'arg1', 'arg2', 'arg3'], varargs=None, varkw=None, defaults=('test_value', 999), kwonlyargs=[], kwonlydefaults=None, annotations={})
['arg1', 'arg2', 'arg3']

可见,getfullargspec(A.__init__).args[1:] 的作用是拿到了 class A 的所有参数名:arg1, arg2, arg3。
所以 for Cls in Classes: 那个循环的功能,就是把某个class的所有父类(同时也包括它自己)的参数列表给一鼓脑地取出来保存到 args 这个变量里。
文章来源:https://www.codelast.com/

for arg in args:
    attr = prefix + arg
    if arg in values and (not hasattr(self, attr) or overwrite):
        setattr(self, attr, values[arg])

这里是把每一个参数,都set到 self 这个类对应的对象的属性里,从而可以用 self.xxx 取到参数值。
总结一下 save__init__args() 这个函数:为什么要这么做?仔细看 rlpyt 的代码会发现,它使用了非常多的Python多重继续特性,而且它通常不会把所有参数都显式地在继承关系之间传递,这样可能是做到了简化代码的目的。正因为这样,才需要在 save__init__args() 这个函数里面,把所有参数都全部获取到,再保存到对象的属性里——要不然很多看似没有定义过的 self.xxx,就真的找不到了。
这种做法很 tricky,作者这样做有他的想法,从上面的分析我们也看到,这样做确实也有意义,如果不能理解就再读三遍。
文章来源:https://www.codelast.com/
▶▶ save__init__args()配合可变参数,达到神奇目的
save__init__args() 函数配合Python可变参数使用,使得 rlpyt 的大量参数可以在多个class之间方便传递。比如说 MinibatchRl 这个类,其 __init__() 方法里使用了Python可变参数 kwargs:

def __init__(self, log_traj_window=100, **kwargs):
    super().__init__(**kwargs)

这里的可变参数 **kwargs,是在调用者那里传入的,例如:

runner = MinibatchRl(
    algo=algo,
    agent=agent,
    sampler=sampler,
    n_steps=50e6,
    log_interval_steps=1e5,
    affinity=affinity,
)

这些参数被组织成了字典(dict)的形式放到 kwargs 里,再通过 MinibatchRl 的父类 MinibatchRlBase 的 __init__() 方法里的 save__init__args(locals()),全部保存到了类对象的属性中,所以在 MinibatchRl 的其他函数里面,就可以愉快地使用 self.agent,self.sampler 之类的对象了,其实它们是从定义 MinibatchRl 类的对象开始,一级级把参数传递进去的。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论