诸神缄默不语-个人CSDN博文目录

写作业的时候发现老师强调 F.dropout() 必须要传入 self.training,感到疑惑,所以上网搜寻了一下解释,最终明白了情况。

dropout方法出自Dropout: A Simple Way to Prevent Neural Networks from Overfitting,证明该方法有效的文献:Improving neural networks by preventing co-adaptation of feature detectors
dropout方法是将输入Tensor的元素按伯努利分布随机置0,具体原理此处不赘,以后待补。总之就是训练的时候要用dropout,验证/测试的时候要关dropout。在PyTorch中的实现,是在训练阶段时输出直接乘以 1 1 − p \frac{1}{1-p} 1p1,测试阶段就直接当恒等函数1来用2

Dropout一般适合于全连接层部分,而卷积层由于其参数并不是很多,所以不需要dropout,加上的话对模型的泛化能力并没有太大的影响。
我们一般在网络的最开始和结束的时候使用全连接层,而hidden layers则是网络中的卷积层。所以一般情况,在全连接层部分,采用较大概率的dropout而在卷积层采用低概率或者不采用dropout。3

以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()eval()方法影响而改变的机制。

1. Module的training属性

torch.nn.Module官方文档
是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。
默认为True,也就是Module初始化时默认为训练状态。

2. torch.nn.functional.dropout的入参training

torch.nn.functional.dropout官方文档

torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)
入参training默认为True,置True时应用Dropout,置False时不用。
因此在调用F.dropout()时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。

示例代码:

x=F.dropout(x,p,self.training)

3. torch.nn.Dropout不需要手动开关

torch.nn.Dropout官方文档

torch.nn.Dropout(p=0.5, inplace=False)

其源代码为(Dropout源码):

class Dropout(_DropoutNd):
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, self.training, self.inplace)

就这个类相当于将 F.dropout() 进行了包装,内置传入了self.training,就不用像在 F.dropout() 里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。

示例代码:

m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)

4. Module的train()和eval()方法改变self.training

torch.nn.Module.train官方文档
train(mode=True)
如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。

torch.nn.Module.eval官方文档
eval()
将Module设置为evaluation mode,相当于 self.train(False)

5. 除正文中已列文档外的参考资料

  1. PyTorch 有哪些坑/bug? - 雷杰的回答 - 知乎 那时候F.dropout的training默认置False,更容易错了……
  2. F.dropout源代码
  3. (深度学习)Pytorch之dropout训练_junbaba_的博客-CSDN博客
  4. torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系_chaiiiiiiiiiiiiiiiii的博客-CSDN博客

  1. 恒等函数 identity function identity function维基百科 ↩︎

  2. 测试阶段的源代码比较直接,可以看出来(来源:pytorch/symbolic_opset9.py at master · pytorch/pytorch):
    在这里插入图片描述 ↩︎

  3. 数据竞赛中如何优化深度学习模型 ↩︎

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐