tensorflow finuetuning 例子

2022-02-16

最近我研究了如何使用它。相比caffe,就麻烦多了。记录如下:

1.原理

原理很简单。使用在数据集A上训练的模型作为初始值,改变其部分结构,调用在另一个数据集B(学习率较小)上训练的过程。

一般来说,会用到以下条件

图片[1]-tensorflow finuetuning 例子-唐朝资源网

2.密钥代码

在数据集A上训练时,和普通的训练过程完全一样。然而,当它在数据集 B 上执行时,需要从先前训练的模型参数中恢复模型参数。这个地方比较关键,

需要注意的是,只恢复需要恢复的参数,其他参数不要恢复,否则会因为找不到声明而报错。以mnist为例,如果我想先训练一个0-7的8类分类器,网络结构如下:

conv1-conv2-fc8(其他未加权,忽略层)

然后我想用训练好的模型参数在一个 0-9 的 10 类分类器上做。网络结构如下:

conv1-conv2-fc10

然后当我从中恢复模型参数时,我只能恢复conv1-conv2。如果连fc8都恢复了,会因为找不到fc8的定义而报错

上面描述对应的代码如下:

图片[2]-tensorflow finuetuning 例子-唐朝资源网

1     if tf.train.latest_checkpoint('ckpts') is not None:
2         trainable_vars = tf.trainable_variables()
3         res_vars = [t for t in trainable_vars if t.name.startswith('conv')]
4         saver = tf.train.Saver(var_list=res_vars)

图片[3]-tensorflow finuetuning 例子-唐朝资源网

5 saver.restore(sess, tf.train.latest_checkpoint('ckpts')) 6 else: 7 saver = tf.train.Saver()

3.演示

图片[4]-tensorflow finuetuning 例子-唐朝资源网

用mnist写个简单的例子,大家可以试试,原来用现有的相关模型来做比从0开始训练收敛更快,准确率更高,

点我下载

分类:

技术要点:

相关文章:

© 版权声明
THE END
喜欢就支持一下吧
点赞11赞赏 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片

    暂无评论内容