PyTorch的F.dropout为什么要加self.training?
解释了PyTorch的dropout层在torch.nn和torch.nn.functional中的不同应用,通过Module的training属性来调整应用,Module的train和eval方法如何改变training的值。
写作业的时候发现老师强调 F.dropout()
必须要传入 self.training
,感到疑惑,所以上网搜寻了一下解释,最终明白了情况。
dropout方法出自Dropout: A Simple Way to Prevent Neural Networks from Overfitting,证明该方法有效的文献:Improving neural networks by preventing co-adaptation of feature detectors。
dropout方法是将输入Tensor的元素按伯努利分布随机置0,具体原理此处不赘,以后待补。总之就是训练的时候要用dropout,验证/测试的时候要关dropout。在PyTorch中的实现,是在训练阶段时输出直接乘以
1
1
−
p
\frac{1}{1-p}
1−p1,测试阶段就直接当恒等函数1来用2。
Dropout一般适合于全连接层部分,而卷积层由于其参数并不是很多,所以不需要dropout,加上的话对模型的泛化能力并没有太大的影响。
我们一般在网络的最开始和结束的时候使用全连接层,而hidden layers则是网络中的卷积层。所以一般情况,在全连接层部分,采用较大概率的dropout而在卷积层采用低概率或者不采用dropout。3
以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()
和eval()
方法影响而改变的机制。
文章目录
1. Module的training属性
见torch.nn.Module官方文档
是Module的属性,布尔值,返回Module是否处于训练状态。也就是说在训练时training就是True。
默认为True,也就是Module初始化时默认为训练状态。
2. torch.nn.functional.dropout的入参training
torch.nn.functional.dropout官方文档
torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)
入参training默认为True,置True时应用Dropout,置False时不用。
因此在调用F.dropout()
时,直接将self.training传入函数,就可以在训练时应用dropout,评估时关闭dropout。
示例代码:
x=F.dropout(x,p,self.training)
3. torch.nn.Dropout不需要手动开关
torch.nn.Dropout(p=0.5, inplace=False)
其源代码为(Dropout源码):
class Dropout(_DropoutNd):
def forward(self, input: Tensor) -> Tensor:
return F.dropout(input, self.p, self.training, self.inplace)
就这个类相当于将 F.dropout()
进行了包装,内置传入了self.training,就不用像在 F.dropout()
里需要手动传参,也能实现在训练时应用dropout,评估时关闭dropout。
示例代码:
m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)
4. Module的train()和eval()方法改变self.training
torch.nn.Module.train官方文档
train(mode=True)
如果入参为True,则将Module设置为training mode,training随之变为True;反之则设置为evaluation mode,training为False。
torch.nn.Module.eval官方文档
eval()
将Module设置为evaluation mode,相当于 self.train(False)
5. 除正文中已列文档外的参考资料
- PyTorch 有哪些坑/bug? - 雷杰的回答 - 知乎 那时候F.dropout的training默认置False,更容易错了……
- F.dropout源代码
- (深度学习)Pytorch之dropout训练_junbaba_的博客-CSDN博客
- torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系_chaiiiiiiiiiiiiiiiii的博客-CSDN博客
恒等函数 identity function identity function维基百科 ↩︎
测试阶段的源代码比较直接,可以看出来(来源:pytorch/symbolic_opset9.py at master · pytorch/pytorch):
↩︎
更多推荐
所有评论(0)