tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测。
一,模型持久化
为了让训练得到的模型保存下来方便下次直接调用,我们需要将训练得到的神经网络模型持久化。下面学习通过TensorFlow程序来持久化一个训练好的模型,并从持久化之后的模型文件中还原被保存的模型,然后学习TensorFlow持久化的工作原理和持久化之后文件中的数据格式。
1,持久化代码实现
TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是 tf.train.Saver 类。使用 tf.train.saver() 保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方式是在TensorFlow中是最常用的保存方式。
下面代码给出了保存TensorFlow计算图的方法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#_*_coding:utf-8_*_
import tensorflow as tf
import os
# 声明两个变量并计算他们的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
init_op = tf.global_variables_initializer()
# 声明 tf.train.Saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
# 将模型保存到model.ckpt文件中
model_path = 'model/model.ckpt'
saver.save(sess, model_path)
上面的代码实现了持久化一个简单的TensorFlow模型的功能。在这段代码中,通过saver.save 函数将TensorFlow模型保存到了 model/model.path 文件中。TensorFlow模型一般会保存在后缀为 .ckpt 的文件中,虽然上面的程序只指定了一个文件路径,但是这个文件目录下面会出现三个文件。这是因为TensorFlow会将计算图的结构和图上参数取值分开保存。
运行上面代码,我们查看model文件里面的文件如下:
下面解释一下文件分别是干什么的:
checkpoint文件是检查点文件,文件保存了一个目录下所有模型文件列表。
model.ckpt.data文件保存了TensorFlow程序中每一个变量的取值
model.ckpt.index文件则保存了TensorFlow程序中变量的索引
model.ckpt.meta文件则保存了TensorFlow计算图的结构(可以简单理解为神经网络的网络结构),该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
下面代码给出加载这个模型的方法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#_*_coding:utf-8_*_
import tensorflow as tf
#使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
# 加载已经保存的模型,并通过已经保存的模型中的变量的值来计算加法
model_path = 'model/model.ckpt'
saver.restore(sess, model_path)
print(sess.run(result))
# 结果如下:[3.]
这段加载模型的代码基本上和保存模型的代码是一样的。在加载模型的程序中也是先定义了TensorFlow计算图上所有运算,并声明了一个 tf.train.Saver类。两段代码唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载出来。如果不希望重复定义图上的运算,也可以直接加载已经持久化的图,以下代码给出一个样例:
1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
# 直接加载持久化的图
model_path = 'model/model.ckpt'
model_path1 = 'model/model.ckpt.meta'
saver = tf.train.import_meta_graph(model_path1)
with tf.Session() as sess:
saver.restore(sess, model_path)
# 通过张量的的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
# 结果如下:[3.]
其上面给出的程序中,默认保存和加载了TensorFlow计算图上定义的所有变量。但是有时可能只需要保存或者加载部分变量。比如,可能有一个之前训练好的五层神经网络模型,现在想尝试一个六层神经网络,那么可以将前面五层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。
为了保存或者加载部分变量,在声明 tf.train.Saver 类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用 saver = tf.train.Saver([v1]) 命令来构建 tf.train.Saver 类,那么只有变量 v1 会被加载进来。如果运行修改后只加载了 v1 的代码会得到变量未初始化的错误:
1
2
tensorflow.python.framework.errors.FailedPreconditionError:Attempting to
use uninitialized value v2
因为 v2 没有被加载,所以v2在运行初始化之前是没有值的。除了可以选取需要被加载的变量,tf.train.Saver 类也支持在保存或者加载时给变量重命名。
下面给出一个简单的样例程序说明变量重命名是如何被使用的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
# 这里声明的变量名称和已经保存的模型中变量的的名称不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')
# 如果直接使用 tf.train.Saver() 来加载模型会报变量找不到的错误,下面显示了报错信息
# tensorflow.python.framework.errors.FailedPreconditionError:Tensor name 'other-v2'
# not found in checkpoint file model/model.ckpt
# 使用一个字典来重命名变量就可以加载原来的模型了
# 这个字典指定了原来名称为 v1 的变量现在加载到变量 v1中(名称为 other-v1)
# 名称为v2 的变量加载到变量 v2中(名称为 other-v2)
saver = tf.train.Saver({'v1': v1, 'v2': v2})
在这个程序中,对变量 v1 和 v2 的名称进行了修改。如果直接通过 tf.train.Saver 默认的构造函数来加载保存的模型,那么程序会报变量找不到的错误,因为保存时候的变量名称和加载时变量的名称不一致。为了解决这个问题,Tensorflow 可以通过字典(dictionary)将模型保存时的变量名和需要加载的变量联系起来。这样做的主要目的之一就是方便使用变量的滑动平均值。在之前介绍了使用变量的滑动平均值可以让神经网络模型更加健壮(robust)。在TensorFlow中,每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。如果在加载模型时将影子变量映射到变量本身,那么在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。这样就大大方便了滑动平均模型的时域。下面代码给出了一个保存滑动平均模型的样例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
# 在没有申明滑动平均模型时只有一个变量 v,所以下面语句只会输出 v:0
for variables in tf.global_variables():
print(variables.name)
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申明滑动平均模型之后,TensorFlow会自动生成一个影子变量 v/ExponentialMovingAverage
# 于是下面的语句会输出 v:0 和 v/ExponentialMovingAverage:0
for variables in tf.global_variables():
print(variables.name)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
# 保存时,TensorFlow会将v:0 和 v/ExponentialMovingAverage:0 两个变量都保存下来
saver.save(sess, 'model/modeltest.ckpt')
print(sess.run([v, ema.average(v)]))
# 输出结果 [10.0, 0.099999905]
下面代码给出了如何通过变量重命名直接读取变量的滑动平均值。从下面程序的输出可以看出,读取的变量 v 的值实际上是上面代码中变量 v 的滑动平均值。通过这个方法,就可以使用完全一样的代码来计算滑动平均模型前向传播的结果:
1
2
3
4
5
6
7
v = tf.Variable(0, dtype=tf.float32, name='v')
# 通过变量重命名将原来变量v的滑动平均值直接赋值给 V
saver = tf.train.Saver({'v/ExponentialMovingAverage': v})
with tf.Session() as sess:
saver.restore(sess, 'model/modeltest.ckpt')
print(sess.run(v))
# 输出 0.099999905 这个值就是原来模型中变量 v 的滑动平均值
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage 类提供了 variables_tp_restore 函数来生成 tf.train.Saver类所需要的变量重命名字典,一下代码给出了 variables_to_restore 函数的使用样例:
1
2
3
4
5
6
7
8
9
10
11
12
13
v = tf.Variable(0, dtype=tf.float32, name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
# 通过使用 variables_to_restore 函数可以直接生成上面代码中提供的字典
# {'v/ExponentialMovingAverage': v}
# 下面代码会输出 {'v/ExponentialMovingAverage': }
print(ema.variables_to_restore())
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess, 'model/modeltest.ckpt')
print(sess.run(v))
# 输出 0.099999905 即原来模型中变量 v 的滑动平均值
使用 tf.train.Saver 会保存进行TensorFlow程序所需要的全部信息,然后有时并不需要某些信息。比如在测试或者离线预测时,只需要知道如何从神经网络的输出层经过前向传播计算得到输出层即可,而不需要类似于变量初始化,模型保存等辅助接点的信息。而且,将变量取值和计算图结构分成不同的文件存储有时候也不方便,于是TensorFlow提供了 convert_variables_to_constants 函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TensorFlow计算图可以统一存放在一个文件中,该方法可以固化模型结构,而且保存的模型可以移植到Android平台。
convert_variables_to_constants固化模型结构
下面给出一个样例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 导出当前计算图的GraphDef部分,只需要这一步就可以完成从输入层到输出层的过程
graph_def = tf.get_default_graph().as_graph_def()
# 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉
# 在下面,最后一个参数['add']给出了需要保存的节点名称
# add节点是上面定义的两个变量相加的操作
# 注意这里给出的是计算节点的的名称,所以没有后面的 :0
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, (['add']))
# 将导出的模型存入文件
with tf.gfile.GFile('model/combined_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
通过下面的程序可以直接计算定义加法运算的结果,当只需要得到计算图中某个节点的取值时,这提供了一个更加方便的方法,以后将使用这种方法来使用训练好的模型完成迁移学习。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
# 读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def 中保存的图加载到当前的图中,
# return_elements = ['add: 0'] 给出了返回的张量的名称
# 在保存的时候给出的是计算节点的名称,所以为add
# 在加载的时候给出的张量的名称,所以是 add:0
result = tf.import_graph_def(graph_def, return_elements=['add: 0'])
print(sess.run(result))
# 输出 [array([3.], dtype=float32)]
2,持久化原理及数据格式
上面学习了当调用 saver.save 函数时,TensorFlow程序会自动生成四个文件。TensorFlow模型的持久化就是通过这个四个文件完成的。这里我们详细学习一下这个三个文件中保存的内容以及数据格式。
TensorFlow是一个通过图的形式来表述计算的编程系统,TensorFlow程序中所有计算都会被表达为计算图上的节点。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由 MetaGraphDef Protocol Buffer 定义的。MetaGraphDef 中的内容就构成了TensorFlow 持久化的第一个文件,以下代码给出了MetaGraphDef类型的定义:
1
2
3
4
5
6
7
message MetaGraphDef{
MeatInfoDef meta_info_def = 1;
GraphDef graph_def = 2;
SaverDef saver_def = 3;
map collection_def = 4;
map signature_def = 5;
}
从上面代码中可以看到,元图中主要记录了五类信息,下面结合变量相加样例的持久化结果,逐一介绍MetaGraphDef类型的每一个属性中存储的信息。保存 MetaGraphDef 信息的文件默认为以 .meta 为后缀名,在上面,文件 model.ckpt.meta 中存储的就是元图的数据。直接运行其样例得到的是一个二进制文件,无法直接查看。为了方便调试,TensorFlow提供了 export_meta_graph 函数,这函数支持以json格式导出 MetaGraphDef Protocol Buffer。下面代码展示了如何使用这个函数:
1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
# 定义变量相加的计算
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
saver = tf.train.Saver()
# 通过 export_meta_graph() 函数导出TensorFlow计算图的元图,并保存为json格式
saver.export_meta_graph('model/model.ckpt.meda.json', as_text=True)
通过上面给出的代码,我们可以将计算图元图以json的格式导出并存储在 model.ckpt.meda.json 文件中。下面给出这个文件的大概内容:
我们从JSON文件中可以看到确实是五类信息。下面结合这JSON文件的具体内容来学习一下TensorFlow中元图存储的信息。
1,meta_info_def属性
meta_info_def 属性是通过MetaInfoDef定义的。它记录了TensorFlow计算图中的元数据以及TensorFlow程序中所有使用到的运算方法的信息,下面是 MetaInfoDef Protocol Buffer 的定义:
1
2
3
4
5
6
7
8
9
message MetaInfoDef{
#saver没有特殊指定,默认属性都为空。meta_info_def属性里只有stripped_op_list属性不能为空。
#该属性不能为空
string meta_graph_version = 1;
#该属性记录了计算图中使用到的所有运算方法的信息,该函数只记录运算信息,不记录计算的次数
OpList stripped_op_list = 2;
google.protobuf.Any any_info = 3;
repeated string tags = 4;
}
TensorFlow计算图的元数据包括了计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。如果没有在 saver中特殊指定,那么这些属性都默认为空。
在model.ckpt.meta.json文件中,meta_info_def 属性里只有 stripped_op_list属性是不为空的。stripped_op_list 属性记录了TensorFlow计算图上使用到的所有运算方法的信息。注意stripped_op_list 属性保存的是 TensorFlow 运算方法的信息,所以如果某一个运算在TensorFlow计算图中出现了多次,那么在 stripped_op_list 也只会出现一次。比如在 model.ckpt.meta.jspm 文件的 stripped_op_list 属性只有一个 Variable运算,但是这个运算在程序中被使用了两次。
stripped_op_list 属性的类型是 OpList。OpList 类型是一个 OpDef类型的列表,以下代码给出了 OpDef 类型的定义:
1
2
3
4
5
6
7
8
9
10
11
12
13
message opDef{
string name = 1;#定义了运算的名称
repeated ArgDef input_arg = 2; #定义了输入,属性是列表
repeated ArgDef output_arg =3; #定义了输出,属性是列表
repeated AttrDef attr = 4;#给出了其他运算的参数信息
string summary = 5;
string description = 6;
OpDeprecation deprecation = 8;
bool is_commutative = 18;
bool is_aggregate = 16
bool is_stateful = 17;
bool allows_uninitialized_input = 19;
};
OpDef 类型中前四个属性定义了一个运算最核心的信息。OpDef 中的第一个属性 name 定义了运算的名称,这也是一个运算唯一的标识符。在TensorFlow计算图元图的其他属性中,比如下面要学习的GraphDef属性,将通过运算名称来引用不同的运算。OpDef 的第二个和第三个属性为 input_arg 和 output_arg,他们定义了运算的输出和输入。因为输入输出都可以有多个,所以这两个属性都是列表。第四个属性Attr给出了其他的运算参数信息。在JSON文件中共定义了七个运算,下面将给出比较有代表性的一个运算来辅助说明OpDef 的数据结构。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
op {
name: "Add"
input_arg{
name: "x"
type_attr:"T"
}
input_arg{
name: "y"
type_attr:"T"
}
output_arg{
name: "z"
type_attr:"T"
}
attr{
name:"T"
type:"type"
allow_values{
list{
type:DT_HALF
type:DT_FLOAT
...
}
}
}
}
上面给出了名称为Add的运算。这个运算有两个输入和一个输出,输入输出属性都指定了属性 type_attr,并且这个属性的值为 T。在OpDef的Attr属性中,必须要出现名称(name)为 T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型(allowed_values)。
2,graph_def 属性
graph_def 属性主要记录了TensorFlow 计算图上的节点信息。TensorFlow计算图的每一个节点对应了TensorFlow程序中一个运算,因为在 meta_info_def 属性中已经包含了所有运算的具体信息,所以 graph_def 属性只关注运算的连接结构。graph_def属性是通过 GraphDef Protocol Buffer 定义的,graph_def主要包含了一个 NodeDef类型的列表。一下代码给出了 graph_def 和NodeDef类型中包含的信息:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
message GraphDef{
#GraphDef的主要信息存储在node属性中,他记录了Tensorflow计算图上所有的节点信息。
repeated NodeDef node = 1;
VersionDef versions = 4; #主要储存了Tensorflow的版本号
};
message NodeDef{
#NodeDef类型中有一个名称属性name,他是一个节点的唯一标识符,在程序中,通过节点的名称来获得相应的节点。
string name = 1;
'''
op属性给出了该节点使用的Tensorflow运算方法的名称。
通过这个名称可以在TensorFlow计算图元图的meta_info_def属性中找到该运算的具体信息。
'''
string op = 2;
'''
input属性是一个字符串列表,他定义了运算的输入。每个字符串的取值格式为弄的:src_output
node部分给出节点名称,src_output表明了这个输入是指定节点的第几个输出。
src_output=0时可以省略src_output部分
'''
repeated string input = 3;
#制定了处理这个运算的设备,可以是本地或者远程的CPU or GPU。属性为空时自动选择
string device = 4;
#制定了和当前运算有关的配置信息
map attr = 5;
};
GraphDef中的versions属性比较简单,它主要存储了TensorFlow的版本号。和其他属性类似,NodeDef 类型中有一个名称属性 name,它是一个节点的唯一标识符,在TensorFlow程序中可以通过节点的名称来获取响应节点。 NodeDef 类型中 的 device属性指定了处理这个运算的设备。运行TensorFlow运算的设备可以是本地机器的CPU或者GPU,当device属性为空时,TensorFlow在运行时会自动选取一个最适合的设备来运行这个运算,最后NodeDef类型中的Attr属性指定了和当前运算相关的配置信息。
下面列举了 model.ckpt.meta.json 文件中的一个计算节点来更加具体的了解graph_def属性:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
graph def {
node {
name: "v1"
op: "Variable"
attr {
key:"_output_shapes"
value {
list{ shape { dim { size: 1 } } }
}
}
}
attr {
key :"dtype"
value {
type: DT_FLOAT
}
}
...
}
node {
name :"ad