小编典典

在 PyTorch 中保存训练模型的最佳方法是什么?

all

我一直在寻找在 PyTorch 中保存训练模型的替代方法。到目前为止,我找到了两种选择。

  1. torch.save()保存模型和torch.load()加载模型。
  2. model.state_dict()保存训练好的模型,model.load_state_dict()加载保存的模型。

我遇到过这个讨论,其中建议使用方法 2 而不是方法 1。

我的问题是,为什么首选第二种方法?仅仅是因为torch.nn模块具有这两个功能,我们被鼓励使用它们吗?


阅读 104

收藏
2022-04-20

共1个答案

小编典典

在他们的 github repo
上找到了这个页面,我将在这里复制粘贴内容。


保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)只保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后后来:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后后来:

the_model = torch.load(PATH)

但是在这种情况下,序列化的数据绑定到特定的类和使用的确切目录结构,因此在其他项目中使用时,或者经过一些严重的重构后,它可能会以各种方式中断。


更新 :另请参阅PyTorch
教程中的保存和加载模型部分

2022-04-20