本文为原创文章,未经本人允许,禁止转载。转载请注明出处。
1.tf.train.Saver()
tf.train.Saver()
用于保存和加载模型。
1
saver=tf.train.Saver()
tf.train.Saver()
参数见下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(
self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None)
部分参数解释:
var_list
:指定要保存和恢复的变量。max_to_keep
:是经常会用到的一个参数。用于设置保存模型的个数(默认为max_to_keep=5
,即保存最近的5个模型)。若max_to_keep
设置为None或0,则保存所有的模型。keep_checkpoint_every_n_hours
:每n个小时保存一次模型。
1.1.saver.save()
1
2
3
4
5
6
7
8
9
10
11
def save(
self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False)
部分参数解释:
sess
:Session。save_path
:模型保存路径。例如:saver.save(sess, 'net/my_net.ckpt')
。global_step
:用来给模型文件名添加数字标记。例如:saver.save(sess, 'my-model', global_step=0)
,保存得到的模型文件名为:'my-model-0'
。
1.2.saver.restore()
1
def restore(self, sess, save_path)
参数解释:
sess
:Session。save_path
:模型路径。例如:saver.restore(sess, 'net/my_net.ckpt')
。
导入模型之前,必须重新再定义一遍变量。但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。
可以使用tf.train.latest_checkpoint()
来自动获取最后一次保存的模型。如:
1
2
model_file=tf.train.latest_checkpoint(checkpoint_dir, latest_filename=None)
saver.restore(sess,model_file)
2.ckpt模型
使用saver.save()
将模型保存为ckpt格式,会生成以下四个文件:
my_net.ckpt.meta
:保存了Tensorflow计算图的结构,即网络结构。my_net.ckpt.index
和my_net.ckpt.data-00000-of-00001
:保存了所有变量的取值。checkpoint
:保存了一个目录下所有的模型文件列表。