tensorflow2.0中保存、加载、克隆模型

1. 将模型保存并加载到磁盘1.1 保存并加载整个模型

保存整个模型:

保存模型

model.save()

tf.keras..(模型,)

注意:.的文件格式,如果不加后缀,则默认为格式,如果加后缀.h5,则为HDF5格式。后者比前者轻量级,但内容不如前者。

加载模型

tf.keras..()

注意:如果加载h5格式文件,可能会报错:’str’ has no ‘。这是由于h5py版本高,只能安装h5py的版本,即pip h5py==2.10.0。

例子

import tensorflow as tf
from tensorflow import keras
def get_model():

    model = keras.Sequential()
    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer='sgd',  loss='mse')
    return model
model_1 = get_model()
model_1.save("my_model.h5")
# 或者 model_1.save("my_model")
model_2 = tf.keras.models.load_model("my_model.h5")

# 或者 model_2 = tf.keras.models.load_model("my_model")

1.2只保存和加载参数保存参数

模型.()

注意:文件格式,如果不加后缀,则默认为格式,如果加后缀.h5,则为HDF5格式。具体区别可以看官方文档。当网络嵌套时,后者可能会出现问题。

加载参数

模型.()

例子

import tensorflow as tf
from tensorflow import keras
def get_model():
    model = keras.Sequential()

    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer='sgd',  loss='mse')
    return model
model_1 = get_model()
model_1.save_weights("my_model_weights.h5")
# 或者 model_1.save_weights("my_model_weights")
model_1.load_weights("my_model_weights.h5")
# 或者 model_1.load_weights("my_model_weights.h5")

图片[1]-tensorflow2.0中保存、加载、克隆模型-唐朝资源网

使用回调函数

模型参数也可以使用回调函数保存和加载

在训练过程中添加回调函数:

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=文件路径,  # 文件存储理解
    save_weights_only=True/False,  # 是否只保留参数
    save_best_only=True/False  # 是否只保留最优结果
)
model.fit(
  ...
  callbacks=[cp_callback]

图片[2]-tensorflow2.0中保存、加载、克隆模型-唐朝资源网

) # 加载模型参数 model.load_weights(文件路径)

2. 克隆内存中的模型2.1 克隆整个模型

keras..(模型)

注意:这里的model只能是model或者model,不能是model

2.仅克隆2个参数

获取模型的参数

模型.()

为模型分配参数

模型.()

参考

© 版权声明
THE END
喜欢就支持一下吧
点赞5赞赏 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片

    暂无评论内容