2022-02-05
如何在中实现双向rnn
单层双向 rnn
单层双向 rnn()
中已经提供了双向rnn接口,就是tf.nn.c_rnn()。我们来看看这个接口是如何使用的。
1 bidirectional_dynamic_rnn( 2 cell_fw, #前向 rnn cell 3 cell_bw, #反向 rnn cell 4 inputs, #输入序列. 5 sequence_length=None,# 序列长度 6 initial_state_fw=None,#前向rnn_cell的初始状态 7 initial_state_bw=None,#反向rnn_cell的初始状态 8 dtype=None,#数据类型 9 parallel_iterations=None, 10 swap_memory=False, 11 time_major=False, 12 scope=None 13 )
返回值:一个 tuple(, ),其中,是一个 tuple(, )。关于总和,如果 =True 那么它们都是,反之亦然。如果需要,直接使用 tf.(, 2) 即可。
如何使用:
c_rnn 使用和
n是非常相似的. 定义前向和反向rnn_cell 定义前向和反向rnn_cell的初始状态 准备好序列 调用bidirectional_dynamic_rnn import tensorflow as tf from tensorflow.contrib import rnn cell_fw = rnn.LSTMCell(10) cell_bw = rnn.LSTMCell(10) initial_state_fw = cell_fw.zero_state(batch_size) initial_state_bw = cell_bw.zero_state(batch_size) seq = ... seq_length = ... (outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq, seq_length, initial_state_fw,initial_state_bw) out = tf.concat(outputs, 2)
查看代码
# ....
多层双向 rnn
多层双向 rnn()
单层双向rnn可以通过上面的方法很方便的实现,但是多层双向rnn不能传给c_rnn。
要了解原因,我们需要查看 c_rnn 的源代码片段。
1 with vs.variable_scope(scope or "bidirectional_rnn"): 2 # Forward direction 3 with vs.variable_scope("fw") as fw_scope: 4 output_fw, output_state_fw = dynamic_rnn( 5 cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 6 initial_state=initial_state_fw, dtype=dtype, 7 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 8 time_major=time_major, scope=fw_scope)
这只是一小部分代码,但足以看出bi-rnn实际上是通过-rnn实现的,如果我们使用它,则忽略了每一层之间不同方向之间的交互。所以我们可以自己实现一个Tool函数,通过多次调用c_rnn来实现多层双向RNN。这是我对多层双向 RNN 的简化版本的实现。如有错误请指出
c_rnn源码探索
上面我们已经看到了正向过程的代码实现,我们来看看剩下的反向部分的实现。
事实上,反向过程做了两次
1. 第一次:输入序列被处理,然后发送做一个操作。
2. 第二次:进行上述返回,保证正反转输出时间正确。
1 def _reverse(input_, seq_lengths, seq_dim, batch_dim): 2 if seq_lengths is not None: 3 return array_ops.reverse_sequence( 4 input=input_, seq_lengths=seq_lengths, 5 seq_dim=seq_dim, batch_dim=batch_dim) 6 else: 7 return array_ops.reverse(input_, axis=[seq_dim]) 8 9 with vs.variable_scope("bw") as bw_scope: 10 inputs_reverse = _reverse( 11 inputs, seq_lengths=sequence_length, 12 seq_dim=time_dim, batch_dim=batch_dim) 13 tmp, output_state_bw = dynamic_rnn( 14 cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 15 initial_state=initial_state_bw, dtype=dtype, 16 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 17 time_major=time_major, scope=bw_scope) 18 19 output_bw = _reverse( 20 tmp, seq_lengths=sequence_length, 21 seq_dim=time_dim, batch_dim=batch_dim) 22 23 outputs = (output_fw, output_bw) 24 output_states = (output_state_fw, output_state_bw) 25 26 return (outputs, output_states)
tf。
反转序列的一部分
1 reverse_sequence( 2 input,#输入序列,将被reverse的序列 3 seq_lengths,#1Dtensor,表示输入序列长度 4 seq_axis=None,# 哪维代表序列 5 batch_axis=None, #哪维代表 batch 6 name=None, 7 seq_dim=None, 8 batch_dim=None 9 )
官网上的例子很好,我就直接贴在这里了:
1 # Given this: 2 batch_dim = 0 3 seq_dim = 1 4 input.dims = (4, 8, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] 9 output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] 10 output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] 11 output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[0, 7:, :, ...] = input[0, 7:, :, ...] 15 output[1, 2:, :, ...] = input[1, 2:, :, ...] 16 output[2, 3:, :, ...] = input[2, 3:, :, ...] 17 output[3, 2:, :, ...] = input[3, 2:, :, ...]
示例 2:
1 # Given this: 2 batch_dim = 2 3 seq_dim = 0 4 input.dims = (8, ?, 4, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] 9 output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] 10 output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] 11 output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] 15 output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] 16 output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] 17 output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
分类:
技术要点:
相关文章:
© 版权声明
本站下载的源码均来自公开网络收集转发二次开发而来,
若侵犯了您的合法权益,请来信通知我们1413333033@qq.com,
我们会及时删除,给您带来的不便,我们深表歉意。
下载用户仅供学习交流,若使用商业用途,请购买正版授权,否则产生的一切后果将由下载用户自行承担,访问及下载者下载默认同意本站声明的免责申明,请合理使用切勿商用。
THE END
暂无评论内容