Tensorflow模型的格式

参考转载:
https://cloud.tencent.com/developer/article/1009979
https://blog.csdn.net/qq_27825451/article/details/105866464

模型的格式通常支持多种格式,主要有(*.ckpt)、(*.pb)、.

1. (*.ckpt)

训练模型时,需要将每次迭代的权重保存到磁盘,称为“”,如下图所示:

此格式文件是通过从 tf.train.Saver() 对象调用 saver.save() 生成的。它只包含几个对象的序列化数据,不包含图结构。因此,如果不向模型提供代码,就不可能重建计算。图形。

加载时,会调用 saver.(, )。

缺点:一是模型文件依赖,只能在其框架内使用;其次,在恢复模型之前需要重新定义网络结构,然后才能将变量的值恢复到网络中。

2. (*.pb)

此格式文件包含对象的序列化数据,包含计算图,从中可以获取所有 () 详细信息,还包含 () 和定义,但不包含值,因此只能从中恢复计算图它,但是仍然需要从中恢复一些经过训练的权重。以下代码实现了使用 *.pb 文件来构建计算图:

在一些例程中,使用*.pb文件作为预训练模型,与上述格式略有不同,属于()之后的文件,简称格式。此文件格式不包含节点。将所有节点转换为常量(其值是从中获取的),它就变成了格式。代码可以参考 //tools/.py

*.pb 是二进制文件,实际上支持文本格式(*.pbtxt),但是文本格式包含权重会占用大量磁盘空间,所以一般不使用。

3.

这是谷歌推荐的模型保存方式,它与语言无关,可以独立运行,一个封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取,继续训练,和转移模型。除了标记模型输入和输出参数之外,格式是 and 的组合。您可以从中提取和对象。

目录结构如下:

其中.pb(或.pbtxt)包含使用对象定义的计算图;包含附加文件;该目录包含由调用 save() API 的 tf.train.Saver() 对象生成的文件。

以下代码实现了保存:

方法一:

#在模型创建并保存
#1.1 在model中
创建signature def signature_def(self): inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs), 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs), 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)} outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)} return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs ,outputs=outputs ,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

#1.2 保存模型
def save_model(self, sess, signature, save_path):
        builder = tf.saved_model.builder.SavedModelBuilder(save_path)
        builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], {'predict': signature}, clear_devices=True)
        builder.save()


#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):
signature = predict_signature_def(inputs={
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout},
outputs={'decode_tags': model.decode_tags}
)
return signature

方法二:

#在模型创建并保存
#2.1 在model中创建signature

def signature_def(self):
inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs)
, 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs)
, 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)}

outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)}
return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs
, outputs=outputs
,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME

#1.2 保存模型
def save_model(self, sess, signature, save_path):
builder = tf.saved_model.builder.SavedModelBuilder(save_path)
builder.add_meta_graph_and_variables(sess=sess
, tags=[tf.saved_model.tag_constants.SERVING]
, signature_def_map=signature
, clear_devices=True)
builder.save()

 

#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):

    inputs = {
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout}
outputs = {'decode_tags': model.decode_tags}

signature = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.saved_model.signature_def_utils.predict_signature_def(inputs, outputs)
}

return signature


4、模型载入

# encoding = utf8

import tensorflow as tf
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)

model_path = 'xxxx'
meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)

print('load succeed!')


  char_inputs_ = signature['predict'].inputs['char_inputs'].name
  seg_inputs_ = signature['predict'].inputs['seg_inputs'].name
  dropout_ = signature['predict'].inputs['dropout'].name
  decode_tags_ = signature['predict'].outputs['decode_tags'].name
  # get tensor   char_inputs = sess.graph.get_tensor_by_name(char_inputs_)   seg_inputs = sess.graph.get_tensor_by_name(seg_inputs_)   dropout = sess.graph.get_tensor_by_name(dropout_)   decode_tags = sess.graph.get_tensor_by_name(decode_tags_)   decode_tags_ = sess.run([decode_tags], feed_dict={char_inputs: inputs[1], seg_inputs:inputs[2], dropout:1.0 })


更多细节可以参考 tensorflow/python/saved_model/README.md。

5. 模式切换

6. 总结

本文总结了常见的模型格式和加载保存方法。官方在部署在线服务时推荐使用该格式(这些格式密切相关,可以使用提供的 API 相互转换。

© 版权声明
THE END
喜欢就支持一下吧
点赞5 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片