参考转载:
https://cloud.tencent.com/developer/article/1009979
https://blog.csdn.net/qq_27825451/article/details/105866464
模型的格式通常支持多种格式,主要有(*.ckpt)、(*.pb)、.
1. (*.ckpt)
训练模型时,需要将每次迭代的权重保存到磁盘,称为“”,如下图所示:
此格式文件是通过从 tf.train.Saver() 对象调用 saver.save() 生成的。它只包含几个对象的序列化数据,不包含图结构。因此,如果不向模型提供代码,就不可能重建计算。图形。
加载时,会调用 saver.(, )。
缺点:一是模型文件依赖,只能在其框架内使用;其次,在恢复模型之前需要重新定义网络结构,然后才能将变量的值恢复到网络中。
2. (*.pb)
此格式文件包含对象的序列化数据,包含计算图,从中可以获取所有 () 详细信息,还包含 () 和定义,但不包含值,因此只能从中恢复计算图它,但是仍然需要从中恢复一些经过训练的权重。以下代码实现了使用 *.pb 文件来构建计算图:
在一些例程中,使用*.pb文件作为预训练模型,与上述格式略有不同,属于()之后的文件,简称格式。此文件格式不包含节点。将所有节点转换为常量(其值是从中获取的),它就变成了格式。代码可以参考 //tools/.py
*.pb 是二进制文件,实际上支持文本格式(*.pbtxt),但是文本格式包含权重会占用大量磁盘空间,所以一般不使用。
3.
这是谷歌推荐的模型保存方式,它与语言无关,可以独立运行,一个封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取,继续训练,和转移模型。除了标记模型输入和输出参数之外,格式是 and 的组合。您可以从中提取和对象。
目录结构如下:
其中.pb(或.pbtxt)包含使用对象定义的计算图;包含附加文件;该目录包含由调用 save() API 的 tf.train.Saver() 对象生成的文件。
以下代码实现了保存:
方法一:
#在模型创建并保存
#1.1 在model中创建signature def signature_def(self): inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs), 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs), 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)} outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)} return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs ,outputs=outputs ,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
#1.2 保存模型 def save_model(self, sess, signature, save_path): builder = tf.saved_model.builder.SavedModelBuilder(save_path) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], {'predict': signature}, clear_devices=True) builder.save()
#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):
signature = predict_signature_def(inputs={
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout},
outputs={'decode_tags': model.decode_tags}
)
return signature
方法二:
#在模型创建并保存
#2.1 在model中创建signature
def signature_def(self):
inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs)
, 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs)
, 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)}
outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)}
return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs
, outputs=outputs
,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
#1.2 保存模型
def save_model(self, sess, signature, save_path):
builder = tf.saved_model.builder.SavedModelBuilder(save_path)
builder.add_meta_graph_and_variables(sess=sess
, tags=[tf.saved_model.tag_constants.SERVING]
, signature_def_map=signature
, clear_devices=True)
builder.save()
#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):
inputs = {
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout}
outputs = {'decode_tags': model.decode_tags}
signature = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.saved_model.signature_def_utils.predict_signature_def(inputs, outputs)
}
return signature
4、模型载入
# encoding = utf8
import tensorflow as tf
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
model_path = 'xxxx'
meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
print('load succeed!')
char_inputs_ = signature['predict'].inputs['char_inputs'].name
seg_inputs_ = signature['predict'].inputs['seg_inputs'].name
dropout_ = signature['predict'].inputs['dropout'].name
decode_tags_ = signature['predict'].outputs['decode_tags'].name
# get tensor
char_inputs = sess.graph.get_tensor_by_name(char_inputs_)
seg_inputs = sess.graph.get_tensor_by_name(seg_inputs_)
dropout = sess.graph.get_tensor_by_name(dropout_)
decode_tags = sess.graph.get_tensor_by_name(decode_tags_)
decode_tags_ = sess.run([decode_tags], feed_dict={char_inputs: inputs[1], seg_inputs:inputs[2], dropout:1.0 })
更多细节可以参考 tensorflow/python/saved_model/README.md。
5. 模式切换
6. 总结
本文总结了常见的模型格式和加载保存方法。官方在部署在线服务时推荐使用该格式(这些格式密切相关,可以使用提供的 API 相互转换。
暂无评论内容