2022-02-04
目录
当用于执行异步计算时,队列是一种强大的机制。
为了了解队列,让我们看一个简单的例子。我们首先创建一个“先进先出”队列()并将其中的所有元素初始化为某个值。然后,我们构建一个图,该图从队列的前面获取一个元素,加 1,然后将其放回队列的后面。慢慢地,队列中元素的值会增加。
提供了两个类来帮助实现多线程:tf. 和 tf.. 类可用于同时停止多个工作线程,并向正在等待所有工作线程终止的程序报告异常。类用于协调多个工作线程,将多个张量同时推送到同一个队列中。
队列
当张量被异步计算时,队列,例如(先进先出,有序出队)和(随机出队)非常重要。
(, ,):创建一个队列,按照先进先出的顺序排列元素
同步执行队列
完成一个出队,+1、入队操作(同步操作):
import tensorflow as tf
# 同步操作,如队列,+1,出队列
# 创建一个队列
Q = tf.FIFOQueue(3, dtypes=tf.float32)
# 数据进队列
init_q = Q.enqueue_many([[1.0, 2.0, 3.0], ])
# 定义操作
de_q = Q.dequeue()
data = de_q + 1
en_q = Q.enqueue(data)
with tf.Session() as sess:
# 初始化队列
sess.run(init_q)
# 执行10次 +1 操作
for i in range(10):
sess.run(en_q)
# 取出数据
for i in range(Q.size().eval()):
print(Q.dequeue().eval())
输出:
5.0
6.0
5.0
当数据量较大时,入队操作从硬盘中读取数据,放入内存中。主线程在训练前需要等待入队操作完成。可以在会话中运行多个线程来实现异步读取。
队列管理器
该类创建了一组可以重复执行操作的线程,并且它们使用同一个线程来处理同步线程终止。此外,一个会运行一个,这将在收到异常报告时自动关闭队列。
您可以使用队列来实现上述结构。首先构建一个使用队列输入样本的图表。处理样本并将样本推入队列的增量操作。添加操作以从队列中删除样本。
tf.train.(queue, =None): 创建一个
异步执行队列
变量加1,进入队列,通过队列管理器将主线程出队的操作(异步操作):
# 异步操作,变量+1,入队,出队列
Q = tf.FIFOQueue(100, dtypes=tf.float32)
# 要做的事情
var = tf.Variable(0.0)
data = tf.assign_add(var, 1)
en_q = Q.enqueue(data)
# 队列管理器op
qr = tf.train.QueueRunner(Q, enqueue_ops=[en_q] * 5)
# 变量初始化op
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
# 初始化变量
sess.run(init_op)
# 开始子线程
threads = qr.create_threads(sess, start=True)
# 主线程读取数据
for i in range(50):
print(sess.run(Q.dequeue()))
分析:此时出现的问题是在完成所需的出队操作后程序无法结束。需要实现线程之间的同步,终止其他线程。程序执行完成并出现以下错误:
tensorflow.python.framework.errors_impl.CancelledError: Enqueue operation was cancelled
[[{{node fifo_queue_enqueue}}]]
线程协调器
tf.train.():线程协调器,实现一个简单的机制来协调一组线程的终止
首先创建一个对象,然后创建一些使用该对象的线程。这些线程通常循环运行,直到 () 返回 True 才停止。任何线程都可以决定何时停止计算。它只需要调用(),其他线程的()就会返回True,然后全部停止。
加入线程协调器的程序:
# 异步操作,变量+1,入队,出队列
Q = tf.FIFOQueue(100, dtypes=tf.float32)
# 要做的事情
var = tf.Variable(0.0)
data = tf.assign_add(var, 1)
en_q = Q.enqueue(data)
# 队列管理器op
qr = tf.train.QueueRunner(Q, enqueue_ops=[en_q] * 5)
# 变量初始化op
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
# 初始化变量
sess.run(init_op)
# 开启线程协调器
coord = tf.train.Coordinator()
# 开始子线程
threads = qr.create_threads(sess, coord=coord, start=True)
# 主线程读取数据
for i in range(50):
print(sess.run(Q.dequeue()))
# 请求停止线程
coord.request_stop()
coord.join()
分类:
技术要点:
相关文章:
暂无评论内容