1153 lines
35 KiB
Python
1153 lines
35 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
"""RNN helpers for TensorFlow models.
|
|
@@bidirectional_dynamic_rnn
|
|
@@dynamic_rnn
|
|
@@raw_rnn
|
|
@@static_rnn
|
|
@@static_state_saving_rnn
|
|
@@static_bidirectional_rnn
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import rnn_cell_impl
|
|
from tensorflow.python.ops import tensor_array_ops
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.util import nest
|
|
import tensorflow as tf
|
|
|
|
|
|
def _like_rnncell_(cell):
|
|
"""Checks that a given object is an RNNCell by using duck typing."""
|
|
|
|
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
|
|
|
|
hasattr(cell, "zero_state"), callable(cell)]
|
|
|
|
return all(conditions)
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
_concat = rnn_cell_impl._concat
|
|
try:
|
|
_like_rnncell = rnn_cell_impl._like_rnncell
|
|
except Exception as e:
|
|
_like_rnncell = _like_rnncell_
|
|
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def _transpose_batch_time(x):
|
|
"""Transpose the batch and time dimensions of a Tensor.
|
|
Retains as much of the static shape information as possible.
|
|
Args:
|
|
x: A tensor of rank 2 or higher.
|
|
Returns:
|
|
x transposed along the first two dimensions.
|
|
Raises:
|
|
ValueError: if `x` is rank 1 or lower.
|
|
"""
|
|
|
|
x_static_shape = x.get_shape()
|
|
|
|
if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
|
|
raise ValueError(
|
|
|
|
"Expected input tensor %s to have rank at least 2, but saw shape: %s" %
|
|
|
|
(x, x_static_shape))
|
|
|
|
x_rank = array_ops.rank(x)
|
|
|
|
x_t = array_ops.transpose(
|
|
|
|
x, array_ops.concat(
|
|
|
|
([1, 0], math_ops.range(2, x_rank)), axis=0))
|
|
|
|
x_t.set_shape(
|
|
|
|
tensor_shape.TensorShape([
|
|
|
|
x_static_shape[1].value, x_static_shape[0].value
|
|
|
|
]).concatenate(x_static_shape[2:]))
|
|
|
|
return x_t
|
|
|
|
|
|
def _best_effort_input_batch_size(flat_input):
|
|
"""Get static input batch size if available, with fallback to the dynamic one.
|
|
Args:
|
|
flat_input: An iterable of time major input Tensors of shape [max_time,
|
|
batch_size, ...]. All inputs should have compatible batch sizes.
|
|
Returns:
|
|
The batch size in Python integer if available, or a scalar Tensor otherwise.
|
|
Raises:
|
|
ValueError: if there is any input with an invalid shape.
|
|
"""
|
|
|
|
for input_ in flat_input:
|
|
|
|
shape = input_.shape
|
|
|
|
if shape.ndims is None:
|
|
continue
|
|
|
|
if shape.ndims < 2:
|
|
raise ValueError(
|
|
|
|
"Expected input tensor %s to have rank at least 2" % input_)
|
|
|
|
batch_size = shape[1].value
|
|
|
|
if batch_size is not None:
|
|
return batch_size
|
|
|
|
# Fallback to the dynamic batch size of the first input.
|
|
|
|
return array_ops.shape(flat_input[0])[1]
|
|
|
|
|
|
def _infer_state_dtype(explicit_dtype, state):
|
|
"""Infer the dtype of an RNN state.
|
|
Args:
|
|
explicit_dtype: explicitly declared dtype or None.
|
|
state: RNN's hidden state. Must be a Tensor or a nested iterable containing
|
|
Tensors.
|
|
Returns:
|
|
dtype: inferred dtype of hidden state.
|
|
Raises:
|
|
ValueError: if `state` has heterogeneous dtypes or is empty.
|
|
"""
|
|
|
|
if explicit_dtype is not None:
|
|
|
|
return explicit_dtype
|
|
|
|
elif nest.is_sequence(state):
|
|
|
|
inferred_dtypes = [element.dtype for element in nest.flatten(state)]
|
|
|
|
if not inferred_dtypes:
|
|
raise ValueError("Unable to infer dtype from empty state.")
|
|
|
|
all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
|
|
|
|
if not all_same:
|
|
raise ValueError(
|
|
|
|
"State has tensors of different inferred_dtypes. Unable to infer a "
|
|
|
|
"single representative dtype.")
|
|
|
|
return inferred_dtypes[0]
|
|
|
|
else:
|
|
|
|
return state.dtype
|
|
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
def _rnn_step(
|
|
|
|
time, sequence_length, min_sequence_length, max_sequence_length,
|
|
|
|
zero_output, state, call_cell, state_size, skip_conditionals=False):
|
|
"""Calculate one step of a dynamic RNN minibatch.
|
|
Returns an (output, state) pair conditioned on the sequence_lengths.
|
|
When skip_conditionals=False, the pseudocode is something like:
|
|
if t >= max_sequence_length:
|
|
return (zero_output, state)
|
|
if t < min_sequence_length:
|
|
return call_cell()
|
|
# Selectively output zeros or output, old state or new state depending
|
|
# on if we've finished calculating each row.
|
|
new_output, new_state = call_cell()
|
|
final_output = np.vstack([
|
|
zero_output if time >= sequence_lengths[r] else new_output_r
|
|
for r, new_output_r in enumerate(new_output)
|
|
])
|
|
final_state = np.vstack([
|
|
state[r] if time >= sequence_lengths[r] else new_state_r
|
|
for r, new_state_r in enumerate(new_state)
|
|
])
|
|
return (final_output, final_state)
|
|
Args:
|
|
time: Python int, the current time step
|
|
sequence_length: int32 `Tensor` vector of size [batch_size]
|
|
min_sequence_length: int32 `Tensor` scalar, min of sequence_length
|
|
max_sequence_length: int32 `Tensor` scalar, max of sequence_length
|
|
zero_output: `Tensor` vector of shape [output_size]
|
|
state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
|
|
or a list/tuple of such tensors.
|
|
call_cell: lambda returning tuple of (new_output, new_state) where
|
|
new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
|
|
new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
|
|
state_size: The `cell.state_size` associated with the state.
|
|
skip_conditionals: Python bool, whether to skip using the conditional
|
|
calculations. This is useful for `dynamic_rnn`, where the input tensor
|
|
matches `max_sequence_length`, and using conditionals just slows
|
|
everything down.
|
|
Returns:
|
|
A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
|
|
final_output is a `Tensor` matrix of shape [batch_size, output_size]
|
|
final_state is either a single `Tensor` matrix, or a tuple of such
|
|
matrices (matching length and shapes of input `state`).
|
|
Raises:
|
|
ValueError: If the cell returns a state tuple whose length does not match
|
|
that returned by `state_size`.
|
|
"""
|
|
|
|
# Convert state to a list for ease of use
|
|
|
|
flat_state = nest.flatten(state)
|
|
|
|
flat_zero_output = nest.flatten(zero_output)
|
|
|
|
def _copy_one_through(output, new_output):
|
|
|
|
# If the state contains a scalar value we simply pass it through.
|
|
|
|
if output.shape.ndims == 0:
|
|
return new_output
|
|
|
|
copy_cond = (time >= sequence_length)
|
|
|
|
with ops.colocate_with(new_output):
|
|
return array_ops.where(copy_cond, output, new_output)
|
|
|
|
def _copy_some_through(flat_new_output, flat_new_state):
|
|
|
|
# Use broadcasting select to determine which values should get
|
|
|
|
# the previous state & zero output, and which values should get
|
|
|
|
# a calculated state & output.
|
|
|
|
flat_new_output = [
|
|
|
|
_copy_one_through(zero_output, new_output)
|
|
|
|
for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
|
|
|
|
flat_new_state = [
|
|
|
|
_copy_one_through(state, new_state)
|
|
|
|
for state, new_state in zip(flat_state, flat_new_state)]
|
|
|
|
return flat_new_output + flat_new_state
|
|
|
|
def _maybe_copy_some_through():
|
|
|
|
"""Run RNN step. Pass through either no or some past state."""
|
|
|
|
new_output, new_state = call_cell()
|
|
|
|
nest.assert_same_structure(state, new_state)
|
|
|
|
flat_new_state = nest.flatten(new_state)
|
|
|
|
flat_new_output = nest.flatten(new_output)
|
|
|
|
return control_flow_ops.cond(
|
|
|
|
# if t < min_seq_len: calculate and return everything
|
|
|
|
time < min_sequence_length, lambda: flat_new_output + flat_new_state,
|
|
|
|
# else copy some of it through
|
|
|
|
lambda: _copy_some_through(flat_new_output, flat_new_state))
|
|
|
|
# TODO(ebrevdo): skipping these conditionals may cause a slowdown,
|
|
|
|
# but benefits from removing cond() and its gradient. We should
|
|
|
|
# profile with and without this switch here.
|
|
|
|
if skip_conditionals:
|
|
|
|
# Instead of using conditionals, perform the selective copy at all time
|
|
|
|
# steps. This is faster when max_seq_len is equal to the number of unrolls
|
|
|
|
# (which is typical for dynamic_rnn).
|
|
|
|
new_output, new_state = call_cell()
|
|
|
|
nest.assert_same_structure(state, new_state)
|
|
|
|
new_state = nest.flatten(new_state)
|
|
|
|
new_output = nest.flatten(new_output)
|
|
|
|
final_output_and_state = _copy_some_through(new_output, new_state)
|
|
|
|
else:
|
|
|
|
empty_update = lambda: flat_zero_output + flat_state
|
|
|
|
final_output_and_state = control_flow_ops.cond(
|
|
|
|
# if t >= max_seq_len: copy all state through, output zeros
|
|
|
|
time >= max_sequence_length, empty_update,
|
|
|
|
# otherwise calculation is required: copy some or all of it through
|
|
|
|
_maybe_copy_some_through)
|
|
|
|
if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
|
|
raise ValueError("Internal error: state and output were not concatenated "
|
|
|
|
"correctly.")
|
|
|
|
final_output = final_output_and_state[:len(flat_zero_output)]
|
|
|
|
final_state = final_output_and_state[len(flat_zero_output):]
|
|
|
|
for output, flat_output in zip(final_output, flat_zero_output):
|
|
output.set_shape(flat_output.get_shape())
|
|
|
|
for substate, flat_substate in zip(final_state, flat_state):
|
|
substate.set_shape(flat_substate.get_shape())
|
|
|
|
final_output = nest.pack_sequence_as(
|
|
|
|
structure=zero_output, flat_sequence=final_output)
|
|
|
|
final_state = nest.pack_sequence_as(
|
|
|
|
structure=state, flat_sequence=final_state)
|
|
|
|
return final_output, final_state
|
|
|
|
|
|
def _reverse_seq(input_seq, lengths):
|
|
"""Reverse a list of Tensors up to specified lengths.
|
|
Args:
|
|
input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
|
|
or nested tuples of tensors.
|
|
lengths: A `Tensor` of dimension batch_size, containing lengths for each
|
|
sequence in the batch. If "None" is specified, simply reverses
|
|
the list.
|
|
Returns:
|
|
time-reversed sequence
|
|
"""
|
|
|
|
if lengths is None:
|
|
return list(reversed(input_seq))
|
|
|
|
flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
|
|
|
|
flat_results = [[] for _ in range(len(input_seq))]
|
|
|
|
for sequence in zip(*flat_input_seq):
|
|
|
|
input_shape = tensor_shape.unknown_shape(
|
|
|
|
ndims=sequence[0].get_shape().ndims)
|
|
|
|
for input_ in sequence:
|
|
input_shape.merge_with(input_.get_shape())
|
|
|
|
input_.set_shape(input_shape)
|
|
|
|
# Join into (time, batch_size, depth)
|
|
|
|
s_joined = array_ops.stack(sequence)
|
|
|
|
# Reverse along dimension 0
|
|
|
|
s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
|
|
|
|
# Split again into list
|
|
|
|
result = array_ops.unstack(s_reversed)
|
|
|
|
for r, flat_result in zip(result, flat_results):
|
|
r.set_shape(input_shape)
|
|
|
|
flat_result.append(r)
|
|
|
|
results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
|
|
|
|
for input_, flat_result in zip(input_seq, flat_results)]
|
|
|
|
return results
|
|
|
|
|
|
#
|
|
# def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
|
|
#
|
|
# initial_state_fw=None, initial_state_bw=None,
|
|
#
|
|
# dtype=None, parallel_iterations=None,
|
|
#
|
|
# swap_memory=False, time_major=False, scope=None):
|
|
#
|
|
# """Creates a dynamic version of bidirectional recurrent neural network.
|
|
#
|
|
#
|
|
#
|
|
# Takes input and builds independent forward and backward RNNs. The input_size
|
|
#
|
|
# of forward and backward cell must match. The initial state for both directions
|
|
#
|
|
# is zero by default (but can be set optionally) and no intermediate states are
|
|
#
|
|
# ever returned -- the network is fully unrolled for the given (passed in)
|
|
#
|
|
# length(s) of the sequence(s) or completely unrolled if length(s) is not
|
|
#
|
|
# given.
|
|
#
|
|
#
|
|
#
|
|
# Args:
|
|
#
|
|
# cell_fw: An instance of RNNCell, to be used for forward direction.
|
|
#
|
|
# cell_bw: An instance of RNNCell, to be used for backward direction.
|
|
#
|
|
# inputs: The RNN inputs.
|
|
#
|
|
# If time_major == False (default), this must be a tensor of shape:
|
|
#
|
|
# `[batch_size, max_time, ...]`, or a nested tuple of such elements.
|
|
#
|
|
# If time_major == True, this must be a tensor of shape:
|
|
#
|
|
# `[max_time, batch_size, ...]`, or a nested tuple of such elements.
|
|
#
|
|
# sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
|
#
|
|
# containing the actual lengths for each of the sequences in the batch.
|
|
#
|
|
# If not provided, all batch entries are assumed to be full sequences; and
|
|
#
|
|
# time reversal is applied from time `0` to `max_time` for each sequence.
|
|
#
|
|
# initial_state_fw: (optional) An initial state for the forward RNN.
|
|
#
|
|
# This must be a tensor of appropriate type and shape
|
|
#
|
|
# `[batch_size, cell_fw.state_size]`.
|
|
#
|
|
# If `cell_fw.state_size` is a tuple, this should be a tuple of
|
|
#
|
|
# tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
|
|
#
|
|
# initial_state_bw: (optional) Same as for `initial_state_fw`, but using
|
|
#
|
|
# the corresponding properties of `cell_bw`.
|
|
#
|
|
# dtype: (optional) The data type for the initial states and expected output.
|
|
#
|
|
# Required if initial_states are not provided or RNN states have a
|
|
#
|
|
# heterogeneous dtype.
|
|
#
|
|
# parallel_iterations: (Default: 32). The number of iterations to run in
|
|
#
|
|
# parallel. Those operations which do not have any temporal dependency
|
|
#
|
|
# and can be run in parallel, will be. This parameter trades off
|
|
#
|
|
# time for space. Values >> 1 use more memory but take less time,
|
|
#
|
|
# while smaller values use less memory but computations take longer.
|
|
#
|
|
# swap_memory: Transparently swap the tensors produced in forward inference
|
|
#
|
|
# but needed for back prop from GPU to CPU. This allows training RNNs
|
|
#
|
|
# which would typically not fit on a single GPU, with very minimal (or no)
|
|
#
|
|
# performance penalty.
|
|
#
|
|
# time_major: The shape format of the `inputs` and `outputs` Tensors.
|
|
#
|
|
# If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
|
|
#
|
|
# If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
|
|
#
|
|
# Using `time_major = True` is a bit more efficient because it avoids
|
|
#
|
|
# transposes at the beginning and end of the RNN calculation. However,
|
|
#
|
|
# most TensorFlow data is batch-major, so by default this function
|
|
#
|
|
# accepts input and emits output in batch-major form.
|
|
#
|
|
# scope: VariableScope for the created subgraph; defaults to
|
|
#
|
|
# "bidirectional_rnn"
|
|
#
|
|
#
|
|
#
|
|
# Returns:
|
|
#
|
|
# A tuple (outputs, output_states) where:
|
|
#
|
|
# outputs: A tuple (output_fw, output_bw) containing the forward and
|
|
#
|
|
# the backward rnn output `Tensor`.
|
|
#
|
|
# If time_major == False (default),
|
|
#
|
|
# output_fw will be a `Tensor` shaped:
|
|
#
|
|
# `[batch_size, max_time, cell_fw.output_size]`
|
|
#
|
|
# and output_bw will be a `Tensor` shaped:
|
|
#
|
|
# `[batch_size, max_time, cell_bw.output_size]`.
|
|
#
|
|
# If time_major == True,
|
|
#
|
|
# output_fw will be a `Tensor` shaped:
|
|
#
|
|
# `[max_time, batch_size, cell_fw.output_size]`
|
|
#
|
|
# and output_bw will be a `Tensor` shaped:
|
|
#
|
|
# `[max_time, batch_size, cell_bw.output_size]`.
|
|
#
|
|
# It returns a tuple instead of a single concatenated `Tensor`, unlike
|
|
#
|
|
# in the `bidirectional_rnn`. If the concatenated one is preferred,
|
|
#
|
|
# the forward and backward outputs can be concatenated as
|
|
#
|
|
# `tf.concat(outputs, 2)`.
|
|
#
|
|
# output_states: A tuple (output_state_fw, output_state_bw) containing
|
|
#
|
|
# the forward and the backward final states of bidirectional rnn.
|
|
#
|
|
#
|
|
#
|
|
# Raises:
|
|
#
|
|
# TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
|
#
|
|
# """
|
|
#
|
|
#
|
|
#
|
|
# if not _like_rnncell(cell_fw):
|
|
#
|
|
# raise TypeError("cell_fw must be an instance of RNNCell")
|
|
#
|
|
# if not _like_rnncell(cell_bw):
|
|
#
|
|
# raise TypeError("cell_bw must be an instance of RNNCell")
|
|
#
|
|
#
|
|
#
|
|
# with vs.variable_scope(scope or "bidirectional_rnn"):
|
|
#
|
|
# # Forward direction
|
|
#
|
|
# with vs.variable_scope("fw") as fw_scope:
|
|
#
|
|
# output_fw, output_state_fw = dynamic_rnn(
|
|
#
|
|
# cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
|
|
#
|
|
# initial_state=initial_state_fw, dtype=dtype,
|
|
#
|
|
# parallel_iterations=parallel_iterations, swap_memory=swap_memory,
|
|
#
|
|
# time_major=time_major, scope=fw_scope)
|
|
#
|
|
#
|
|
#
|
|
# # Backward direction
|
|
#
|
|
# if not time_major:
|
|
#
|
|
# time_dim = 1
|
|
#
|
|
# batch_dim = 0
|
|
#
|
|
# else:
|
|
#
|
|
# time_dim = 0
|
|
#
|
|
# batch_dim = 1
|
|
#
|
|
#
|
|
#
|
|
# def _reverse(input_, seq_lengths, seq_dim, batch_dim):
|
|
#
|
|
# if seq_lengths is not None:
|
|
#
|
|
# return array_ops.reverse_sequence(
|
|
#
|
|
# input=input_, seq_lengths=seq_lengths,
|
|
#
|
|
# seq_dim=seq_dim, batch_dim=batch_dim)
|
|
#
|
|
# else:
|
|
#
|
|
# return array_ops.reverse(input_, axis=[seq_dim])
|
|
#
|
|
#
|
|
#
|
|
# with vs.variable_scope("bw") as bw_scope:
|
|
#
|
|
# inputs_reverse = _reverse(
|
|
#
|
|
# inputs, seq_lengths=sequence_length,
|
|
#
|
|
# seq_dim=time_dim, batch_dim=batch_dim)
|
|
#
|
|
# tmp, output_state_bw = dynamic_rnn(
|
|
#
|
|
# cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
|
|
#
|
|
# initial_state=initial_state_bw, dtype=dtype,
|
|
#
|
|
# parallel_iterations=parallel_iterations, swap_memory=swap_memory,
|
|
#
|
|
# time_major=time_major, scope=bw_scope)
|
|
#
|
|
#
|
|
#
|
|
# output_bw = _reverse(
|
|
#
|
|
# tmp, seq_lengths=sequence_length,
|
|
#
|
|
# seq_dim=time_dim, batch_dim=batch_dim)
|
|
#
|
|
#
|
|
#
|
|
# outputs = (output_fw, output_bw)
|
|
#
|
|
# output_states = (output_state_fw, output_state_bw)
|
|
#
|
|
#
|
|
#
|
|
# return (outputs, output_states)
|
|
#
|
|
|
|
|
|
def dynamic_rnn(cell, inputs, att_scores=None, sequence_length=None, initial_state=None,
|
|
|
|
dtype=None, parallel_iterations=None, swap_memory=False,
|
|
|
|
time_major=False, scope=None):
|
|
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
|
Performs fully dynamic unrolling of `inputs`.
|
|
Example:
|
|
```python
|
|
# create a BasicRNNCell
|
|
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
|
|
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
|
|
# defining initial state
|
|
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
|
|
# 'state' is a tensor of shape [batch_size, cell_state_size]
|
|
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
|
|
initial_state=initial_state,
|
|
dtype=tf.float32)
|
|
```
|
|
```python
|
|
# create 2 LSTMCells
|
|
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
|
|
# create a RNN cell composed sequentially of a number of RNNCells
|
|
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
|
|
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
|
|
# 'state' is a N-tuple where N is the number of LSTMCells containing a
|
|
# tf.contrib.rnn.LSTMStateTuple for each cell
|
|
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
|
|
inputs=data,
|
|
dtype=tf.float32)
|
|
```
|
|
Args:
|
|
cell: An instance of RNNCell.
|
|
inputs: The RNN inputs.
|
|
If `time_major == False` (default), this must be a `Tensor` of shape:
|
|
`[batch_size, max_time, ...]`, or a nested tuple of such
|
|
elements.
|
|
If `time_major == True`, this must be a `Tensor` of shape:
|
|
`[max_time, batch_size, ...]`, or a nested tuple of such
|
|
elements.
|
|
This may also be a (possibly nested) tuple of Tensors satisfying
|
|
this property. The first two dimensions must match across all the inputs,
|
|
but otherwise the ranks and other shape components may differ.
|
|
In this case, input to `cell` at each time-step will replicate the
|
|
structure of these tuples, except for the time dimension (from which the
|
|
time is taken).
|
|
The input to `cell` at each time step will be a `Tensor` or (possibly
|
|
nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
|
|
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
|
|
Used to copy-through state and zero-out outputs when past a batch
|
|
element's sequence length. So it's more for correctness than performance.
|
|
initial_state: (optional) An initial state for the RNN.
|
|
If `cell.state_size` is an integer, this must be
|
|
a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
|
|
If `cell.state_size` is a tuple, this should be a tuple of
|
|
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
|
dtype: (optional) The data type for the initial state and expected output.
|
|
Required if initial_state is not provided or RNN state has a heterogeneous
|
|
dtype.
|
|
parallel_iterations: (Default: 32). The number of iterations to run in
|
|
parallel. Those operations which do not have any temporal dependency
|
|
and can be run in parallel, will be. This parameter trades off
|
|
time for space. Values >> 1 use more memory but take less time,
|
|
while smaller values use less memory but computations take longer.
|
|
swap_memory: Transparently swap the tensors produced in forward inference
|
|
but needed for back prop from GPU to CPU. This allows training RNNs
|
|
which would typically not fit on a single GPU, with very minimal (or no)
|
|
performance penalty.
|
|
time_major: The shape format of the `inputs` and `outputs` Tensors.
|
|
If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
|
|
If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
|
|
Using `time_major = True` is a bit more efficient because it avoids
|
|
transposes at the beginning and end of the RNN calculation. However,
|
|
most TensorFlow data is batch-major, so by default this function
|
|
accepts input and emits output in batch-major form.
|
|
scope: VariableScope for the created subgraph; defaults to "rnn".
|
|
Returns:
|
|
A pair (outputs, state) where:
|
|
outputs: The RNN output `Tensor`.
|
|
If time_major == False (default), this will be a `Tensor` shaped:
|
|
`[batch_size, max_time, cell.output_size]`.
|
|
If time_major == True, this will be a `Tensor` shaped:
|
|
`[max_time, batch_size, cell.output_size]`.
|
|
Note, if `cell.output_size` is a (possibly nested) tuple of integers
|
|
or `TensorShape` objects, then `outputs` will be a tuple having the
|
|
same structure as `cell.output_size`, containing Tensors having shapes
|
|
corresponding to the shape data in `cell.output_size`.
|
|
state: The final state. If `cell.state_size` is an int, this
|
|
will be shaped `[batch_size, cell.state_size]`. If it is a
|
|
`TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
|
|
If it is a (possibly nested) tuple of ints or `TensorShape`, this will
|
|
be a tuple having the corresponding shapes. If cells are `LSTMCells`
|
|
`state` will be a tuple containing a `LSTMStateTuple` for each cell.
|
|
Raises:
|
|
TypeError: If `cell` is not an instance of RNNCell.
|
|
ValueError: If inputs is None or an empty list.
|
|
"""
|
|
|
|
if not _like_rnncell(cell):
|
|
raise TypeError("cell must be an instance of RNNCell")
|
|
|
|
# By default, time_major==False and inputs are batch-major: shaped
|
|
|
|
# [batch, time, depth]
|
|
|
|
# For internal calculations, we transpose to [time, batch, depth]
|
|
|
|
flat_input = nest.flatten(inputs)
|
|
|
|
if not time_major:
|
|
# (B,T,D) => (T,B,D)
|
|
|
|
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
|
|
|
|
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
|
|
|
|
parallel_iterations = parallel_iterations or 32
|
|
|
|
if sequence_length is not None:
|
|
|
|
sequence_length = math_ops.to_int32(sequence_length)
|
|
|
|
if sequence_length.get_shape().ndims not in (None, 1):
|
|
raise ValueError(
|
|
|
|
"sequence_length must be a vector of length batch_size, "
|
|
|
|
"but saw shape: %s" % sequence_length.get_shape())
|
|
|
|
sequence_length = array_ops.identity( # Just to find it in the graph.
|
|
|
|
sequence_length, name="sequence_length")
|
|
|
|
# Create a new scope in which the caching device is either
|
|
|
|
# determined by the parent scope, or is set to place the cached
|
|
|
|
# Variable using the same placement as for the rest of the RNN.
|
|
|
|
with vs.variable_scope(scope or "rnn",reuse=tf.AUTO_REUSE) as varscope:#TODO:user defined reuse
|
|
|
|
if varscope.caching_device is None:
|
|
varscope.set_caching_device(lambda op: op.device)
|
|
|
|
batch_size = _best_effort_input_batch_size(flat_input)
|
|
|
|
if initial_state is not None:
|
|
|
|
state = initial_state
|
|
|
|
else:
|
|
|
|
if not dtype:
|
|
raise ValueError("If there is no initial_state, you must give a dtype.")
|
|
|
|
state = cell.zero_state(batch_size, dtype)
|
|
|
|
def _assert_has_shape(x, shape):
|
|
|
|
x_shape = array_ops.shape(x)
|
|
|
|
packed_shape = array_ops.stack(shape)
|
|
|
|
return control_flow_ops.Assert(
|
|
|
|
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
|
|
|
|
["Expected shape for Tensor %s is " % x.name,
|
|
|
|
packed_shape, " but saw shape: ", x_shape])
|
|
|
|
if sequence_length is not None:
|
|
# Perform some shape validation
|
|
|
|
with ops.control_dependencies(
|
|
|
|
[_assert_has_shape(sequence_length, [batch_size])]):
|
|
sequence_length = array_ops.identity(
|
|
|
|
sequence_length, name="CheckSeqLen")
|
|
|
|
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
|
|
|
|
(outputs, final_state) = _dynamic_rnn_loop(
|
|
|
|
cell,
|
|
|
|
inputs,
|
|
|
|
state,
|
|
|
|
parallel_iterations=parallel_iterations,
|
|
|
|
swap_memory=swap_memory,
|
|
|
|
att_scores=att_scores,
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
dtype=dtype)
|
|
|
|
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
|
|
|
|
# If we are performing batch-major calculations, transpose output back
|
|
|
|
# to shape [batch, time, depth]
|
|
|
|
if not time_major:
|
|
# (T,B,D) => (B,T,D)
|
|
|
|
outputs = nest.map_structure(_transpose_batch_time, outputs)
|
|
|
|
return (outputs, final_state)
|
|
|
|
|
|
def _dynamic_rnn_loop(cell,
|
|
|
|
inputs,
|
|
|
|
initial_state,
|
|
|
|
parallel_iterations,
|
|
|
|
swap_memory,
|
|
|
|
att_scores=None,
|
|
|
|
sequence_length=None,
|
|
|
|
dtype=None):
|
|
"""Internal implementation of Dynamic RNN.
|
|
Args:
|
|
cell: An instance of RNNCell.
|
|
inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
|
|
tuple of such elements.
|
|
initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
|
|
`cell.state_size` is a tuple, then this should be a tuple of
|
|
tensors having shapes `[batch_size, s] for s in cell.state_size`.
|
|
parallel_iterations: Positive Python int.
|
|
swap_memory: A Python boolean
|
|
sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
|
|
dtype: (optional) Expected dtype of output. If not specified, inferred from
|
|
initial_state.
|
|
Returns:
|
|
Tuple `(final_outputs, final_state)`.
|
|
final_outputs:
|
|
A `Tensor` of shape `[time, batch_size, cell.output_size]`. If
|
|
`cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
|
|
objects, then this returns a (possibly nsted) tuple of Tensors matching
|
|
the corresponding shapes.
|
|
final_state:
|
|
A `Tensor`, or possibly nested tuple of Tensors, matching in length
|
|
and shapes to `initial_state`.
|
|
Raises:
|
|
ValueError: If the input depth cannot be inferred via shape inference
|
|
from the inputs.
|
|
"""
|
|
|
|
state = initial_state
|
|
|
|
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
|
|
|
|
state_size = cell.state_size
|
|
|
|
flat_input = nest.flatten(inputs)
|
|
|
|
flat_output_size = nest.flatten(cell.output_size)
|
|
|
|
# Construct an initial output
|
|
|
|
input_shape = array_ops.shape(flat_input[0])
|
|
|
|
time_steps = input_shape[0]
|
|
|
|
batch_size = _best_effort_input_batch_size(flat_input)
|
|
|
|
inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
|
|
|
|
for input_ in flat_input)
|
|
|
|
const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
|
|
|
|
for shape in inputs_got_shape:
|
|
|
|
if not shape[2:].is_fully_defined():
|
|
raise ValueError(
|
|
|
|
"Input size (depth of inputs) must be accessible via shape inference,"
|
|
|
|
" but saw value None.")
|
|
|
|
got_time_steps = shape[0].value
|
|
|
|
got_batch_size = shape[1].value
|
|
|
|
if const_time_steps != got_time_steps:
|
|
raise ValueError(
|
|
|
|
"Time steps is not the same for all the elements in the input in a "
|
|
|
|
"batch.")
|
|
|
|
if const_batch_size != got_batch_size:
|
|
raise ValueError(
|
|
|
|
"Batch_size is not the same for all the elements in the input.")
|
|
|
|
# Prepare dynamic conditional copying of state & output
|
|
|
|
def _create_zero_arrays(size):
|
|
|
|
size = _concat(batch_size, size)
|
|
|
|
return array_ops.zeros(
|
|
|
|
array_ops.stack(size), _infer_state_dtype(dtype, state))
|
|
|
|
flat_zero_output = tuple(_create_zero_arrays(output)
|
|
|
|
for output in flat_output_size)
|
|
|
|
zero_output = nest.pack_sequence_as(structure=cell.output_size,
|
|
|
|
flat_sequence=flat_zero_output)
|
|
|
|
if sequence_length is not None:
|
|
min_sequence_length = math_ops.reduce_min(sequence_length)
|
|
|
|
max_sequence_length = math_ops.reduce_max(sequence_length)
|
|
|
|
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
|
|
|
|
with ops.name_scope("dynamic_rnn") as scope:
|
|
|
|
base_name = scope
|
|
|
|
def _create_ta(name, dtype):
|
|
|
|
return tensor_array_ops.TensorArray(dtype=dtype,
|
|
|
|
size=time_steps,
|
|
|
|
tensor_array_name=base_name + name)
|
|
|
|
output_ta = tuple(_create_ta("output_%d" % i,
|
|
|
|
_infer_state_dtype(dtype, state))
|
|
|
|
for i in range(len(flat_output_size)))
|
|
|
|
input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype)
|
|
|
|
for i in range(len(flat_input)))
|
|
|
|
input_ta = tuple(ta.unstack(input_)
|
|
|
|
for ta, input_ in zip(input_ta, flat_input))
|
|
|
|
def _time_step(time, output_ta_t, state, att_scores=None):
|
|
|
|
"""Take a time step of the dynamic RNN.
|
|
Args:
|
|
time: int32 scalar Tensor.
|
|
output_ta_t: List of `TensorArray`s that represent the output.
|
|
state: nested tuple of vector tensors that represent the state.
|
|
Returns:
|
|
The tuple (time + 1, output_ta_t with updated flow, new_state).
|
|
"""
|
|
|
|
input_t = tuple(ta.read(time) for ta in input_ta)
|
|
|
|
# Restore some shape information
|
|
|
|
for input_, shape in zip(input_t, inputs_got_shape):
|
|
input_.set_shape(shape[1:])
|
|
|
|
input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
|
|
|
|
if att_scores is not None:
|
|
|
|
att_score = att_scores[:, time, :]
|
|
|
|
call_cell = lambda: cell(input_t, state, att_score)
|
|
|
|
else:
|
|
|
|
call_cell = lambda: cell(input_t, state)
|
|
|
|
if sequence_length is not None:
|
|
|
|
(output, new_state) = _rnn_step(
|
|
|
|
time=time,
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
min_sequence_length=min_sequence_length,
|
|
|
|
max_sequence_length=max_sequence_length,
|
|
|
|
zero_output=zero_output,
|
|
|
|
state=state,
|
|
|
|
call_cell=call_cell,
|
|
|
|
state_size=state_size,
|
|
|
|
skip_conditionals=True)
|
|
|
|
else:
|
|
|
|
(output, new_state) = call_cell()
|
|
|
|
# Pack state if using state tuples
|
|
|
|
output = nest.flatten(output)
|
|
|
|
output_ta_t = tuple(
|
|
|
|
ta.write(time, out) for ta, out in zip(output_ta_t, output))
|
|
|
|
if att_scores is not None:
|
|
|
|
return (time + 1, output_ta_t, new_state, att_scores)
|
|
|
|
else:
|
|
|
|
return (time + 1, output_ta_t, new_state)
|
|
|
|
if att_scores is not None:
|
|
|
|
_, output_final_ta, final_state, _ = control_flow_ops.while_loop(
|
|
|
|
cond=lambda time, *_: time < time_steps,
|
|
|
|
body=_time_step,
|
|
|
|
loop_vars=(time, output_ta, state, att_scores),
|
|
|
|
parallel_iterations=parallel_iterations,
|
|
|
|
swap_memory=swap_memory)
|
|
|
|
else:
|
|
|
|
_, output_final_ta, final_state = control_flow_ops.while_loop(
|
|
|
|
cond=lambda time, *_: time < time_steps,
|
|
|
|
body=_time_step,
|
|
|
|
loop_vars=(time, output_ta, state),
|
|
|
|
parallel_iterations=parallel_iterations,
|
|
|
|
swap_memory=swap_memory)
|
|
|
|
# Unpack final output if not using output tuples.
|
|
|
|
final_outputs = tuple(ta.stack() for ta in output_final_ta)
|
|
|
|
# Restore some shape information
|
|
|
|
for output, output_size in zip(final_outputs, flat_output_size):
|
|
shape = _concat(
|
|
|
|
[const_time_steps, const_batch_size], output_size, static=True)
|
|
|
|
output.set_shape(shape)
|
|
|
|
final_outputs = nest.pack_sequence_as(
|
|
|
|
structure=cell.output_size, flat_sequence=final_outputs)
|
|
|
|
return (final_outputs, final_state) |