出现这个问题是因为网络中存在BatchNormalization模块,它需要多于1个数据来计算平均值,当batch只有一个数据时会报错。
如果使用pytorch,可以在获取数据集时,将DataLoader中drop_last设置为True。把不够一个批次的数据丢弃。

原文 https://blog.csdn.net/sinat_39307513/article/details/87917537

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐