根据PyTorch文档,在把PyTorch模型保存成文件的时候有两种方法,第一种是推荐的:
torch.save(the_model.state_dict(), PATH)
对应地,加载模型这样做:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
另一种方法是不推荐的:
torch.save(the_model, PATH)
对应地,加载模型这样做:
the_model = torch.load(PATH)
文章来源:https://www.codelast.com/
这两者的区别:第1种方法只保存了模型的参数,而第2种方法保存了整个模型(结构+参数),所以第2种方法保存出来的文件体积会比第1种方法大。
使用第2种方法的话,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用或经过一些大的重构后,它可能会失效。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):