tensorflow中的batch_normalization实现

2022-02-06

中实现了两个主要功能

1)tf.nn.

2)tf.nn.

tf.nn.主要用于计算均值和方差值,用于后续的tf.nn。

tf.nn.(x, 轴,…)

主要有两个参数:输入数据;计算均值和方差的维度轴,轴的值是一个列表,可以传入多个维度

返回值:均值和

tf.nn.(x, mean, , , scala, )

主要参数:输入数据;意思是;;和scala,这两个参数是要学习的参数,所以只要给定初始值,一般=0,scala=1;保证为0时除法依然是,设置为较小的值

输出:bn 处理数据

具体代码如下:

import tensorflow as tf
import numpy as np
X = tf.constant(np.random.uniform(1, 10, size=(3, 3)), dtype=tf.float32)
axis = list(range(len(X.get_shape()) - 1))
mean, variance = tf.nn.moments(X, axis)
print(axis)
X_batch = tf.nn.batch_normalization(X, mean, variance, 0, 1, 0.001)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    mean, variance, X_batch = sess.run([mean, variance, X_batch])
    print(mean)
    print(variance)
    print(X_batch)
输出:

轴:[0]

意思是:[5. 3. 4. ]

: [3. 1. 3.]

: [[-0. -1. 0.]

[-1. 0. -1. ]

[ 1. 1. 0.]]

 

分类:

技术要点:

相关文章:

© 版权声明
THE END
喜欢就支持一下吧
点赞10赞赏 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片

    暂无评论内容