最近使用这个tf.train.latest_checkpoint()的时候出现了问题。这里记一下。

这部分的代码如下,主要就是用了tf.train.latest_checkpoint()来寻找最新训练出的模型文件。

sess=tf.Session(config=config)
model=LSTMDSSM(...)

If os.path.exists(train_dir+"checkpoint"):
	model.saver.restore(sess,tf.train.latest_checkpoint(train_dir))
else:
    print("Can't find the checkpoint going to stop")

出错的原因是:我是先在一个有gpu的机器上训练模型,后面又把训练出的模型挪到另外一台机器上执行推理。

我原先以为这个函数是会在存放模型的地方进行查询,按照文件时间或者名字之类的找到最新的模型文件。所以我修改了存放模型的train_dir,想着这样就能正确推理了。结果报错的时候,报错信息里竟然出现了训练用机器上的路径。

然后我去看checkpoint文件之后,发现tf.train.latest_checkpoint()函数应当是读取了这个checkpoints文件,读取里面存储的最新训练的模型文件。如果要使用这个函数,需要注意这点。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐