tf.train.latest_checkpoint()报错,路径错误
最近使用这个tf.train.latest_checkpoint()的时候出现了问题。这里记一下。这部分的代码如下,主要就是用了tf.train.latest_checkpoint()来寻找最新训练出的模型文件。sess=tf.Session(config=config)model=LSTMDSSM(...)If os.path.exists(train_dir+"checkpoint"):mod
·
最近使用这个tf.train.latest_checkpoint()的时候出现了问题。这里记一下。
这部分的代码如下,主要就是用了tf.train.latest_checkpoint()来寻找最新训练出的模型文件。
sess=tf.Session(config=config)
model=LSTMDSSM(...)
If os.path.exists(train_dir+"checkpoint"):
model.saver.restore(sess,tf.train.latest_checkpoint(train_dir))
else:
print("Can't find the checkpoint going to stop")
出错的原因是:我是先在一个有gpu的机器上训练模型,后面又把训练出的模型挪到另外一台机器上执行推理。
我原先以为这个函数是会在存放模型的地方进行查询,按照文件时间或者名字之类的找到最新的模型文件。所以我修改了存放模型的train_dir,想着这样就能正确推理了。结果报错的时候,报错信息里竟然出现了训练用机器上的路径。
然后我去看checkpoint文件之后,发现tf.train.latest_checkpoint()函数应当是读取了这个checkpoints文件,读取里面存储的最新训练的模型文件。如果要使用这个函数,需要注意这点。
更多推荐
已为社区贡献1条内容
所有评论(0)