我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.util.nest.map_structure()。
def batch(states, batch_size=None): """Combines a collection of state structures into a batch, padding if needed. Args: states: A collection of individual nested state structures. batch_size: The desired final batch size. If the nested state structure that results from combining the states is smaller than this, it will be padded with zeros. Returns: A single state structure that results from stacking the structures in `states`, with padding if needed. Raises: ValueError: If the number of input states is larger than `batch_size`. """ if batch_size and len(states) > batch_size: raise ValueError('Combined state is larger than the requested batch size') def stack_and_pad(*states): stacked = np.stack(states) if batch_size: stacked.resize([batch_size] + list(stacked.shape)[1:]) return stacked return tf_nest.map_structure(stack_and_pad, *states)
def __call__(self, inputs, state, scope=None): """Run the cell and add its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, outputs) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.map_structure(assert_shape_match, inputs, outputs) res_outputs = nest.map_structure( lambda inp, out: math_ops.scalar_mul(0.5, inp + out), inputs, outputs) return res_outputs, new_state
def _create(self): # Concat bridge inputs on the depth dimensions bridge_input = nest.map_structure( lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]), self._bridge_input) bridge_input_flat = nest.flatten([bridge_input]) bridge_input_concat = tf.concat(bridge_input_flat, axis=1) state_size_splits = nest.flatten(self.decoder_state_size) total_decoder_state_size = sum(state_size_splits) # Pass bridge inputs through a fully connected layer layer initial_state_flat = tf.contrib.layers.fully_connected( bridge_input_concat, num_outputs=total_decoder_state_size, activation_fn=self._activation_fn, weights_initializer=tf.truncated_normal_initializer( stddev=self.parameter_init), biases_initializer=tf.zeros_initializer(), scope=None) # Shape back into required state size initial_state = tf.split(initial_state_flat, state_size_splits, axis=1) return nest.pack_sequence_as(self.decoder_state_size, initial_state)
def gnmt_residual_fn(inputs, outputs): """Residual function that handles different inputs and outputs inner dims. Args: inputs: cell inputs, this is actual inputs concatenated with the attention vector. outputs: cell outputs Returns: outputs + actual inputs """ def split_input(inp, out): out_dim = out.get_shape().as_list()[-1] inp_dim = inp.get_shape().as_list()[-1] return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.assert_same_structure(actual_inputs, outputs) nest.map_structure(assert_shape_match, actual_inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
def setUp(self): super(BridgeTest, self).setUp() self.batch_size = 4 self.encoder_cell = MultiRNNCell( [GRUCell(4, None), GRUCell(8, None)]) self.decoder_cell = MultiRNNCell( [LSTMCell(16, None), GRUCell(8, None)]) final_encoder_state = nest.map_structure( lambda x: tf.convert_to_tensor( value=np.random.randn(self.batch_size, x), dtype=tf.float32), self.encoder_cell.state_size) self.encoder_outputs = EncoderOutput( outputs=tf.convert_to_tensor( value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32), attention_values=tf.convert_to_tensor( value=np.random.randn(self.batch_size, 10, 16), dtype=tf.float32), attention_values_length=np.full([self.batch_size], 10), final_state=final_encoder_state)
def _create(self): bridge_input = nest.map_structure( lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]), self._bridge_input) bridge_input_flat = nest.flatten([bridge_input]) bridge_input_concat = tf.concat(bridge_input_flat, 1) state_size_splits = nest.flatten(self.decoder_state_size) total_decoder_state_size = sum(state_size_splits) initial_state_flat = fully_connected( bridge_input_concat, total_decoder_state_size, self._mode, self._reuse, activation=self._activation_fn) initial_state = tf.split(initial_state_flat, state_size_splits, axis=1) return nest.pack_sequence_as(self.decoder_state_size, initial_state)
def __call__(self, inputs, state, scope=None): """Run the cell and then apply the residual_fn on its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) def default_residual_fn(inputs, outputs): nest.assert_same_structure(inputs, outputs) nest.map_structure(assert_shape_match, inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) return (res_outputs, new_state)
def __call__(self, inputs, state, scope=None): """Run the cell and add its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, outputs) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.map_structure(assert_shape_match, inputs, outputs) res_outputs = nest.map_structure( lambda inp, out: inp + out, inputs, outputs) return (res_outputs, new_state)
def __init__(self, training, cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer=None, gold_sequence=None, gold_sequence_length=None): self._training = training self._cell = cell self._output_layer = output_layer self._embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids) self._output_size = output_layer.units if output_layer is not None else self._output.output_size self._batch_size = tf.size(start_tokens) self._beam_width = beam_width self._tiled_initial_cell_state = nest.map_structure(self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = start_tokens self._tiled_start_tokens = self._maybe_tile_batch(start_tokens) self._end_token = end_token self._original_gold_sequence = gold_sequence self._gold_sequence = gold_sequence self._gold_sequence_length = gold_sequence_length if training: assert self._gold_sequence is not None assert self._gold_sequence_length is not None self._max_time = int(self._gold_sequence.shape[1]) # transpose gold sequence to be time major and make it into a TensorArray self._gold_sequence = tf.TensorArray(dtype=tf.int32, size=self._max_time) self._gold_sequence = self._gold_sequence.unstack(tf.transpose(gold_sequence, [1, 0]))
def step(self, time, inputs, state : BeamSearchOptimizationDecoderState , name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ with tf.name_scope(name, "BeamSearchOptimizationDecoderStep", (time, inputs, state)): cell_state = state.cell_state with tf.name_scope('merge_cell_input'): inputs = nest.map_structure(lambda x: self._merge_batch_beams(x, s=x.shape[2:]), inputs) print('inputs', inputs) with tf.name_scope('merge_cell_state'): cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) with tf.name_scope('split_cell_outputs'): cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs, self._output_size) with tf.name_scope('split_cell_state'): next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) beam_search_output, beam_search_state = self._beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = self._embedding_fn(sample_ids) return (beam_search_output, beam_search_state, next_inputs, finished)
def add_decoder_op(self, enc_final_state, enc_hidden_states, output_embed_matrix, training): cell_dec = tf.contrib.rnn.MultiRNNCell([self.make_rnn_cell(i, True) for i in range(self.config.rnn_layers)]) encoder_hidden_size = int(enc_hidden_states.get_shape()[-1]) decoder_hidden_size = int(cell_dec.output_size) # if encoder and decoder have different sizes, add a projection layer if encoder_hidden_size != decoder_hidden_size: assert False, (encoder_hidden_size, decoder_hidden_size) with tf.variable_scope('hidden_projection'): kernel = tf.get_variable('kernel', (encoder_hidden_size, decoder_hidden_size), dtype=tf.float32) # apply a relu to the projection for good measure enc_final_state = nest.map_structure(lambda x: tf.nn.relu(tf.matmul(x, kernel)), enc_final_state) enc_hidden_states = tf.nn.relu(tf.tensordot(enc_hidden_states, kernel, [[2], [1]])) else: # flatten and repack the state enc_final_state = nest.pack_sequence_as(cell_dec.state_size, nest.flatten(enc_final_state)) if self.config.connect_output_decoder: cell_dec = ParentFeedingCellWrapper(cell_dec, enc_final_state) else: cell_dec = InputIgnoringCellWrapper(cell_dec, enc_final_state) if self.config.apply_attention: attention = LuongAttention(self.config.decoder_hidden_size, enc_hidden_states, self.input_length_placeholder, probability_fn=tf.nn.softmax) cell_dec = AttentionWrapper(cell_dec, attention, cell_input_fn=lambda inputs, _: inputs, attention_layer_size=self.config.decoder_hidden_size, initial_cell_state=enc_final_state) enc_final_state = cell_dec.zero_state(self.batch_size, dtype=tf.float32) decoder = Seq2SeqDecoder(self.config, self.input_placeholder, self.input_length_placeholder, self.output_placeholder, self.output_length_placeholder, self.batch_number_placeholder) return decoder.decode(cell_dec, enc_final_state, self.config.grammar.output_size, output_embed_matrix, training)
def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" def _t(s): return (s if isinstance(s, ops.Tensor) else constant_op.constant( tensor_shape.TensorShape(s).as_list(), dtype=dtypes.int32, name="zero_suffix_shape")) def _create(s, d): return array_ops.zeros( array_ops.concat( ([batch_size], _t(s)), axis=0), dtype=d) return nest.map_structure(_create, size, dtype)
def __init__(self, inputs, sequence_length, time_major=False, name=None): """Initializer. Args: inputs: A (structure of) input tensors. sequence_length: An int32 vector tensor. time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. name: Name scope for any created operations. Raises: ValueError: if `sequence_length` is not a 1D tensor. """ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) self._input_tas = nest.map_structure(_unstack_ta, inputs) self._sequence_length = ops.convert_to_tensor( sequence_length, name="sequence_length") if self._sequence_length.get_shape().ndims != 1: raise ValueError( "Expected sequence_length to be a vector, but received shape: %s" % self._sequence_length.get_shape()) self._zero_inputs = nest.map_structure( lambda inp: array_ops.zeros_like(inp[0, :]), inputs) self._batch_size = array_ops.size(sequence_length)
def initialize(self, name=None): with ops.name_scope(name, "TrainingHelperInitialize"): finished = math_ops.equal(0, self._sequence_length) all_finished = math_ops.reduce_all(finished) next_inputs = control_flow_ops.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) return (finished, next_inputs)
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): """next_inputs_fn for TrainingHelper.""" with ops.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]): next_time = time + 1 finished = (next_time >= self._sequence_length) all_finished = math_ops.reduce_all(finished) def read_from_ta(inp): return inp.read(next_time) next_inputs = control_flow_ops.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(read_from_ta, self._input_tas)) return (finished, next_inputs, state)
def _enumerated_map_structure(map_fn, *args, **kwargs): ix = [0] def enumerated_fn(*inner_args, **inner_kwargs): r = map_fn(ix[0], *inner_args, **inner_kwargs) ix[0] += 1 return r return nest.map_structure(enumerated_fn, *args, **kwargs)
def extract_state(batched_states, i): """Extracts a single state from a batch of states. Args: batched_states: A nested structure with entries whose first dimensions all equal N. i: The index of the state to extract. Returns: A tuple containing tensors (or tuples of tensors) of the same structure as rnn_nade_state, but containing only the state values that represent the state at index i. The tensors will now have the shape (1, N). """ return tf_nest.map_structure(lambda x: x[i], batched_states)
def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output has " "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with ops.control_dependencies( [check_ops.assert_equal(batch_size, self._attention_mechanism.batch_size, message=error_message)]): cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="checked_cell_state"), cell_state) if self._alignment_history: alignment_history = tensor_array_ops.TensorArray( dtype=dtype, size=0, dynamic_size=True) else: alignment_history = () return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_size, batch_size, dtype), alignments=self._attention_mechanism.initial_alignments( batch_size, dtype), alignment_history=alignment_history)
def _build(self, inputs): outputs = self._base_module(inputs) residual = nest.map_structure(lambda inp, out: inp + out, inputs, outputs) return residual
def _build(self, inputs, prev_state): outputs, new_state = self._base_core(inputs, prev_state) residual = nest.map_structure(lambda inp, out: inp + out, inputs, outputs) return residual, new_state
def _build(self, inputs, prev_state): if not self._input_shape: self._input_shape = inputs.get_shape()[1:] outputs, new_state = self._base_core(inputs, prev_state) outputs = nest.map_structure(lambda inp, out: tf.concat((inp, out), -1), inputs, outputs) return outputs, new_state
def _build(self, inputs): if isinstance(inputs, (NotATensor, tf.SparseTensor)): outputs = inputs else: if self.no_nest: outputs = inputs else: outputs = nest.map_structure(tf.identity, inputs) return outputs
def _get_shape_without_batch_dimension(tensor_nest): """Converts Tensor nest to a TensorShape nest, removing batch dimension.""" def _strip_batch_and_convert_to_shape(tensor): return tensor[0].get_shape() return nest.map_structure(_strip_batch_and_convert_to_shape, tensor_nest)
def _create(self): zero_state = nest.map_structure( lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32), self.decoder_state_size) return zero_state
def finalize(self, outputs, final_state): # Gather according to beam search result predicted_ids = beam_search.gather_tree(outputs.predicted_ids, outputs.beam_parent_ids) # We're using a batch size of 1, so we add an extra dimension to # convert tensors to [1, beam_width, ...] shape. This way Tensorflow # doesn't confuse batch_size with beam_width outputs = nest.map_structure(lambda x: tf.expand_dims(x, 1), outputs) final_outputs = FinalBeamDecoderOutput( predicted_ids=tf.expand_dims(predicted_ids, 1), beam_search_output=outputs) return final_outputs, final_state
def _build(self, initial_state, helper): # Tile initial state initial_state = nest.map_structure( lambda x: tf.tile(x, [self.batch_size, 1]), initial_state) self.decoder._setup(initial_state, helper) return super(BeamSearchDecoder, self)._build(self.decoder.initial_state, self.decoder.helper)
def step(self, time_, inputs, state, name=None): decoder_state, beam_state = state # Call the original decoder (decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs, decoder_state) # Perform a step of beam search bs_output, beam_state = beam_search.beam_search_step( time_=time_, logits=decoder_output.logits, beam_state=beam_state, config=self.config) # Shuffle everything according to beam search result decoder_state = nest.map_structure( lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state) decoder_output = nest.map_structure( lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output) next_state = (decoder_state, beam_state) outputs = BeamDecoderOutput( logits=tf.zeros([self.config.beam_width, self.config.vocab_size]), predicted_ids=bs_output.predicted_ids, log_probs=beam_state.log_probs, scores=bs_output.scores, beam_parent_ids=bs_output.beam_parent_ids, original_outputs=decoder_output) finished, next_inputs, next_state = self.decoder.helper.next_inputs( time=time_, outputs=decoder_output, state=next_state, sample_ids=bs_output.predicted_ids) next_inputs.set_shape([self.batch_size, None]) return (outputs, next_state, next_inputs, finished)
def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" def _t(s): return (s if isinstance(s, tf.Tensor) else tf.constant( tf.TensorShape(s).as_list(), dtype=tf.int32, name="zero_suffix_shape")) def _create(s, d): return tf.zeros( tf.concat( ([batch_size], _t(s)), axis=0), dtype=d) return nest.map_structure(_create, size, dtype)
def __init__(self, inputs, sequence_length, time_major=False, name=None): """Initializer. Args: inputs: A (structure of) input tensors. sequence_length: An int32 vector tensor. time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. name: Name scope for any created operations. Raises: ValueError: if `sequence_length` is not a 1D tensor. """ with tf.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = tf.convert_to_tensor(inputs, name="inputs") if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) self._input_tas = nest.map_structure(_unstack_ta, inputs) self._sequence_length = tf.convert_to_tensor( sequence_length, name="sequence_length") if self._sequence_length.get_shape().ndims != 1: raise ValueError( "Expected sequence_length to be a vector, but received shape: %s" % self._sequence_length.get_shape()) self._zero_inputs = nest.map_structure( lambda inp: tf.zeros_like(inp[0, :]), inputs) self._batch_size = tf.size(sequence_length)
def initialize(self, name=None): with tf.name_scope(name, "TrainingHelperInitialize"): finished = tf.equal(0, self._sequence_length) all_finished = tf.reduce_all(finished) next_inputs = tf.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) return (finished, next_inputs)
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): """next_inputs_fn for TrainingHelper.""" with tf.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]): next_time = time + 1 finished = (next_time >= self._sequence_length) all_finished = tf.reduce_all(finished) def read_from_ta(inp): return inp.read(next_time) next_inputs = tf.cond( all_finished, lambda: self._zero_inputs, lambda: nest.map_structure(read_from_ta, self._input_tas)) return (finished, next_inputs, state)
def __init__(self, cell, residual_fn=None): """Constructs a `ResidualWrapper` for `cell`. Args: cell: An instance of `RNNCell`. residual_fn: (Optional) The function to map raw cell inputs and raw cell outputs to the actual cell outputs of the residual network. Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs. """ self._cell = cell self._residual_fn = residual_fn
def _pack(*args): return [nest.map_structure( lambda item: batch_repeat_pack(item), arg ) for arg in args]
def _unpack(*args, repeats=1): return [nest.map_structure( lambda item: batch_repeat_unpack(item, repeats=repeats), arg ) for arg in args]
def _zero_state_tensors(state_size, batch_size, dtype): """Create tensors of zeros based on state_size, batch_size, and dtype.""" def get_state_shape(s): """Combine s with batch_size to get a proper tensor shape.""" c = _concat(batch_size, s) c_static = _concat(batch_size, s, static=True) size = array_ops.zeros(c, dtype=dtype) size.set_shape(c_static) return size return nest.map_structure(get_state_shape, state_size)
def add_decoder_op(self, enc_final_state, enc_hidden_states, output_embed_matrix, training): cell_dec = tf.contrib.rnn.MultiRNNCell([self.make_rnn_cell(i, for_decoder=True) for i in range(self.config.rnn_layers)]) encoder_hidden_size = int(enc_hidden_states.get_shape()[-1]) decoder_hidden_size = int(cell_dec.output_size) # if encoder and decoder have different sizes, add a projection layer if encoder_hidden_size != decoder_hidden_size: assert False, (encoder_hidden_size, decoder_hidden_size) with tf.variable_scope('hidden_projection'): kernel = tf.get_variable('kernel', (encoder_hidden_size, decoder_hidden_size), dtype=tf.float32) # apply a relu to the projection for good measure enc_final_state = nest.map_structure(lambda x: tf.nn.relu(tf.matmul(x, kernel)), enc_final_state) enc_hidden_states = tf.nn.relu(tf.tensordot(enc_hidden_states, kernel, [[2], [1]])) else: # flatten and repack the state enc_final_state = nest.pack_sequence_as(cell_dec.state_size, nest.flatten(enc_final_state)) beam_width = self.config.training_beam_size if training else self.config.beam_size #cell_dec = ParentFeedingCellWrapper(cell_dec, tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width)) if self.config.apply_attention: attention = LuongAttention(decoder_hidden_size, tf.contrib.seq2seq.tile_batch(enc_hidden_states, beam_width), tf.contrib.seq2seq.tile_batch(self.input_length_placeholder, beam_width), probability_fn=tf.nn.softmax) cell_dec = AttentionWrapper(cell_dec, attention, cell_input_fn=lambda inputs, _: inputs, attention_layer_size=decoder_hidden_size, initial_cell_state=tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width)) enc_final_state = cell_dec.zero_state(self.batch_size * beam_width, dtype=tf.float32) else: enc_final_state = tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width) print('enc_final_state', enc_final_state) linear_layer = tf_core_layers.Dense(self.config.output_size) go_vector = tf.ones((self.batch_size,), dtype=tf.int32) * self.config.grammar.start decoder = BeamSearchOptimizationDecoder(training, cell_dec, output_embed_matrix, go_vector, self.config.grammar.end, enc_final_state, beam_width=beam_width, output_layer=linear_layer, gold_sequence=self.output_placeholder if training else None, gold_sequence_length=(self.output_length_placeholder+1) if training else None) if self.config.use_grammar_constraints: raise NotImplementedError("Grammar constraints are not implemented for the beam search yet") final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, maximum_iterations=self.config.max_length) return final_outputs
def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", [time, outputs, state, sample_ids]): (finished, base_next_inputs, state) = ( super(ScheduledOutputTrainingHelper, self).next_inputs( time=time, outputs=outputs, state=state, sample_ids=sample_ids, name=name)) def maybe_sample(): """Perform scheduled sampling.""" def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): """Concatenate outputs with auxiliary inputs, if they exist.""" if self._auxiliary_input_tas is None: return outputs_ next_time = time + 1 auxiliary_inputs = nest.map_structure( lambda ta: ta.read(next_time), self._auxiliary_input_tas) if indices is not None: auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) return nest.map_structure( lambda x, y: array_ops.concat((x, y), -1), outputs_, auxiliary_inputs) if self._next_input_layer is None: return array_ops.where( sample_ids, maybe_concatenate_auxiliary_inputs(outputs), base_next_inputs) where_sampling = math_ops.cast( array_ops.where(sample_ids), dtypes.int32) where_not_sampling = math_ops.cast( array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) outputs_sampling = array_ops.gather_nd(outputs, where_sampling) inputs_not_sampling = array_ops.gather_nd(base_next_inputs, where_not_sampling) sampled_next_inputs = maybe_concatenate_auxiliary_inputs( self._next_input_layer(outputs_sampling), where_sampling) base_shape = array_ops.shape(base_next_inputs) return (array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape)) all_finished = math_ops.reduce_all(finished) next_inputs = control_flow_ops.cond( all_finished, lambda: base_next_inputs, maybe_sample) return (finished, next_inputs, state)