小编典典

model.train() 在 PyTorch 中做了什么?

all

它来电forward()nn.Module?我认为当我们调用模型时,forward正在使用方法。为什么我们需要指定 train()?


阅读 3524

收藏
2022-09-02

共1个答案

小编典典

model.train()告诉您的模型您正在训练模型。这有助于通知诸如 Dropout 和 BatchNorm
等层,这些层旨在在训练和评估期间表现不同。例如,在训练模式下,BatchNorm 更新每个新批次的移动平均值;而对于评估模式,这些更新被冻结。

更多细节:
model.train()将模式设置为训练(参见源代码)。您可以致电model.eval()model.train(mode=False)告知您正在测试。期望train函数训练模型有点直观,但它并没有这样做。它只是设置模式。

2022-09-02