查看关于 rlpyt 的更多文章请点击这里。
rlpyt 是BAIR(Berkeley Artificial Intelligence Research,伯克利人工智能研究所)开源的一个强化学习(RL)框架。我之前写了一篇它的简介。
本文描述了在 rlpyt 框架下,如何使用一个预训练过的(pre-trained)model作为起点,来训练自己的RL模型的过程。
▶▶ 什么是预训练模型
引用一篇文章:
简单来说,预训练模型(pre-trained model)是前人为了解决类似问题所创造出来的模型。你在解决问题的时候,不用从零开始训练一个新模型,可以从在类似问题中训练过的模型入手。比如说,如果你想做一辆自动驾驶汽车,可以花数年时间从零开始构建一个性能优良的图像识别算法,也可以从Google在ImageNet数据集上训练得到的inception model(一个预训练模型)起步,来识别图像。一个预训练模型可能对于你的应用中并不是100%的准确对口,但是它可以为你节省大量功夫。
文章来源:https://www.codelast.com/
▶▶ rlpyt 对预训练模型的支持
以使用 DQN 算法的 example_1 为例,class DQN(RlAlgorithm) 的 __init__() 函数有一个 initial_optim_state_dict 参数:
initial_optim_state_dict=None,
另外,AtariDqnAgent 类的其中一个父类:DqnAgent,它又有一个父类 BaseAgent,在 __init__() 初始化的时候也有一个 initial_model_state_dict 参数:
def __init__(self, ModelCls=None, model_kwargs=None, initial_model_state_dict=None):
这两个地方,就是当你使用预训练模型的时候需要传入的参数。
但为什么会有两个参数?它们有什么区别?
✔ 前一个是Optimizer(优化器,例如 torch.optim.Adam)的 state_dict,其包含的参数有 learning rate 等。
✔ 后一个是model的 state_dict,其包含的参数有 model 的 weight、bias 等。
直观点,来个图(图片可放大):
从图中可以清楚地看到model里存储的数据,optimizer_state_dict 就是 Optimizer 的 state_dict,agent_state_dict 就是model的 state_dict。
文章来源:https://www.codelast.com/
▶▶ 代码实操:加载预训练模型
首先我们要有一个预训练模型文件,因此,我们先把没有修改过代码的 example_1 运行一段时间,生成一个 params.pkl 模型文件,假设此文件路径为:/home/codelast/rlpyt/data/local/20191111/example_1/run_0/params.pkl
现在修改 example_1.py,可以加载预训练模型了:
# 加载预训练模型 model_loaded = torch.load('/home/codelast/rlpyt/data/local/20191111/example_1/run_0/params.pkl') optimizer_state_dict = model_loaded['optimizer_state_dict'] agent_state_dict = model_loaded['agent_state_dict'] algo = DQN(min_steps_learn=1e3, initial_optim_state_dict=optimizer_state_dict) agent = AtariDqnAgent(initial_model_state_dict=agent_state_dict['model'])
其他代码无需修改,就这么简单!
再重新运行修改过的example,现在就已经是在pre-trained model的基础上继续进行的训练了。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):