[原创] 强化学习框架 rlpyt 的 size mismatch 错误原因及解决办法

查看关于 rlpyt 的更多文章请点击这里

rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。 
当你使用 rlpyt 来实现自己的强化学习程序时,可能会遇到类似于下面这样的错误:

RuntimeError: size mismatch, m1: [1 x 365], m2: [461 x 32] at /tmp/pip-req-build-_357f2zr/aten/src/TH/generic/THTensorMath.cpp:752

本文分析错误原因及解决办法。

▶▶ 错误原因
可能是由于observation space的期望shape与实际shape不匹配造成的。
observation space的期望shape定义在自己写的environment类中,例如:

self._observation_space = IntBox(
            low=0, high=1,
            shape=461,
            dtype="int32")

里面的 shape 必须与输入network的特征向量的长度相同。

实际的shape,由自定义的environment类的 get_obs() 函数所决定:

def get_obs(self) -> np.ndarray:
    observation: np.ndarray = xxx  # 此处需要自己实现
    return observation

文章来源:https://www.codelast.com/
▶▶ 解决办法
当出现上面的错误时,以串行模式断点调试上面的程序,在上面两处地方都加上断点,看看期望的shape以及实际的observation shape是不是不相等,如果不相等,就要去调查为什么实际的shape是错的了。解决这个问题以后,上面的问题就迎刃而解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论