我一直在寻找在 PyTorch 中保存训练模型的替代方法。到目前为止,我找到了两种选择。
我遇到过这个讨论,其中建议使用方法 2 而不是方法 1。
我的问题是,为什么首选第二种方法?仅仅是因为torch.nn模块具有这两个功能,我们被鼓励使用它们吗?
在他们的 github repo 上找到了这个页面,我将在这里复制粘贴内容。
序列化和恢复模型有两种主要方法。
第一个(推荐)只保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后后来:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
the_model = torch.load(PATH)
但是在这种情况下,序列化的数据绑定到特定的类和使用的确切目录结构,因此在其他项目中使用时,或者经过一些严重的重构后,它可能会以各种方式中断。
更新 :另请参阅PyTorch 教程中的保存和加载模型部分