查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 如果你想用这个框架来开发自己的强化学习程序(尤其是那些不属于Atari游戏领域的强化学习程序),那么需要对它的源码有一定的了解。本文尝试从 rlpyt 自带的一个实例来分析它的部分源码,希望能帮助到一小部分人。
要先声明一下:rlpyt 的源码比较复杂,想要充分理解全部模块需要下很大的功夫,本系列“源码分析”文章,并没有把 rlpyt 的源码全部分析一遍,而只是分析了它的“冰山一角”,主要目的是让读者能了解它的基本结构及基本运作方式。
▶▶ 切入点
为了能看懂 rlpyt 的源码,最正确的做法就是先找一个可以运行的例子,再从它下手开始分析。下面就从 examples/example_1.py 这个 rlpyt 自带的例子开始干。
在分析这个例子之前,你可以把它运行一下,看看程序的输出,从而有一个感性的印象。
文章来源:https://www.codelast.com/
▶▶ 为什么有很多看似未定义的类成员变量(self.xxx)?
阅读 rlpyt 代码的时候,你会发现在很多class里面,都会有很多“cannot find declaration”的类成员变量(self.xxx),无论你用任何IDE都不能点击跳转,在IDE中不能导航,这就在一定程度上增加了理解难度。例如下面这个例子:
事实上,这些看似未定义的变量,是通过可变参数的形式定义的,而不是在 __init__() 中显式定义的。
文章来源:https://www.codelast.com/
拿上面这个例子来说,它的 __init__() 函数如下:
def __init__( self, algo, agent, sampler, n_steps, seed=None, affinity=None, log_interval_steps=1e5, ): n_steps = int(n_steps) log_interval_steps = int(log_interval_steps) affinity = dict() if affinity is None else affinity save__init__args(locals())
显然 __init__() 里面并没有定义 self.affinity = affinity,但是后面的函数里使用 self.affinity Python为什么不报错?这是因为 save__init__args(locals()) 这个“幕后英雄”帮你做了这件事。
文章来源:https://www.codelast.com/
▶▶ Python的 local() 函数
locals() 函数会以字典(dict)类型返回当前位置的全部局部变量。
举个例子:
class A(object): def __init__(self): var1 = 5 var2 = 3 local_vars = locals() print(local_vars) a = A()
输出:
{'var2': 3, 'var1': 5, 'self': <__main__.A object at 0x7f337bb3b6d8>}
在 __init__() 中,var1 和 var2 是两个局部变量,而 locals() 函数把所有局部变量的变量名、变量值都获取到,并且保存在了一个dict里,因此最后print出来的内容里包含了它们。
文章来源:https://www.codelast.com/
▶▶ save__init__args() 函数的功能
这个函数只会在 __init_() 里使用,用于把可变参数保存到对象属性中:
def save__init__args(values, underscore=False, overwrite=False, subclass_only=False): """ Use in __init__() only; assign all args/kwargs to instance attributes. To maintain precedence of args provided to subclasses, call this in the subclass before super().__init__() if save__init__args() also appears in base class, or use overwrite=True. With subclass_only==True, only args/kwargs listed in current subclass apply. """ prefix = "_" if underscore else "" self = values['self'] args = list() Classes = type(self).mro() if subclass_only: Classes = Classes[:1] for Cls in Classes: # class inheritances if '__init__' in vars(Cls): args += getfullargspec(Cls.__init__).args[1:] for arg in args: attr = prefix + arg if arg in values and (not hasattr(self, attr) or overwrite): setattr(self, attr, values[arg])
文章来源:https://www.codelast.com/
来看一下关键的几句代码:
self = values['self']
在类的__init__()里调用locals()函数时,会把"self"这个item放进dict里面,因此values['self']取到的就是那个类对象。
文章来源:https://www.codelast.com/
Classes = type(self).mro()
type(self)返回对象的类型,mro()函数返回该类型的方法解析顺序(MRO,Method Resolution Order)列表。Python是一种支持多重继承的语言,而 rlpyt 大量使用了Python的多重继承。使用类的mro()函数,Python可以计算出一个MRO列表,它代表了类继承的顺序。
文章来源:https://www.codelast.com/
if subclass_only: Classes = Classes[:1]
当用户指定了 subclass_only 为 True 时,Classes = Classes[:1] 就是取MRO列表的第一个元素,由于在MRO列表中,子类永远在父类前面,因此这里取第一个元素就是子类。
文章来源:https://www.codelast.com/
for Cls in Classes: # class inheritances if '__init__' in vars(Cls): args += getfullargspec(Cls.__init__).args[1:]
这里对MRO列表中的所有类进行遍历。vars() 函数返回对象object的属性和属性值的字典对象,对定义了 __init()__ 函数的类来说,if 判断的结果为 True,就会执行下面的那条语句,对没有定义 __init__() 函数的类来说则不会。那么 args += ... 那里又是干了什么呢?
getfullargspec() 函数可以获取一个可调用对象(callable object)的名称和默认值。把 Cls.__init__ 作为 getfullargspec() 的参数,就是把 Cls 这个类的 __init__() 函数里的所有参数信息都取出来放到一个dict里,举个例子,下面的代码:
from inspect import getfullargspec class A(object): def __init__(self, arg1, arg2='test_value', arg3=999): var1 = 5 var2 = 3 local_vars = locals() print(local_vars) print(getfullargspec(A.__init__)) print(getfullargspec(A.__init__).args[1:])
输出:
FullArgSpec(args=['self', 'arg1', 'arg2', 'arg3'], varargs=None, varkw=None, defaults=('test_value', 999), kwonlyargs=[], kwonlydefaults=None, annotations={})['arg1', 'arg2', 'arg3']
可见,getfullargspec(A.__init__).args[1:] 的作用是拿到了 class A 的所有参数名:arg1, arg2, arg3。
所以 for Cls in Classes: 那个循环的功能,就是把某个class的所有父类(同时也包括它自己)的参数列表给一鼓脑地取出来保存到 args 这个变量里。
文章来源:https://www.codelast.com/
for arg in args: attr = prefix + arg if arg in values and (not hasattr(self, attr) or overwrite): setattr(self, attr, values[arg])
这里是把每一个参数,都set到 self 这个类对应的对象的属性里,从而可以用 self.xxx 取到参数值。
总结一下 save__init__args() 这个函数:为什么要这么做?仔细看 rlpyt 的代码会发现,它使用了非常多的Python多重继续特性,而且它通常不会把所有参数都显式地在继承关系之间传递,这样可能是做到了简化代码的目的。正因为这样,才需要在 save__init__args() 这个函数里面,把所有参数都全部获取到,再保存到对象的属性里——要不然很多看似没有定义过的 self.xxx,就真的找不到了。
这种做法很 tricky,作者这样做有他的想法,从上面的分析我们也看到,这样做确实也有意义,如果不能理解就再读三遍。
文章来源:https://www.codelast.com/
▶▶ save__init__args()配合可变参数,达到神奇目的
save__init__args() 函数配合Python可变参数使用,使得 rlpyt 的大量参数可以在多个class之间方便传递。比如说 MinibatchRl 这个类,其 __init__() 方法里使用了Python可变参数 kwargs:
def __init__(self, log_traj_window=100, **kwargs): super().__init__(**kwargs)
这里的可变参数 **kwargs,是在调用者那里传入的,例如:
runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e5, affinity=affinity, )
这些参数被组织成了字典(dict)的形式放到 kwargs 里,再通过 MinibatchRl 的父类 MinibatchRlBase 的 __init__() 方法里的 save__init__args(locals()),全部保存到了类对象的属性中,所以在 MinibatchRl 的其他函数里面,就可以愉快地使用 self.agent,self.sampler 之类的对象了,其实它们是从定义 MinibatchRl 类的对象开始,一级级把参数传递进去的。
文章来源:https://www.codelast.com/
这一节就到这,且听下回分解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):