我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.util.nest.flatten()。
def __init__(self, n_hiddens, hidden_transfer=default_activation, n_out=None, transfer=None, initializers=default_init): """Initialises the MLP :param n_hiddens: int or an interable of ints, number of hidden units in layers :param hidden_transfer: callable or iterable; a transfer function for hidden layers or an interable thereof. If it's an iterable its length should be the same as length of `n_hiddens` :param n_out: int or None, number of output units :param transfer: callable or None, a transfer function for the output """ super(MLP, self).__init__(self.__class__.__name__) self._n_hiddens = nest.flatten(n_hiddens) transfers = nest.flatten(hidden_transfer) if len(transfers) > 1: assert len(transfers) == len(self._n_hiddens) else: transfers *= len(self._n_hiddens) self._hidden_transfers = nest.flatten(transfers) self._n_out = n_out self._transfer = transfer self._initializers = initializers
def _build(self): """Connects the module to the graph. Returns: The learnable state, which has the same type, structure and shape as the `initial_state` passed to the constructor. """ flat_initial_state = nest.flatten(self._initial_state) if self._mask is not None: flat_mask = nest.flatten(self._mask) flat_learnable_state = [ _single_learnable_state(state, state_id=i, learnable=mask) for i, (state, mask) in enumerate(zip(flat_initial_state, flat_mask))] else: flat_learnable_state = [_single_learnable_state(state, state_id=i) for i, state in enumerate(flat_initial_state)] return nest.pack_sequence_as(structure=self._initial_state, flat_sequence=flat_learnable_state)
def testRegularizers(self, trainable, state_size): batch_size = 6 # Set the attribute to the class since it we can't set properties of # abstract classes snt.RNNCore.state_size = state_size flat_state_size = nest.flatten(state_size) core = snt.RNNCore(name="dummy_core") flat_regularizer = ([tf.contrib.layers.l1_regularizer(scale=0.5)] * len(flat_state_size)) trainable_regularizers = nest.pack_sequence_as( structure=state_size, flat_sequence=flat_regularizer) core.initial_state(batch_size, dtype=tf.float32, trainable=trainable, trainable_regularizers=trainable_regularizers) graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if not trainable: self.assertFalse(graph_regularizers) else: for i in range(len(flat_state_size)): self.assertRegexpMatches( graph_regularizers[i].name, ".*l1_regularizer.*")
def _get_flat_core_sizes(cores): """Obtains the list flattened output sizes of a list of cores. Args: cores: list of cores to get the shapes from. Returns: List of lists that, for each core, contains the list of its output dimensions. """ core_sizes_lists = [] for core in cores: flat_output_size = nest.flatten(core.output_size) core_sizes_lists.append([tensor_shape.as_shape(size).as_list() for size in flat_output_size]) return core_sizes_lists
def _create(self): # Concat bridge inputs on the depth dimensions bridge_input = nest.map_structure( lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]), self._bridge_input) bridge_input_flat = nest.flatten([bridge_input]) bridge_input_concat = tf.concat(bridge_input_flat, axis=1) state_size_splits = nest.flatten(self.decoder_state_size) total_decoder_state_size = sum(state_size_splits) # Pass bridge inputs through a fully connected layer layer initial_state_flat = tf.contrib.layers.fully_connected( bridge_input_concat, num_outputs=total_decoder_state_size, activation_fn=self._activation_fn, weights_initializer=tf.truncated_normal_initializer( stddev=self.parameter_init), biases_initializer=tf.zeros_initializer(), scope=None) # Shape back into required state size initial_state = tf.split(initial_state_flat, state_size_splits, axis=1) return nest.pack_sequence_as(self.decoder_state_size, initial_state)
def __init__(self, inpt, n_hidden, n_output, transfer_hidden=tf.nn.elu, transfer=None, hidden_weight_init=None, hidden_bias_init=None,weight_init=None, bias_init=None, name=None): """ :param inpt: inpt tensor :param n_hidden: scalar ot list, number of hidden units :param n_output: scalar, number of output units :param transfer_hidden: scalar or list, transfers for hidden units. If list, len must be == len(n_hidden). :param transfer: tf.Op or None """ self.n_hidden = nest.flatten(n_hidden) self.n_output = n_output self.hidden_weight_init = hidden_weight_init self.hidden_bias_init = hidden_bias_init transfer_hidden = nest.flatten(transfer_hidden) if len(transfer_hidden) == 1: transfer_hidden *= len(self.n_hidden) self.transfer_hidden = transfer_hidden self.transfer = transfer super(MLP, self).__init__(inpt, name, weight_init, bias_init)
def _zero_state(self, img, att, presence, state, transform_features, transform_state=False): with tf.variable_scope(self.__class__.__name__) as vs: features = self.extract_features(img, att)[1] if transform_features: features_flat = tf.reshape(features, (-1, self.n_units)) features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output features = tf.reshape(features_flat, tf.shape(features)) rnn_outputs, hidden_state = self._propagate(features, state) hidden_state = nest.flatten(hidden_state) if transform_state: for i, hs in enumerate(hidden_state): name = 'init_state_transform_{}'.format(i) hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state) self.rnn_vs = vs return state, rnn_outputs
def _create(self): bridge_input = nest.map_structure( lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]), self._bridge_input) bridge_input_flat = nest.flatten([bridge_input]) bridge_input_concat = tf.concat(bridge_input_flat, 1) state_size_splits = nest.flatten(self.decoder_state_size) total_decoder_state_size = sum(state_size_splits) initial_state_flat = fully_connected( bridge_input_concat, total_decoder_state_size, self._mode, self._reuse, activation=self._activation_fn) initial_state = tf.split(initial_state_flat, state_size_splits, axis=1) return nest.pack_sequence_as(self.decoder_state_size, initial_state)
def call(self, inputs, state): """Long short-term memory cell with attention (LSTMA).""" state, attns, attn_states = state attn_states = array_ops.reshape(attn_states, [-1, self._attn_length, self._attn_size]) input_size = self._input_size if input_size is None: input_size = inputs.get_shape().as_list()[1] inputs = _linear([inputs, attns], input_size, True) lstm_output, new_state = self._cell(inputs, state) new_state_cat = array_ops.concat(nest.flatten(new_state), 1) new_attns, new_attn_states = self._attention(new_state_cat, attn_states) with tf.variable_scope("attn_output_projection"): output = _linear([lstm_output, new_attns], self._attn_size, True) new_attn_states = array_ops.concat( [new_attn_states, array_ops.expand_dims(output, 1)], 1) new_attn_states = array_ops.reshape( new_attn_states, [-1, self._attn_length * self._attn_size]) new_state = (new_state, new_attns, new_attn_states) return output, new_state
def trainable_initial_state(batch_size, state_size, initializer=None, name="initial_state"): flat_state_size = nest.flatten(state_size) if not initializer: flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size) else: flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size) names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))] tiled_states = [] for name, size, init in zip(names, flat_state_size, flat_initializer): shape_with_batch_dim = [1, size] initial_state_variable = tf.get_variable( name, shape=shape_with_batch_dim, initializer=init()) tiled_state = tf.tile(initial_state_variable, [batch_size, 1], name=(name + "_tiled")) tiled_states.append(tiled_state) return nest.pack_sequence_as(structure=state_size, flat_sequence=tiled_states)
def state_tuple_to_dict(state): """Returns a dict containing flattened `state`. Args: state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must have the same rank and agree on all dimensions except the last. Returns: A dict containing the `Tensor`s that make up `state`. The keys of the dict are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor` in a depth-first traversal of `state`. """ with ops.name_scope('state_tuple_to_dict'): flat_state = nest.flatten(state) state_dict = {} for i, state_component in enumerate(flat_state): state_name = _get_state_name(i) state_value = (None if state_component is None else array_ops.identity( state_component, name=state_name)) state_dict[state_name] = state_value return state_dict
def state_tuple_to_dict(state): """Returns a dict containing flattened `state`. Args: state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must have the same rank and agree on all dimensions except the last. Returns: A dict containing the `Tensor`s that make up `state`. The keys of the dict are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor` in a depth-first traversal of `state`. """ with ops.name_scope('state_tuple_to_dict'): flat_state = nest.flatten(state) state_dict = {} for i, state_component in enumerate(flat_state): state_name = _get_state_name(i) state_value = (None if state_component is None else array_ops.identity(state_component, name=state_name)) state_dict[state_name] = state_value return state_dict
def encode(self, inputs, input_length, _parses): with tf.name_scope('BiLSTMEncoder'): fw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)]) bw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)]) outputs, output_state = tf.nn.bidirectional_dynamic_rnn(fw_cell_enc, bw_cell_enc, inputs, input_length, dtype=tf.float32) fw_output_state, bw_output_state = output_state # concat each element of the final state, so that we're compatible with a unidirectional # decoder output_state = nest.pack_sequence_as(fw_output_state, [tf.concat((x, y), axis=1) for x, y in zip(nest.flatten(fw_output_state), nest.flatten(bw_output_state))]) return tf.concat(outputs, axis=2), output_state
def __init__(self, wrapped : tf.contrib.rnn.RNNCell, parent_state): super().__init__() self._wrapped = wrapped self._flat_parent_state = tf.concat(nest.flatten(parent_state), axis=1)
def __init__(self, wrapped : tf.contrib.rnn.RNNCell, constant_input): super().__init__() self._wrapped = wrapped self._flat_constant_input = tf.concat(nest.flatten(constant_input), axis=1)
def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. Args: explicit_dtype: explicitly declared dtype or None. state: RNN's hidden state. Must be a Tensor or a nested iterable containing Tensors. Returns: dtype: inferred dtype of hidden state. Raises: ValueError: if `state` has heterogeneous dtypes or is empty. """ if explicit_dtype is not None: return explicit_dtype elif nest.is_sequence(state): inferred_dtypes = [element.dtype for element in nest.flatten(state)] if not inferred_dtypes: raise ValueError("Unable to infer dtype from empty state.") all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) if not all_same: raise ValueError( "State has tensors of different inferred_dtypes. Unable to infer a " "single representative dtype.") return inferred_dtypes[0] else: return state.dtype
def make_fig(air, sess, checkpoint_dir=None, global_step=None, n_samples=10): n_steps = air.max_steps xx, pred_canvas, pred_crop, prob, pres, w = sess.run( [air.obs, air.canvas, air.glimpse, air.num_steps_distrib.prob()[..., 1:], air.presence, air.where]) height, width = xx.shape[1:] bs = min(n_samples, air.batch_size) scale = 1.5 figsize = scale * np.asarray((bs, 2 * n_steps + 1)) fig, axes = plt.subplots(2 * n_steps + 1, bs, figsize=figsize) for i, ax in enumerate(axes[0]): ax.imshow(xx[i], cmap='gray', vmin=0, vmax=1) for i, ax_row in enumerate(axes[1:1 + n_steps]): for j, ax in enumerate(ax_row): ax.imshow(pred_canvas[i, j], cmap='gray', vmin=0, vmax=1) if pres[i, j, 0] > .5: rect_stn(ax, width, height, w[i, j], 'r') for i, ax_row in enumerate(axes[1 + n_steps:]): for j, ax in enumerate(ax_row): ax.imshow(pred_crop[i, j], cmap='gray') # , vmin=0, vmax=1) ax.set_title('{:d} with p({:d}) = {:.02f}'.format(int(pres[i, j, 0]), i + 1, prob[j, i].squeeze()), fontsize=4 * scale) for ax in axes.flatten(): ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) if checkpoint_dir is not None: fig_name = osp.join(checkpoint_dir, 'progress_fig_{}.png'.format(global_step)) fig.savefig(fig_name, dpi=300) plt.close('all')
def log_norm(expr_list, name): """ :param expr_list: :param name: :return: """ n_elems = 0 norm = 0. for e in nest.flatten(expr_list): n_elems += tf.reduce_prod(tf.shape(e)) norm += tf.reduce_sum(e**2) norm /= tf.to_float(n_elems) tf.summary.scalar(name, norm) return norm
def _embed(self, inpt): flatten = snt.BatchFlatten() mlp = MLP(self._n_hidden, n_out=self._n_param) seq = snt.Sequential([flatten, mlp]) return seq(inpt)
def _build(self, img, what, where, presence_prob, state=None): batch_size = int(img.get_shape()[0]) parts = [tf.reshape(tf.transpose(i, (1, 0, 2)), (batch_size, -1)) for i in (what, where, presence_prob)] if state is not None: parts += nest.flatten(state) img_flat = tf.reshape(img, (batch_size, -1)) baseline_inpts = [img_flat] + parts baseline_inpts = tf.concat(baseline_inpts, -1) mlp = MLP(self._n_hidden, n_out=1) baseline = mlp(baseline_inpts) return baseline
def initial_cell_state_from_embedding(cell, z, batch_size, name=None): """Computes an initial RNN `cell` state from an embedding, `z`.""" flat_state_sizes = tf_nest.flatten(cell.state_size) return tf_nest.pack_sequence_as( cell.zero_state(batch_size=batch_size, dtype=tf.float32), tf.split( tf.layers.dense( z, sum(flat_state_sizes), activation=tf.tanh, kernel_initializer=tf.random_normal_initializer(stddev=0.001), name=name), flat_state_sizes, axis=1))
def _assert_sructures_equal(self, struct1, struct2): tf_nest.assert_same_structure(struct1, struct2) for a, b in zip(tf_nest.flatten(struct1), tf_nest.flatten(struct2)): np.testing.assert_array_equal(a, b)
def __init__(self, initial_state, mask=None, name="trainable_initial_state"): """Constructs the Module that introduces a trainable state in the graph. It receives an initial state that will be used as the initial values for the trainable variables that the module contains, and optionally a mask that indicates the parts of the initial state that should be learnable. Args: initial_state: tensor or arbitrarily nested iterables of tensors. mask: optional boolean mask. It should have the same nested structure as the given initial_state. name: module name. Raises: TypeError: if mask is not a list of booleans or None. """ super(TrainableInitialState, self).__init__(name=name) # Since python 2.7, DeprecationWarning is ignored by default. # Turn on the warning: warnings.simplefilter("always", DeprecationWarning) warnings.warn("Use the trainable flag in initial_state instead.", DeprecationWarning, stacklevel=2) if mask is not None: flat_mask = nest.flatten(mask) if not all([isinstance(m, bool) for m in flat_mask]): raise TypeError("Mask should be None or a list of boolean values.") nest.assert_same_structure(initial_state, mask) self._mask = mask self._initial_state = initial_state
def _testACT(self, input_size, hidden_size, output_size, seq_len, batch_size, core, get_state_for_halting): threshold = 0.99 act = pondering_rnn.ACTCore( core, output_size, threshold, get_state_for_halting) seq_input = tf.random_uniform(shape=(seq_len, batch_size, input_size)) initial_state = core.initial_state(batch_size) seq_output = tf.nn.dynamic_rnn( act, seq_input, time_major=True, initial_state=initial_state) for tensor in nest.flatten(seq_output): self.assertEqual(seq_input.dtype, tensor.dtype) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) output = sess.run(seq_output) (final_out, (iteration, r_t)), final_cumul_state = output self.assertEqual((seq_len, batch_size, output_size), final_out.shape) self.assertEqual((seq_len, batch_size, 1), iteration.shape) self.assertTrue(np.all(iteration == np.floor(iteration))) state_shape = get_state_for_halting(initial_state).get_shape().as_list() self.assertEqual(tuple(state_shape), get_state_for_halting(final_cumul_state).shape) self.assertEqual((seq_len, batch_size, 1), r_t.shape) self.assertTrue(np.all(r_t >= 0)) self.assertTrue(np.all(r_t <= threshold))
def testInitialStateTuple(self, trainable, use_custom_initial_value, state_size): batch_size = 6 # Set the attribute to the class since it we can't set properties of # abstract classes snt.RNNCore.state_size = state_size flat_state_size = nest.flatten(state_size) core = snt.RNNCore(name="dummy_core") if use_custom_initial_value: flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size) trainable_initializers = nest.pack_sequence_as( structure=state_size, flat_sequence=flat_initializer) else: trainable_initializers = None initial_state = core.initial_state( batch_size, dtype=tf.float32, trainable=trainable, trainable_initializers=trainable_initializers) nest.assert_same_structure(initial_state, state_size) flat_initial_state = nest.flatten(initial_state) for state, size in zip(flat_initial_state, flat_state_size): self.assertEqual(state.get_shape(), [batch_size, size]) with self.test_session() as sess: tf.global_variables_initializer().run() flat_initial_state_value = sess.run(flat_initial_state) for value, size in zip(flat_initial_state_value, flat_state_size): expected_initial_state = np.empty([batch_size, size]) if not trainable: expected_initial_state.fill(0) elif use_custom_initial_value: expected_initial_state.fill(2) else: value_row = value[0] expected_initial_state = np.tile(value_row, (batch_size, 1)) self.assertAllClose(value, expected_initial_state)
def __init__(self, encoder_outputs, decoder_state_size): self.encoder_outputs = encoder_outputs self.decoder_state_size = decoder_state_size self.batch_size = tf.shape( nest.flatten(self.encoder_outputs.final_state)[0])[0]
def batch_size(self): return tf.shape(nest.flatten([self.initial_state])[0])[0]
def _assert_correct_outputs(self, initial_state_): initial_state_flat_ = nest.flatten(initial_state_) for element in initial_state_flat_: self.assertAllEqual(element, np.zeros_like(element))
def _assert_correct_outputs(self, initial_state_): nest.assert_same_structure(initial_state_, self.decoder_cell.state_size) nest.assert_same_structure( initial_state_, self.encoder_outputs.final_state) encoder_state_flat = nest.flatten(self.encoder_outputs.final_state) with self.test_session() as sess: encoder_state_flat_ = sess.run(encoder_state_flat) initial_state_flat_ = nest.flatten(initial_state_) for e_dec, e_enc in zip(initial_state_flat_, encoder_state_flat_): self.assertAllEqual(e_dec, e_enc)
def __init__(self, encoder_outputs, decoder_state_size, params, mode, reuse): Configurable.__init__(self, params, mode, reuse) self.encoder_outputs = encoder_outputs self.decoder_state_size = decoder_state_size self.batch_size = tf.shape( nest.flatten(self.encoder_outputs.final_state)[0])[0]
def nest_map(func, nested): if not nest.is_sequence(nested): return func(nested) flat = nest.flatten(nested) return nest.pack_sequence_as(nested, list(map(func, flat)))
def __call__(self, inputs, state, scope=None): varscope = scope or tf.get_variable_scope() flat_inputs = nest.flatten(inputs) flat_state = nest.flatten(state) flat_inputs_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_inputs])) flat_state_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_state])) flat_output_unstacked = [] flat_next_state_unstacked = [] output_sample = None next_state_sample = None for i, (inputs_k, state_k) in enumerate(zip(flat_inputs_unstacked, flat_state_unstacked)): inputs_k = nest.pack_sequence_as(inputs, inputs_k) state_k = nest.pack_sequence_as(state, state_k) if i == 0: output_k, next_state_k = self.cell(inputs_k, state_k, scope=scope) else: with tf.variable_scope(varscope, reuse=True): output_k, next_state_k = self.cell(inputs_k, state_k, scope=varscope if scope is not None else None) flat_output_unstacked.append(nest.flatten(output_k)) flat_next_state_unstacked.append(nest.flatten(next_state_k)) output_sample = output_k next_state_sample = next_state_k flat_output = [tf.stack(tensors, axis=1) for tensors in zip(*flat_output_unstacked)] flat_next_state = [tf.stack(tensors, axis=1) for tensors in zip(*flat_next_state_unstacked)] output = nest.pack_sequence_as(output_sample, flat_output) next_state = nest.pack_sequence_as(next_state_sample, flat_next_state) return output, next_state
def __call__(self, inputs, state, scope=None): """Long short-term memory cell with attention (LSTMA).""" with vs.variable_scope(scope or type(self).__name__): if self._state_is_tuple: state, attns, attn_states = state else: states = state state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) attns = array_ops.slice( states, [0, self._cell.state_size], [-1, self._attn_size]) attn_states = array_ops.slice( states, [0, self._cell.state_size + self._attn_size], [-1, self._attn_size * self._attn_length]) attn_states = array_ops.reshape(attn_states, [-1, self._attn_length, self._attn_size]) input_size = self._input_size if input_size is None: input_size = inputs.get_shape().as_list()[1] inputs = _linear([inputs, attns], input_size, True) lstm_output, new_state = self._cell(inputs, state) if self._state_is_tuple: new_state_cat = array_ops.concat(1, nest.flatten(new_state)) else: new_state_cat = new_state new_attns, new_attn_states = self._attention(new_state_cat, attn_states) with vs.variable_scope("AttnOutputProjection"): output = _linear([lstm_output, new_attns], self._attn_size, True) new_attn_states = array_ops.concat(1, [new_attn_states, array_ops.expand_dims(output, 1)]) new_attn_states = array_ops.reshape( new_attn_states, [-1, self._attn_length * self._attn_size]) new_state = (new_state, new_attns, new_attn_states) if not self._state_is_tuple: new_state = array_ops.concat(1, list(new_state)) return output, new_state
def wrap_state(self, state): dummy = BeamDecoderCellWrapper(None, self.batch_concat, self.num_classes, self.max_len, self.stop_token, self.beam_size, self.min_op, self.min_frac) if nest.is_sequence(state): batch_size = tf.shape(nest.flatten(state)[0])[0] dtype = nest.flatten(state)[0].dtype else: batch_size = tf.shape(state)[0] dtype = state.dtype return dummy._create_state(batch_size, dtype, cell_state=state)
def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size x s]` for each s in `state_size`. """ state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros
def __init__(self, initial_state, mask=None, name="trainable_initial_state"): """Constructs the Module that introduces a trainable state in the graph. It receives an initial state that will be used as the intial values for the trainable variables that the module contains, and optionally a mask that indicates the parts of the initial state that should be learnable. Args: initial_state: tensor or arbitrarily nested iterables of tensors. mask: optional boolean mask. It should have the same nested structure as the given initial_state. name: module name. Raises: TypeError: if mask is not a list of booleans or None. """ super(TrainableInitialState, self).__init__(name=name) # Since python 2.7, DeprecationWarning is ignored by default. # Turn on the warning: warnings.simplefilter("always", DeprecationWarning) warnings.warn("Use the trainable flag in initial_state instead.", DeprecationWarning, stacklevel=2) if mask is not None: flat_mask = nest.flatten(mask) if not all([isinstance(m, bool) for m in flat_mask]): raise TypeError("Mask should be None or a list of boolean values.") nest.assert_same_structure(initial_state, mask) self._mask = mask self._initial_state = initial_state
def _reverse_seq(input_seq, lengths): """Reverse a list of Tensors up to specified lengths. Args: input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) or nested tuples of tensors. lengths: A `Tensor` of dimension batch_size, containing lengths for each sequence in the batch. If "None" is specified, simply reverses the list. Returns: time-reversed sequence """ if lengths is None: return list(reversed(input_seq)) flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) flat_results = [[] for _ in range(len(input_seq))] for sequence in zip(*flat_input_seq): input_shape = tensor_shape.unknown_shape( ndims=sequence[0].get_shape().ndims) for input_ in sequence: input_shape.merge_with(input_.get_shape()) input_.set_shape(input_shape) # Join into (time, batch_size, depth) s_joined = array_ops.pack(sequence) # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 if lengths is not None: lengths = math_ops.to_int64(lengths) # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list result = array_ops.unpack(s_reversed) for r, flat_result in zip(result, flat_results): r.set_shape(input_shape) flat_result.append(r) results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) for input_, flat_result in zip(input_seq, flat_results)] return results