[原创] 如何打印出TensorFlow保存的checkpoint里的参数名

TensorFlow版本:1.14.0
Python版本:3.6.8

checkpoint文件是TensorFlow保存出来的一种模型文件格式。通常save下来的时候会得到4个文件,例如:

checkpoint
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta

如何查看这些文件里的模型参数名称呢?

[1] 方法一:用TensorFlow自带的 inspect_checkpoint 工具
首先 cd 到那4个model文件所在的目录,然后执行:

python -m tensorflow.python.tools.inspect_checkpoint --file_name=model.ckpt-1 "$@"

其中,model.ckpt-1 是输出的checkpoint文件名中的一部分(后3个文件的最后一个“点”号前面的部分)。
输出:
pi/actor13/bias (DT_FLOAT) [32]
pi/actor13/kernel (DT_FLOAT) [101,32]
pi/actor15/bias (DT_FLOAT) [8]
pi/actor15/kernel (DT_FLOAT) [32,8]
pi/ap/bias (DT_FLOAT) [2]
pi/ap/kernel (DT_FLOAT) [8,2]
 
# Total number of params: 3546
上面类似于 pi/actor13/kernel 的部分就是一个参数名。
文章来源:https://www.codelast.com/
[2] 方法二:写Python程序实现

  1. checkpoint_path = '/path_to_those_four_files/model.ckpt-1'  
  2. reader = tf.train.NewCheckpointReader(checkpoint_path)  
  3. var_to_shape_map = reader.get_variable_to_shape_map()  
  4.   
  5. for key in var_to_shape_map:  
  6.     print("tensor name: ", key)  
  7.     print(reader.get_tensor(key))  # 打印出Tensor的值  

其中,checkpoint_path 变量中的 path_to_those_four_files 这一部分指的是那4个model文件所在的目录路径,后面的 “model.ckpt-1” 是输出的checkpoint文件名中的一部分,并不是完整的文件名,这一点要注意,很多文章没有说清楚,会让人搞混。
文章来源:https://www.codelast.com/
部分输出内容类似于:

tensor_name:  pi/actor15/bias
[ 0.01485237 -0.04058828  0.00179128 -0.00357329 -0.05909787  0.00424578
 -0.03840631 -0.00575123]
tensor_name:  pi/ap/bias
[ 0.00746095 -0.00746095]

...

(完)
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论