模型训练中loss下降陡升的原因
模型训练过程中的loss陡升,到底是什么问题导致的呢?是灾难性遗忘吗?抱歉,恕我直言,你现在写的代码和那么一点儿数据量,很难导致灾难性遗忘~ 那么大概率就是你写的bug导致的!
·
这些天在做一个竞赛,其中的loss下降过程如下左图:
可以看到很明显的重复波段,而且这个波动是在epoch 交汇处,所以就让我思考,到底是什么地方出现了这个问题?难道是网络的灾难性遗忘导致的吗?
大概率不是,如果是灾难性遗忘,损失波动不会这么有规律。于是我思考是不是程序有问题,果不其然,当我记录下log的时候,发现如下:
可以发现上面红框中的 global steps
存在一个很大的跃迁,每个epoch内的global steps 都是以50 为单位变化,但是在 epoch 交汇出却出现了大概以100 step 为单位变化,所以导致最后计算出的loss差不多是之前的一倍,所以有那么一个陡升的过程。存在bug的代码如下:
因为这个部分是在 for 循环内的,所以如果只有一个 if(i+1) % config['logging_step'] == 0
就会导致一个epoch内不能完全除尽 config[logging_step]
,所以导致 cur_avg_loss
还有剩余就进入了下一个循环。解决方法就是在一个epoch 结束的时候,将cur_avg_loss 置零。
更多推荐
所有评论(0)