查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。
当你使用 rlpyt 来实现自己的强化学习程序时,可能会遇到类似于下面这样的错误:
RuntimeError: size mismatch, m1: [1 x 365], m2: [461 x 32] at /tmp/pip-req-build-_357f2zr/aten/src/TH/generic/THTensorMath.cpp:752
本文分析错误原因及解决办法。
▶▶ 错误原因
可能是由于observation space的期望shape与实际shape不匹配造成的。
observation space的期望shape定义在自己写的environment类中,例如:
self._observation_space = IntBox(
low=0, high=1,
shape=461,
dtype="int32")
里面的 shape 必须与输入network的特征向量的长度相同。
实际的shape,由自定义的environment类的 get_obs() 函数所决定:
def get_obs(self) -> np.ndarray:
observation: np.ndarray = xxx # 此处需要自己实现
return observation
文章来源:https://www.codelast.com/
▶▶ 解决办法
当出现上面的错误时,以串行模式断点调试上面的程序,在上面两处地方都加上断点,看看期望的shape以及实际的observation shape是不是不相等,如果不相等,就要去调查为什么实际的shape是错的了。解决这个问题以后,上面的问题就迎刃而解。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):