这些天在做一个竞赛,其中的loss下降过程如下左图:
在这里插入图片描述

可以看到很明显的重复波段,而且这个波动是在epoch 交汇处,所以就让我思考,到底是什么地方出现了这个问题?难道是网络的灾难性遗忘导致的吗?
大概率不是,如果是灾难性遗忘,损失波动不会这么有规律。于是我思考是不是程序有问题,果不其然,当我记录下log的时候,发现如下:
在这里插入图片描述
可以发现上面红框中的 global steps 存在一个很大的跃迁,每个epoch内的global steps 都是以50 为单位变化,但是在 epoch 交汇出却出现了大概以100 step 为单位变化,所以导致最后计算出的loss差不多是之前的一倍,所以有那么一个陡升的过程。存在bug的代码如下:

在这里插入图片描述
因为这个部分是在 for 循环内的,所以如果只有一个 if(i+1) % config['logging_step'] == 0 就会导致一个epoch内不能完全除尽 config[logging_step] ,所以导致 cur_avg_loss 还有剩余就进入了下一个循环。解决方法就是在一个epoch 结束的时候,将cur_avg_loss 置零。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐