Python tensorflow.python.util.nest 模块,flatten() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.util.nest.flatten()

项目:attend_infer_repeat    作者:akosiorek    | 项目源码 | 文件源码
def __init__(self, n_hiddens, hidden_transfer=default_activation, n_out=None, transfer=None, initializers=default_init):
        """Initialises the MLP

        :param n_hiddens: int or an interable of ints, number of hidden units in layers
        :param hidden_transfer: callable or iterable; a transfer function for hidden layers or an interable thereof. If it's an iterable its length should be the same as length of `n_hiddens`
        :param n_out: int or None, number of output units
        :param transfer: callable or None, a transfer function for the output
        """

        super(MLP, self).__init__(self.__class__.__name__)
        self._n_hiddens = nest.flatten(n_hiddens)
        transfers = nest.flatten(hidden_transfer)
        if len(transfers) > 1:
            assert len(transfers) == len(self._n_hiddens)
        else:
            transfers *= len(self._n_hiddens)
        self._hidden_transfers = nest.flatten(transfers)
        self._n_out = n_out
        self._transfer = transfer
        self._initializers = initializers
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def _build(self):
    """Connects the module to the graph.

    Returns:
      The learnable state, which has the same type, structure and shape as
        the `initial_state` passed to the constructor.
    """
    flat_initial_state = nest.flatten(self._initial_state)
    if self._mask is not None:
      flat_mask = nest.flatten(self._mask)
      flat_learnable_state = [
          _single_learnable_state(state, state_id=i, learnable=mask)
          for i, (state, mask) in enumerate(zip(flat_initial_state, flat_mask))]
    else:
      flat_learnable_state = [_single_learnable_state(state, state_id=i)
                              for i, state in enumerate(flat_initial_state)]

    return nest.pack_sequence_as(structure=self._initial_state,
                                 flat_sequence=flat_learnable_state)
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def testRegularizers(self, trainable, state_size):
    batch_size = 6

    # Set the attribute to the class since it we can't set properties of
    # abstract classes
    snt.RNNCore.state_size = state_size
    flat_state_size = nest.flatten(state_size)
    core = snt.RNNCore(name="dummy_core")
    flat_regularizer = ([tf.contrib.layers.l1_regularizer(scale=0.5)] *
                        len(flat_state_size))
    trainable_regularizers = nest.pack_sequence_as(
        structure=state_size, flat_sequence=flat_regularizer)

    core.initial_state(batch_size, dtype=tf.float32, trainable=trainable,
                       trainable_regularizers=trainable_regularizers)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    if not trainable:
      self.assertFalse(graph_regularizers)
    else:
      for i in range(len(flat_state_size)):
        self.assertRegexpMatches(
            graph_regularizers[i].name, ".*l1_regularizer.*")
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def _get_flat_core_sizes(cores):
  """Obtains the list flattened output sizes of a list of cores.

  Args:
    cores: list of cores to get the shapes from.

  Returns:
    List of lists that, for each core, contains the list of its output
      dimensions.
  """
  core_sizes_lists = []
  for core in cores:
    flat_output_size = nest.flatten(core.output_size)
    core_sizes_lists.append([tensor_shape.as_shape(size).as_list()
                             for size in flat_output_size])
  return core_sizes_lists
项目:tensorflow_end2end_speech_recognition    作者:hirofumi0810    | 项目源码 | 文件源码
def _create(self):
        # Concat bridge inputs on the depth dimensions
        bridge_input = nest.map_structure(
            lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
            self._bridge_input)
        bridge_input_flat = nest.flatten([bridge_input])
        bridge_input_concat = tf.concat(bridge_input_flat, axis=1)

        state_size_splits = nest.flatten(self.decoder_state_size)
        total_decoder_state_size = sum(state_size_splits)

        # Pass bridge inputs through a fully connected layer layer
        initial_state_flat = tf.contrib.layers.fully_connected(
            bridge_input_concat,
            num_outputs=total_decoder_state_size,
            activation_fn=self._activation_fn,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=self.parameter_init),
            biases_initializer=tf.zeros_initializer(),
            scope=None)

        # Shape back into required state size
        initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
        return nest.pack_sequence_as(self.decoder_state_size, initial_state)
项目:hart    作者:akosiorek    | 项目源码 | 文件源码
def __init__(self, inpt, n_hidden, n_output, transfer_hidden=tf.nn.elu, transfer=None,
                 hidden_weight_init=None, hidden_bias_init=None,weight_init=None, bias_init=None,
                 name=None):
        """
        :param inpt: inpt tensor
        :param n_hidden: scalar ot list, number of hidden units
        :param n_output: scalar, number of output units
        :param transfer_hidden: scalar or list, transfers for hidden units. If list, len must be == len(n_hidden).
        :param transfer: tf.Op or None
        """

        self.n_hidden = nest.flatten(n_hidden)
        self.n_output = n_output
        self.hidden_weight_init = hidden_weight_init
        self.hidden_bias_init = hidden_bias_init

        transfer_hidden = nest.flatten(transfer_hidden)
        if len(transfer_hidden) == 1:
            transfer_hidden *= len(self.n_hidden)
        self.transfer_hidden = transfer_hidden

        self.transfer = transfer
        super(MLP, self).__init__(inpt, name, weight_init, bias_init)
项目:hart    作者:akosiorek    | 项目源码 | 文件源码
def _zero_state(self, img, att, presence, state, transform_features, transform_state=False):

        with tf.variable_scope(self.__class__.__name__) as vs:
            features = self.extract_features(img, att)[1]

            if transform_features:
                features_flat = tf.reshape(features, (-1, self.n_units))
                features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output
                features = tf.reshape(features_flat, tf.shape(features))

            rnn_outputs, hidden_state = self._propagate(features, state)

            hidden_state = nest.flatten(hidden_state)

            if transform_state:
                for i, hs in enumerate(hidden_state):
                    name = 'init_state_transform_{}'.format(i)
                    hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output

            state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state)
        self.rnn_vs = vs
        return state, rnn_outputs
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def _create(self):
        bridge_input = nest.map_structure(
            lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
            self._bridge_input)
        bridge_input_flat = nest.flatten([bridge_input])
        bridge_input_concat = tf.concat(bridge_input_flat, 1)

        state_size_splits = nest.flatten(self.decoder_state_size)
        total_decoder_state_size = sum(state_size_splits)

        initial_state_flat = fully_connected(
            bridge_input_concat,
            total_decoder_state_size,
            self._mode,
            self._reuse,
            activation=self._activation_fn)
        initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
        return nest.pack_sequence_as(self.decoder_state_size, initial_state)
项目:odin    作者:imito    | 项目源码 | 文件源码
def call(self, inputs, state):
    """Long short-term memory cell with attention (LSTMA)."""
    state, attns, attn_states = state
    attn_states = array_ops.reshape(attn_states,
                                    [-1, self._attn_length, self._attn_size])
    input_size = self._input_size
    if input_size is None:
      input_size = inputs.get_shape().as_list()[1]
    inputs = _linear([inputs, attns], input_size, True)
    lstm_output, new_state = self._cell(inputs, state)
    new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
    with tf.variable_scope("attn_output_projection"):
      output = _linear([lstm_output, new_attns], self._attn_size, True)
    new_attn_states = array_ops.concat(
        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
    new_attn_states = array_ops.reshape(
        new_attn_states, [-1, self._attn_length * self._attn_size])
    new_state = (new_state, new_attns, new_attn_states)
    return output, new_state
项目:pointer-network-tensorflow    作者:devsisters    | 项目源码 | 文件源码
def trainable_initial_state(batch_size, state_size,
                            initializer=None, name="initial_state"):
  flat_state_size = nest.flatten(state_size)

  if not initializer:
    flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
  else:
    flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size)

  names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))]
  tiled_states = []

  for name, size, init in zip(names, flat_state_size, flat_initializer):
    shape_with_batch_dim = [1, size]
    initial_state_variable = tf.get_variable(
        name, shape=shape_with_batch_dim, initializer=init())

    tiled_state = tf.tile(initial_state_variable,
                          [batch_size, 1], name=(name + "_tiled"))
    tiled_states.append(tiled_state)

  return nest.pack_sequence_as(structure=state_size,
                               flat_sequence=tiled_states)
项目:tf-tutorial    作者:zchen0211    | 项目源码 | 文件源码
def _build(self):
    """Connects the module to the graph.

    Returns:
      The learnable state, which has the same type, structure and shape as
        the `initial_state` passed to the constructor.
    """
    flat_initial_state = nest.flatten(self._initial_state)
    if self._mask is not None:
      flat_mask = nest.flatten(self._mask)
      flat_learnable_state = [
          _single_learnable_state(state, state_id=i, learnable=mask)
          for i, (state, mask) in enumerate(zip(flat_initial_state, flat_mask))]
    else:
      flat_learnable_state = [_single_learnable_state(state, state_id=i)
                              for i, state in enumerate(flat_initial_state)]

    return nest.pack_sequence_as(structure=self._initial_state,
                                 flat_sequence=flat_learnable_state)
项目:tf-tutorial    作者:zchen0211    | 项目源码 | 文件源码
def _get_flat_core_sizes(cores):
  """Obtains the list flattened output sizes of a list of cores.

  Args:
    cores: list of cores to get the shapes from.

  Returns:
    List of lists that, for each core, contains the list of its output
      dimensions.
  """
  core_sizes_lists = []
  for core in cores:
    flat_output_size = nest.flatten(core.output_size)
    core_sizes_lists.append([tensor_shape.as_shape(size).as_list()
                             for size in flat_output_size])
  return core_sizes_lists
项目:tf-sparql    作者:derdav3    | 项目源码 | 文件源码
def _build(self):
    """Connects the module to the graph.

    Returns:
      The learnable state, which has the same type, structure and shape as
        the `initial_state` passed to the constructor.
    """
    flat_initial_state = nest.flatten(self._initial_state)
    if self._mask is not None:
      flat_mask = nest.flatten(self._mask)
      flat_learnable_state = [
          _single_learnable_state(state, state_id=i, learnable=mask)
          for i, (state, mask) in enumerate(zip(flat_initial_state, flat_mask))]
    else:
      flat_learnable_state = [_single_learnable_state(state, state_id=i)
                              for i, state in enumerate(flat_initial_state)]

    return nest.pack_sequence_as(structure=self._initial_state,
                                 flat_sequence=flat_learnable_state)
项目:tf-sparql    作者:derdav3    | 项目源码 | 文件源码
def _get_flat_core_sizes(cores):
  """Obtains the list flattened output sizes of a list of cores.

  Args:
    cores: list of cores to get the shapes from.

  Returns:
    List of lists that, for each core, contains the list of its output
      dimensions.
  """
  core_sizes_lists = []
  for core in cores:
    flat_output_size = nest.flatten(core.output_size)
    core_sizes_lists.append([tensor_shape.as_shape(size).as_list()
                             for size in flat_output_size])
  return core_sizes_lists
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def state_tuple_to_dict(state):
  """Returns a dict containing flattened `state`.

  Args:
    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
    have the same rank and agree on all dimensions except the last.

  Returns:
    A dict containing the `Tensor`s that make up `state`. The keys of the dict
    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
    in a depth-first traversal of `state`.
  """
  with ops.name_scope('state_tuple_to_dict'):
    flat_state = nest.flatten(state)
    state_dict = {}
    for i, state_component in enumerate(flat_state):
      state_name = _get_state_name(i)
      state_value = (None if state_component is None else array_ops.identity(
          state_component, name=state_name))
      state_dict[state_name] = state_value
  return state_dict
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def state_tuple_to_dict(state):
  """Returns a dict containing flattened `state`.

  Args:
    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
    have the same rank and agree on all dimensions except the last.

  Returns:
    A dict containing the `Tensor`s that make up `state`. The keys of the dict
    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
    in a depth-first traversal of `state`.
  """
  with ops.name_scope('state_tuple_to_dict'):
    flat_state = nest.flatten(state)
    state_dict = {}
    for i, state_component in enumerate(flat_state):
      state_name = _get_state_name(i)
      state_value = (None if state_component is None
                     else array_ops.identity(state_component, name=state_name))
      state_dict[state_name] = state_value
  return state_dict
项目:neural-combinatorial-rl-tensorflow    作者:devsisters    | 项目源码 | 文件源码
def trainable_initial_state(batch_size, state_size,
                            initializer=None, name="initial_state"):
  flat_state_size = nest.flatten(state_size)

  if not initializer:
    flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
  else:
    flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size)

  names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))]
  tiled_states = []

  for name, size, init in zip(names, flat_state_size, flat_initializer):
    shape_with_batch_dim = [1, size]
    initial_state_variable = tf.get_variable(
        name, shape=shape_with_batch_dim, initializer=init())

    tiled_state = tf.tile(initial_state_variable,
                          [batch_size, 1], name=(name + "_tiled"))
    tiled_states.append(tiled_state)

  return nest.pack_sequence_as(structure=state_size,
                               flat_sequence=tiled_states)
项目:almond-nnparser    作者:Stanford-Mobisocial-IoT-Lab    | 项目源码 | 文件源码
def encode(self, inputs, input_length, _parses):
        with tf.name_scope('BiLSTMEncoder'):
            fw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])
            bw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])

            outputs, output_state = tf.nn.bidirectional_dynamic_rnn(fw_cell_enc, bw_cell_enc, inputs, input_length,
                                                                    dtype=tf.float32)

            fw_output_state, bw_output_state = output_state
            # concat each element of the final state, so that we're compatible with a unidirectional
            # decoder
            output_state = nest.pack_sequence_as(fw_output_state, [tf.concat((x, y), axis=1) for x, y in zip(nest.flatten(fw_output_state), nest.flatten(bw_output_state))])

            return tf.concat(outputs, axis=2), output_state
项目:almond-nnparser    作者:Stanford-Mobisocial-IoT-Lab    | 项目源码 | 文件源码
def __init__(self, wrapped : tf.contrib.rnn.RNNCell, parent_state):
        super().__init__()
        self._wrapped = wrapped
        self._flat_parent_state = tf.concat(nest.flatten(parent_state), axis=1)
项目:almond-nnparser    作者:Stanford-Mobisocial-IoT-Lab    | 项目源码 | 文件源码
def __init__(self, wrapped : tf.contrib.rnn.RNNCell, constant_input):
        super().__init__()
        self._wrapped = wrapped
        self._flat_constant_input = tf.concat(nest.flatten(constant_input), axis=1)
项目:youtube-8m    作者:wangheda    | 项目源码 | 文件源码
def _infer_state_dtype(explicit_dtype, state):
    """Infer the dtype of an RNN state.

    Args:
      explicit_dtype: explicitly declared dtype or None.
      state: RNN's hidden state. Must be a Tensor or a nested iterable containing
        Tensors.

    Returns:
      dtype: inferred dtype of hidden state.

    Raises:
      ValueError: if `state` has heterogeneous dtypes or is empty.
    """
    if explicit_dtype is not None:
        return explicit_dtype
    elif nest.is_sequence(state):
        inferred_dtypes = [element.dtype for element in nest.flatten(state)]
        if not inferred_dtypes:
            raise ValueError("Unable to infer dtype from empty state.")
        all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
        if not all_same:
            raise ValueError(
                "State has tensors of different inferred_dtypes. Unable to infer a "
                "single representative dtype.")
        return inferred_dtypes[0]
    else:
        return state.dtype
项目:attend_infer_repeat    作者:akosiorek    | 项目源码 | 文件源码
def make_fig(air, sess, checkpoint_dir=None, global_step=None, n_samples=10):
    n_steps = air.max_steps

    xx, pred_canvas, pred_crop, prob, pres, w = sess.run(
        [air.obs, air.canvas, air.glimpse, air.num_steps_distrib.prob()[..., 1:], air.presence, air.where])
    height, width = xx.shape[1:]

    bs = min(n_samples, air.batch_size)
    scale = 1.5
    figsize = scale * np.asarray((bs, 2 * n_steps + 1))
    fig, axes = plt.subplots(2 * n_steps + 1, bs, figsize=figsize)

    for i, ax in enumerate(axes[0]):
        ax.imshow(xx[i], cmap='gray', vmin=0, vmax=1)

    for i, ax_row in enumerate(axes[1:1 + n_steps]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_canvas[i, j], cmap='gray', vmin=0, vmax=1)
            if pres[i, j, 0] > .5:
                rect_stn(ax, width, height, w[i, j], 'r')

    for i, ax_row in enumerate(axes[1 + n_steps:]):
        for j, ax in enumerate(ax_row):
            ax.imshow(pred_crop[i, j], cmap='gray')  # , vmin=0, vmax=1)
            ax.set_title('{:d} with p({:d}) = {:.02f}'.format(int(pres[i, j, 0]), i + 1, prob[j, i].squeeze()),
                         fontsize=4 * scale)

    for ax in axes.flatten():
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

    if checkpoint_dir is not None:
        fig_name = osp.join(checkpoint_dir, 'progress_fig_{}.png'.format(global_step))
        fig.savefig(fig_name, dpi=300)
        plt.close('all')
项目:attend_infer_repeat    作者:akosiorek    | 项目源码 | 文件源码
def log_norm(expr_list, name):
    """

    :param expr_list:
    :param name:
    :return:
    """
    n_elems = 0
    norm = 0.
    for e in nest.flatten(expr_list):
        n_elems += tf.reduce_prod(tf.shape(e))
        norm += tf.reduce_sum(e**2)
    norm /= tf.to_float(n_elems)
    tf.summary.scalar(name, norm)
    return norm
项目:attend_infer_repeat    作者:akosiorek    | 项目源码 | 文件源码
def _embed(self, inpt):
        flatten = snt.BatchFlatten()
        mlp = MLP(self._n_hidden, n_out=self._n_param)
        seq = snt.Sequential([flatten, mlp])
        return seq(inpt)
项目:attend_infer_repeat    作者:akosiorek    | 项目源码 | 文件源码
def _build(self, img, what, where, presence_prob, state=None):

        batch_size = int(img.get_shape()[0])
        parts = [tf.reshape(tf.transpose(i, (1, 0, 2)), (batch_size, -1)) for i in (what, where, presence_prob)]
        if state is not None:
            parts += nest.flatten(state)

        img_flat = tf.reshape(img, (batch_size, -1))
        baseline_inpts = [img_flat] + parts
        baseline_inpts = tf.concat(baseline_inpts, -1)
        mlp = MLP(self._n_hidden, n_out=1)
        baseline = mlp(baseline_inpts)
        return baseline
项目:magenta    作者:tensorflow    | 项目源码 | 文件源码
def initial_cell_state_from_embedding(cell, z, batch_size, name=None):
  """Computes an initial RNN `cell` state from an embedding, `z`."""
  flat_state_sizes = tf_nest.flatten(cell.state_size)
  return tf_nest.pack_sequence_as(
      cell.zero_state(batch_size=batch_size, dtype=tf.float32),
      tf.split(
          tf.layers.dense(
              z,
              sum(flat_state_sizes),
              activation=tf.tanh,
              kernel_initializer=tf.random_normal_initializer(stddev=0.001),
              name=name),
          flat_state_sizes,
          axis=1))
项目:magenta    作者:tensorflow    | 项目源码 | 文件源码
def _assert_sructures_equal(self, struct1, struct2):
    tf_nest.assert_same_structure(struct1, struct2)
    for a, b in zip(tf_nest.flatten(struct1), tf_nest.flatten(struct2)):
      np.testing.assert_array_equal(a, b)
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def __init__(self, initial_state, mask=None, name="trainable_initial_state"):
    """Constructs the Module that introduces a trainable state in the graph.

    It receives an initial state that will be used as the initial values for the
    trainable variables that the module contains, and optionally a mask that
    indicates the parts of the initial state that should be learnable.

    Args:
      initial_state: tensor or arbitrarily nested iterables of tensors.
      mask: optional boolean mask. It should have the same nested structure as
       the given initial_state.
      name: module name.

    Raises:
      TypeError: if mask is not a list of booleans or None.
    """
    super(TrainableInitialState, self).__init__(name=name)

    # Since python 2.7, DeprecationWarning is ignored by default.
    # Turn on the warning:
    warnings.simplefilter("always", DeprecationWarning)
    warnings.warn("Use the trainable flag in initial_state instead.",
                  DeprecationWarning, stacklevel=2)

    if mask is not None:
      flat_mask = nest.flatten(mask)
      if not all([isinstance(m, bool) for m in flat_mask]):
        raise TypeError("Mask should be None or a list of boolean values.")
      nest.assert_same_structure(initial_state, mask)

    self._mask = mask
    self._initial_state = initial_state
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def _testACT(self, input_size, hidden_size, output_size, seq_len, batch_size,
               core, get_state_for_halting):
    threshold = 0.99
    act = pondering_rnn.ACTCore(
        core, output_size, threshold, get_state_for_halting)
    seq_input = tf.random_uniform(shape=(seq_len, batch_size, input_size))
    initial_state = core.initial_state(batch_size)
    seq_output = tf.nn.dynamic_rnn(
        act, seq_input, time_major=True, initial_state=initial_state)
    for tensor in nest.flatten(seq_output):
      self.assertEqual(seq_input.dtype, tensor.dtype)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      output = sess.run(seq_output)
    (final_out, (iteration, r_t)), final_cumul_state = output
    self.assertEqual((seq_len, batch_size, output_size),
                     final_out.shape)
    self.assertEqual((seq_len, batch_size, 1),
                     iteration.shape)
    self.assertTrue(np.all(iteration == np.floor(iteration)))
    state_shape = get_state_for_halting(initial_state).get_shape().as_list()
    self.assertEqual(tuple(state_shape),
                     get_state_for_halting(final_cumul_state).shape)
    self.assertEqual((seq_len, batch_size, 1), r_t.shape)
    self.assertTrue(np.all(r_t >= 0))
    self.assertTrue(np.all(r_t <= threshold))
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def testInitialStateTuple(self, trainable, use_custom_initial_value,
                            state_size):
    batch_size = 6

    # Set the attribute to the class since it we can't set properties of
    # abstract classes
    snt.RNNCore.state_size = state_size
    flat_state_size = nest.flatten(state_size)
    core = snt.RNNCore(name="dummy_core")
    if use_custom_initial_value:
      flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size)
      trainable_initializers = nest.pack_sequence_as(
          structure=state_size, flat_sequence=flat_initializer)
    else:
      trainable_initializers = None
    initial_state = core.initial_state(
        batch_size, dtype=tf.float32, trainable=trainable,
        trainable_initializers=trainable_initializers)

    nest.assert_same_structure(initial_state, state_size)
    flat_initial_state = nest.flatten(initial_state)

    for state, size in zip(flat_initial_state, flat_state_size):
      self.assertEqual(state.get_shape(), [batch_size, size])

    with self.test_session() as sess:
      tf.global_variables_initializer().run()
      flat_initial_state_value = sess.run(flat_initial_state)
      for value, size in zip(flat_initial_state_value, flat_state_size):
        expected_initial_state = np.empty([batch_size, size])
        if not trainable:
          expected_initial_state.fill(0)
        elif use_custom_initial_value:
          expected_initial_state.fill(2)
        else:
          value_row = value[0]
          expected_initial_state = np.tile(value_row, (batch_size, 1))
        self.assertAllClose(value, expected_initial_state)
项目:tensorflow_end2end_speech_recognition    作者:hirofumi0810    | 项目源码 | 文件源码
def __init__(self, encoder_outputs, decoder_state_size):
        self.encoder_outputs = encoder_outputs
        self.decoder_state_size = decoder_state_size
        self.batch_size = tf.shape(
            nest.flatten(self.encoder_outputs.final_state)[0])[0]
项目:tensorflow_end2end_speech_recognition    作者:hirofumi0810    | 项目源码 | 文件源码
def batch_size(self):
        return tf.shape(nest.flatten([self.initial_state])[0])[0]
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def _assert_correct_outputs(self, initial_state_):
        initial_state_flat_ = nest.flatten(initial_state_)
        for element in initial_state_flat_:
            self.assertAllEqual(element, np.zeros_like(element))
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def _assert_correct_outputs(self, initial_state_):
        nest.assert_same_structure(initial_state_, self.decoder_cell.state_size)
        nest.assert_same_structure(
            initial_state_, self.encoder_outputs.final_state)

        encoder_state_flat = nest.flatten(self.encoder_outputs.final_state)
        with self.test_session() as sess:
            encoder_state_flat_ = sess.run(encoder_state_flat)

        initial_state_flat_ = nest.flatten(initial_state_)
        for e_dec, e_enc in zip(initial_state_flat_, encoder_state_flat_):
            self.assertAllEqual(e_dec, e_enc)
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def batch_size(self):
        return tf.shape(nest.flatten([self.initial_state])[0])[0]
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def __init__(self, encoder_outputs, decoder_state_size, params, mode, reuse):
        Configurable.__init__(self, params, mode, reuse)
        self.encoder_outputs = encoder_outputs
        self.decoder_state_size = decoder_state_size
        self.batch_size = tf.shape(
            nest.flatten(self.encoder_outputs.final_state)[0])[0]
项目:neural-chat    作者:henriblancke    | 项目源码 | 文件源码
def nest_map(func, nested):
    if not nest.is_sequence(nested):
        return func(nested)
    flat = nest.flatten(nested)
    return nest.pack_sequence_as(nested, list(map(func, flat)))
项目:neural-chat    作者:henriblancke    | 项目源码 | 文件源码
def __call__(self, inputs, state, scope=None):
        varscope = scope or tf.get_variable_scope()

        flat_inputs = nest.flatten(inputs)
        flat_state = nest.flatten(state)

        flat_inputs_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_inputs]))
        flat_state_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_state]))

        flat_output_unstacked = []
        flat_next_state_unstacked = []
        output_sample = None
        next_state_sample = None

        for i, (inputs_k, state_k) in enumerate(zip(flat_inputs_unstacked, flat_state_unstacked)):
            inputs_k = nest.pack_sequence_as(inputs, inputs_k)
            state_k = nest.pack_sequence_as(state, state_k)

            if i == 0:
                output_k, next_state_k = self.cell(inputs_k, state_k,
                                                   scope=scope)
            else:
                with tf.variable_scope(varscope, reuse=True):
                    output_k, next_state_k = self.cell(inputs_k, state_k,
                                                       scope=varscope if scope is not None else None)

            flat_output_unstacked.append(nest.flatten(output_k))
            flat_next_state_unstacked.append(nest.flatten(next_state_k))

            output_sample = output_k
            next_state_sample = next_state_k

        flat_output = [tf.stack(tensors, axis=1) for tensors in zip(*flat_output_unstacked)]
        flat_next_state = [tf.stack(tensors, axis=1) for tensors in zip(*flat_next_state_unstacked)]

        output = nest.pack_sequence_as(output_sample, flat_output)
        next_state = nest.pack_sequence_as(next_state_sample, flat_next_state)

        return output, next_state
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell with attention (LSTMA)."""
    with vs.variable_scope(scope or type(self).__name__):
      if self._state_is_tuple:
        state, attns, attn_states = state
      else:
        states = state
        state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
        attns = array_ops.slice(
            states, [0, self._cell.state_size], [-1, self._attn_size])
        attn_states = array_ops.slice(
            states, [0, self._cell.state_size + self._attn_size],
            [-1, self._attn_size * self._attn_length])
      attn_states = array_ops.reshape(attn_states,
                                      [-1, self._attn_length, self._attn_size])
      input_size = self._input_size
      if input_size is None:
        input_size = inputs.get_shape().as_list()[1]
      inputs = _linear([inputs, attns], input_size, True)
      lstm_output, new_state = self._cell(inputs, state)
      if self._state_is_tuple:
        new_state_cat = array_ops.concat(1, nest.flatten(new_state))
      else:
        new_state_cat = new_state
      new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
      with vs.variable_scope("AttnOutputProjection"):
        output = _linear([lstm_output, new_attns], self._attn_size, True)
      new_attn_states = array_ops.concat(1, [new_attn_states,
                                             array_ops.expand_dims(output, 1)])
      new_attn_states = array_ops.reshape(
          new_attn_states, [-1, self._attn_length * self._attn_size])
      new_state = (new_state, new_attns, new_attn_states)
      if not self._state_is_tuple:
        new_state = array_ops.concat(1, list(new_state))
      return output, new_state
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell with attention (LSTMA)."""
    with vs.variable_scope(scope or type(self).__name__):
      if self._state_is_tuple:
        state, attns, attn_states = state
      else:
        states = state
        state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
        attns = array_ops.slice(
            states, [0, self._cell.state_size], [-1, self._attn_size])
        attn_states = array_ops.slice(
            states, [0, self._cell.state_size + self._attn_size],
            [-1, self._attn_size * self._attn_length])
      attn_states = array_ops.reshape(attn_states,
                                      [-1, self._attn_length, self._attn_size])
      input_size = self._input_size
      if input_size is None:
        input_size = inputs.get_shape().as_list()[1]
      inputs = _linear([inputs, attns], input_size, True)
      lstm_output, new_state = self._cell(inputs, state)
      if self._state_is_tuple:
        new_state_cat = array_ops.concat(1, nest.flatten(new_state))
      else:
        new_state_cat = new_state
      new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
      with vs.variable_scope("AttnOutputProjection"):
        output = _linear([lstm_output, new_attns], self._attn_size, True)
      new_attn_states = array_ops.concat(1, [new_attn_states,
                                             array_ops.expand_dims(output, 1)])
      new_attn_states = array_ops.reshape(
          new_attn_states, [-1, self._attn_length * self._attn_size])
      new_state = (new_state, new_attns, new_attn_states)
      if not self._state_is_tuple:
        new_state = array_ops.concat(1, list(new_state))
      return output, new_state
项目:TextGAN    作者:ankitkv    | 项目源码 | 文件源码
def nest_map(func, nested):
    if not nest.is_sequence(nested):
        return func(nested)
    flat = nest.flatten(nested)
    return nest.pack_sequence_as(nested, list(map(func, flat)))
项目:TextGAN    作者:ankitkv    | 项目源码 | 文件源码
def wrap_state(self, state):
        dummy = BeamDecoderCellWrapper(None, self.batch_concat, self.num_classes,
                                       self.max_len, self.stop_token, self.beam_size,
                                       self.min_op, self.min_frac)
        if nest.is_sequence(state):
            batch_size = tf.shape(nest.flatten(state)[0])[0]
            dtype = nest.flatten(state)[0].dtype
        else:
            batch_size = tf.shape(state)[0]
            dtype = state.dtype
        return dummy._create_state(batch_size, dtype, cell_state=state)
项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow    作者:liuyuemaicha    | 项目源码 | 文件源码
def nest_map(func, nested):
    if not nest.is_sequence(nested):
        return func(nested)
    flat = nest.flatten(nested)
    return nest.pack_sequence_as(nested, list(map(func, flat)))
项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow    作者:liuyuemaicha    | 项目源码 | 文件源码
def __call__(self, inputs, state, scope=None):
        varscope = scope or tf.get_variable_scope()

        flat_inputs = nest.flatten(inputs)
        flat_state = nest.flatten(state)

        flat_inputs_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_inputs]))
        flat_state_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_state]))

        flat_output_unstacked = []
        flat_next_state_unstacked = []
        output_sample = None
        next_state_sample = None

        for i, (inputs_k, state_k) in enumerate(zip(flat_inputs_unstacked, flat_state_unstacked)):
            inputs_k = nest.pack_sequence_as(inputs, inputs_k)
            state_k = nest.pack_sequence_as(state, state_k)

            if i == 0:
                output_k, next_state_k = self.cell(inputs_k, state_k, scope=scope)
            else:
                with tf.variable_scope(varscope, reuse=True):
                    output_k, next_state_k = self.cell(inputs_k, state_k, scope=varscope if scope is not None else None)

            flat_output_unstacked.append(nest.flatten(output_k))
            flat_next_state_unstacked.append(nest.flatten(next_state_k))

            output_sample = output_k
            next_state_sample = next_state_k

        flat_output = [tf.stack(tensors, axis=1) for tensors in zip(*flat_output_unstacked)]
        flat_next_state = [tf.stack(tensors, axis=1) for tensors in zip(*flat_next_state_unstacked)]

        output = nest.pack_sequence_as(output_sample, flat_output)
        next_state = nest.pack_sequence_as(next_state_sample, flat_next_state)

        return output, next_state
项目:diversity_based_attention    作者:PrekshaNema25    | 项目源码 | 文件源码
def zero_state(self, batch_size, dtype):
    """Return zero-filled state tensor(s).

    Args:
      batch_size: int, float, or unit Tensor representing the batch size.
      dtype: the data type to use for the state.

    Returns:
      If `state_size` is an int or TensorShape, then the return value is a
      `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

      If `state_size` is a nested list or tuple, then the return value is
      a nested list or tuple (of the same structure) of `2-D` tensors with
    the shapes `[batch_size x s]` for each s in `state_size`.
    """
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
              dtype=dtype)
          for s in state_size_flat]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros
项目:diversity_based_attention    作者:PrekshaNema25    | 项目源码 | 文件源码
def _infer_state_dtype(explicit_dtype, state):
  """Infer the dtype of an RNN state.

  Args:
    explicit_dtype: explicitly declared dtype or None.
    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
      Tensors.

  Returns:
    dtype: inferred dtype of hidden state.

  Raises:
    ValueError: if `state` has heterogeneous dtypes or is empty.
  """
  if explicit_dtype is not None:
    return explicit_dtype
  elif nest.is_sequence(state):
    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
    if not inferred_dtypes:
      raise ValueError("Unable to infer dtype from empty state.")
    all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
    if not all_same:
      raise ValueError(
          "State has tensors of different inferred_dtypes. Unable to infer a "
          "single representative dtype.")
    return inferred_dtypes[0]
  else:
    return state.dtype
项目:tf-tutorial    作者:zchen0211    | 项目源码 | 文件源码
def __init__(self, initial_state, mask=None, name="trainable_initial_state"):
    """Constructs the Module that introduces a trainable state in the graph.

    It receives an initial state that will be used as the intial values for the
    trainable variables that the module contains, and optionally a mask that
    indicates the parts of the initial state that should be learnable.

    Args:
      initial_state: tensor or arbitrarily nested iterables of tensors.
      mask: optional boolean mask. It should have the same nested structure as
       the given initial_state.
      name: module name.

    Raises:
      TypeError: if mask is not a list of booleans or None.
    """
    super(TrainableInitialState, self).__init__(name=name)

    # Since python 2.7, DeprecationWarning is ignored by default.
    # Turn on the warning:
    warnings.simplefilter("always", DeprecationWarning)
    warnings.warn("Use the trainable flag in initial_state instead.",
                  DeprecationWarning, stacklevel=2)

    if mask is not None:
      flat_mask = nest.flatten(mask)
      if not all([isinstance(m, bool) for m in flat_mask]):
        raise TypeError("Mask should be None or a list of boolean values.")
      nest.assert_same_structure(initial_state, mask)

    self._mask = mask
    self._initial_state = initial_state
项目:ROLO    作者:Guanghan    | 项目源码 | 文件源码
def zero_state(self, batch_size, dtype):
    """Return zero-filled state tensor(s).

    Args:
      batch_size: int, float, or unit Tensor representing the batch size.
      dtype: the data type to use for the state.

    Returns:
      If `state_size` is an int or TensorShape, then the return value is a
      `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

      If `state_size` is a nested list or tuple, then the return value is
      a nested list or tuple (of the same structure) of `2-D` tensors with
    the shapes `[batch_size x s]` for each s in `state_size`.
    """
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
              dtype=dtype)
          for s in state_size_flat]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros
项目:ROLO    作者:Guanghan    | 项目源码 | 文件源码
def _infer_state_dtype(explicit_dtype, state):
  """Infer the dtype of an RNN state.

  Args:
    explicit_dtype: explicitly declared dtype or None.
    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
      Tensors.

  Returns:
    dtype: inferred dtype of hidden state.

  Raises:
    ValueError: if `state` has heterogeneous dtypes or is empty.
  """
  if explicit_dtype is not None:
    return explicit_dtype
  elif nest.is_sequence(state):
    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
    if not inferred_dtypes:
      raise ValueError("Unable to infer dtype from empty state.")
    all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
    if not all_same:
      raise ValueError(
          "State has tensors of different inferred_dtypes. Unable to infer a "
          "single representative dtype.")
    return inferred_dtypes[0]
  else:
    return state.dtype
项目:ROLO    作者:Guanghan    | 项目源码 | 文件源码
def _reverse_seq(input_seq, lengths):
  """Reverse a list of Tensors up to specified lengths.

  Args:
    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
               or nested tuples of tensors.
    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
               sequence in the batch. If "None" is specified, simply reverses
               the list.

  Returns:
    time-reversed sequence
  """
  if lengths is None:
    return list(reversed(input_seq))

  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)

  flat_results = [[] for _ in range(len(input_seq))]
  for sequence in zip(*flat_input_seq):
    input_shape = tensor_shape.unknown_shape(
        ndims=sequence[0].get_shape().ndims)
    for input_ in sequence:
      input_shape.merge_with(input_.get_shape())
      input_.set_shape(input_shape)

    # Join into (time, batch_size, depth)
    s_joined = array_ops.pack(sequence)

    # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32
    if lengths is not None:
      lengths = math_ops.to_int64(lengths)

    # Reverse along dimension 0
    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
    # Split again into list
    result = array_ops.unpack(s_reversed)
    for r, flat_result in zip(result, flat_results):
      r.set_shape(input_shape)
      flat_result.append(r)

  results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
             for input_, flat_result in zip(input_seq, flat_results)]
  return results