我们从Python开源项目中,提取了以下29个代码示例,用于说明如何使用tensorflow.unsorted_segment_sum()。
def segment_softmax(scores, segment_ids): """Given scores and a partition, converts scores to probs by performing softmax over all rows within a partition.""" # Subtract max num_segments = tf.reduce_max(segment_ids) + 1 if len(scores.get_shape()) == 2: max_per_partition = tf.unsorted_segment_max(tf.reduce_max(scores, axis=1), segment_ids, num_segments) scores -= tf.expand_dims(tf.gather(max_per_partition, segment_ids), axis=1) else: max_per_partition = tf.unsorted_segment_max(scores, segment_ids, num_segments) scores -= tf.gather(max_per_partition, segment_ids) # Compute probs scores_exp = tf.exp(scores) if len(scores.get_shape()) == 2: scores_exp_sum_per_partition = tf.unsorted_segment_sum(tf.reduce_sum(scores_exp, axis=1), segment_ids, num_segments) probs = scores_exp / tf.expand_dims(tf.gather(scores_exp_sum_per_partition, segment_ids), axis=1) else: scores_exp_sum_per_partition = tf.unsorted_segment_sum(scores_exp, segment_ids, num_segments) probs = scores_exp / tf.gather(scores_exp_sum_per_partition, segment_ids) return probs
def density_map(tensor, shape): """ """ height, width, channels = shape bins = max(height, width) # values = value_map(tensor, shape, keep_dims=True) # values = tf.minimum(tf.maximum(tensor, 0.0), 1.0) # TODO: Get this to work with HDR data values = tensor # https://stackoverflow.com/a/34143927 binned_values = tf.cast(tf.reshape(values * (bins - 1), [-1]), tf.int32) ones = tf.ones_like(binned_values, dtype=tf.int32) counts = tf.unsorted_segment_sum(ones, binned_values, bins) out = tf.gather(counts, tf.cast(values[:, :] * (bins - 1), tf.int32)) return tf.ones(shape) * normalize(tf.cast(out, tf.float32))
def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None): if s_embed_flat is None: s_embed_flat = tf.reshape( s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE]) with tf.variable_scope(self.name): s_src_assignment = tf.argmax(s_src_pwr, axis=1) s_indices = tf.reshape( s_src_assignment, [hparams.BATCH_SIZE, -1]) fn_segmean = lambda _: tf.unsorted_segment_sum( _[0], _[1], hparams.MAX_N_SIGNAL) s_attractors = tf.map_fn( fn_segmean, (s_embed_flat, s_indices), hparams.FLOATX) s_attractors_wgt = tf.map_fn( fn_segmean, (tf.ones_like(s_embed_flat), s_indices), hparams.FLOATX) s_attractors /= (s_attractors_wgt + 1.) if hparams.DEBUG: self.debug_fetches = dict() # float[B, C, E] return s_attractors
def find_dup(a): """ Find the duplicated elements in 1-D a tensor. Args: a: 1-D tensor. Return: more_than_one_vals: duplicated value in a. indexes_in_a: duplicated value's index in a. dups_in_a: duplicated value with duplicate in a. """ unique_a_vals, unique_idx = tf.unique(a) count_a_unique = tf.unsorted_segment_sum(tf.ones_like(a), unique_idx, tf.shape(a)[0]) more_than_one = tf.greater(count_a_unique, 1) more_than_one_idx = tf.squeeze(tf.where(more_than_one)) more_than_one_vals = tf.squeeze(tf.gather(unique_a_vals, more_than_one_idx)) not_duplicated, _ = tf.setdiff1d(a, more_than_one_vals) dups_in_a, indexes_in_a = tf.setdiff1d(a, not_duplicated) return more_than_one_vals, indexes_in_a, dups_in_a
def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. The slice corresponding to a particular batch element `b` is computed as the sum over all experts `i` of the expert output, weighted by the corresponding gate values. If `multiply_by_gates` is set to False, the gate values are ignored. Args: expert_out: a list of `num_experts` `Tensor`s, each with shape `[expert_batch_size_i, <extra_output_dims>]`. multiply_by_gates: a boolean Returns: a `Tensor` with shape `[batch_size, <extra_output_dims>]`. """ # see comments on convert_gradient_to_tensor stitched = convert_gradient_to_tensor(tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, tf.shape(self._gates)[0]) return combined
def combine(self, x): """Return the output from the experts. When one example goes to multiple experts, the outputs are summed. Args: x: a Tensor with shape [batch, num_experts, expert_capacity, depth] Returns: a `Tensor` with shape `[batch, length, depth] """ depth = tf.shape(x)[-1] x *= tf.expand_dims(self._nonpadding, -1) ret = tf.unsorted_segment_sum( x, self._flat_indices, num_segments=self._batch * self._length) ret = tf.reshape(ret, [self._batch, self._length, depth]) return ret
def compute_mean(cluster_center, x, label, K, eta): """ Compute Mean Input: x: embedding of size N x D label: cluster label of size N X 1 K: number of clusters tf_eps: small constant Output: cluster_center: cluster center of size K x D """ tf_eps = tf.constant(1.0e-16) cluster_size = tf.expand_dims(tf.unsorted_segment_sum( tf.ones(label.get_shape()), label, K), 1) cluster_center_new = (1 - eta) * tf.unsorted_segment_sum(x, label, K) / (cluster_size + tf_eps) + eta * cluster_center return cluster_center.assign(cluster_center_new)
def _sum_attentions(attentions, document): assert static_rank(attentions) == 2 and static_rank(document) == 2 num_entities = tf.reduce_max(document) + 1 @func_scope() def _sum_attention(args): attentions, document = args assert static_rank(attentions) == 1 and static_rank(document) == 1 return tf.unsorted_segment_sum(attentions, document, num_entities) attentions = tf.map_fn(_sum_attention, [attentions, document], dtype=FLAGS.float_type) return attentions[:, FLAGS.first_entity_index:FLAGS.last_entity_index + 1]
def xqa_crossentropy_loss(start_scores, end_scores, answer_span, answer2support, support2question, use_sum=True): """Very common XQA loss function.""" num_questions = tf.reduce_max(support2question) + 1 start, end = answer_span[:, 0], answer_span[:, 1] start_probs = segment_softmax(start_scores, support2question) start_probs = tf.gather_nd(start_probs, tf.stack([answer2support, start], 1)) # only start probs are normalized on multi-paragraph, end probs conditioned on start only on per support level num_answers = tf.shape(answer_span)[0] is_aligned = tf.equal(tf.shape(end_scores)[0], num_answers) end_probs = tf.cond( is_aligned, lambda: tf.gather_nd(tf.nn.softmax(end_scores), tf.stack([tf.range(num_answers, dtype=tf.int32), end], 1)), lambda: tf.gather_nd(segment_softmax(end_scores, support2question), tf.stack([answer2support, end], 1)) ) answer2question = tf.gather(support2question, answer2support) # compute losses individually if use_sum: span_probs = tf.unsorted_segment_sum( start_probs, answer2question, num_questions) * tf.unsorted_segment_sum( end_probs, answer2question, num_questions) else: span_probs = tf.unsorted_segment_max( start_probs, answer2question, num_questions) * tf.unsorted_segment_max( end_probs, answer2question, num_questions) return -tf.reduce_mean(tf.log(tf.maximum(1e-6, span_probs + 1e-6)))
def data_group_avg(group_ids, data): # Sum each group sum_total = tf.unsorted_segment_sum(data, group_ids, 3) # Count each group num_total = tf.unsorted_segment_sum(tf.ones_like(data), group_ids, 3) # Calculate average avg_by_group = sum_total/num_total return(avg_by_group)
def batch_segment_mean(s_data, s_indices, n): s_data_shp = tf.shape(s_data) s_data_flat = tf.reshape( s_data, [tf.prod(s_data_shp[:-1]), s_data_shp[-1]]) s_indices_flat = tf.reshape(s_indices, [-1]) s_results = tf.unsorted_segment_sum(s_data_flat, s_indices_flat, n) s_weights = tf.unsorted_segment_sum( tf.ones_like(s_indices_flat), s_indices_flat, n) return s_results / tf.cast(tf.expand_dims(s_weights, -1), hparams.FLOATX)
def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None): if s_embed_flat is None: s_embed_flat = tf.reshape( s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE]) with tf.variable_scope(self.name): s_wgt = tf.reshape( s_mix_pwr, [hparams.BATCH_SIZE, -1, 1]) s_wgt = tf.cast( tf.less(5., s_wgt), hparams.FLOATX) s_src_assignment = tf.argmax(s_src_pwr, axis=1) s_indices = tf.reshape( s_src_assignment, [hparams.BATCH_SIZE, -1]) fn_segmean = lambda _: tf.unsorted_segment_sum( _[0], _[1], hparams.MAX_N_SIGNAL) s_attractors = tf.map_fn(fn_segmean, ( s_embed_flat * s_wgt, s_indices), hparams.FLOATX) s_attractors_wgt = tf.map_fn(fn_segmean, ( s_wgt, s_indices), hparams.FLOATX) s_attractors /= (s_attractors_wgt + hparams.EPS) # float[B, C, E] return s_attractors
def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None): if s_embed_flat is None: s_embed_flat = tf.reshape( s_embed, [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE]) with tf.variable_scope(self.name): s_wgt = tf.reshape( s_mix_pwr, [hparams.BATCH_SIZE, -1, 1]) s_src_assignment = tf.argmax(s_src_pwr, axis=1) s_indices = tf.reshape( s_src_assignment, [hparams.BATCH_SIZE, -1]) fn_segmean = lambda _: tf.unsorted_segment_sum( _[0], _[1], hparams.MAX_N_SIGNAL) s_attractors = tf.map_fn(fn_segmean, ( s_embed_flat * s_wgt, s_indices), hparams.FLOATX) s_attractors_wgt = tf.map_fn(fn_segmean, ( s_wgt, s_indices), hparams.FLOATX) s_attractors /= (s_attractors_wgt + hparams.EPS) if hparams.DEBUG: self.debug_fetches = dict() # float[B, C, E] return s_attractors
def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers): """Creates an op for training for full batch case. Args: inputs: list of input Tensors. cluster_idx_list: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. cluster_centers: Tensor Ref of cluster centers. Returns: An op for doing an update of mini-batch k-means. """ cluster_sums = [] cluster_counts = [] epsilon = tf.constant(1e-6, dtype=inputs[0].dtype) for inp, cluster_idx in zip(inputs, cluster_idx_list): with ops.colocate_with(inp): cluster_sums.append(tf.unsorted_segment_sum(inp, cluster_idx, self._num_clusters)) cluster_counts.append(tf.unsorted_segment_sum( tf.reshape(tf.ones(tf.reshape(tf.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, self._num_clusters)) with ops.colocate_with(cluster_centers): new_clusters_centers = tf.add_n(cluster_sums) / ( tf.cast(tf.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) if self._clusters_l2_normalized(): new_clusters_centers = tf.nn.l2_normalize(new_clusters_centers, dim=1) return tf.assign(cluster_centers, new_clusters_centers)
def _apply_sparse(self, cache): """""" x_tm1, g_t, idxs = cache['x_tm1'], cache['g_t'], cache['idxs'] idxs, idxs_ = tf.unique(idxs) g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs)) updates = cache['updates'] if self.mu > 0: m_t, t_m = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', beta=self.mu) m_t_ = tf.gather(m_t, idxs) m_bar_t_ = (1-self.gamma) * m_t_ + self.gamma * g_t_ updates.extend([m_t, t_m]) else: m_bar_t_ = g_t_ if self.nu > 0: v_t, t_v = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', beta=self.nu) v_t_ = tf.gather(v_t, idxs) v_bar_t_ = tf.sqrt(v_t_ + self.epsilon) updates.extend([v_t, t_v]) else: v_bar_t_ = 1 s_t_ = self.learning_rate * m_bar_t_ / v_bar_t_ cache['s_t'] = s_t_ cache['g_t'] = g_t_ cache['idxs'] = idxs return cache
def _apply_sparse(self, cache): """""" g_t, idxs = cache['g_t'], cache['idxs'] idxs, idxs_ = tf.unique(idxs) g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs)) cache['g_t'] = g_t_ cache['idxs'] = idxs cache['s_t'] = self.learning_rate * g_t_ return cache
def _rowwise_unsorted_segment_sum(values, indices, n): """UnsortedSegmentSum on each row. Args: values: a `Tensor` with shape `[batch_size, k]`. indices: an integer `Tensor` with shape `[batch_size, k]`. n: an integer. Returns: A `Tensor` with the same type as `values` and shape `[batch_size, n]`. """ batch, k = tf.unstack(tf.shape(indices), num=2) indices_flat = tf.reshape(indices, [-1]) + tf.div(tf.range(batch * k), k) * n ret_flat = tf.unsorted_segment_sum( tf.reshape(values, [-1]), indices_flat, batch * n) return tf.reshape(ret_flat, [batch, n])
def bucket_mean(data, bucket_ids, num_buckets): total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets) count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets) return total / count
def _apply_sparse(self, cache): """""" x_tm1, g_t, idxs = cache['x_tm1'], cache['g_t'], cache['idxs'] idxs, idxs_ = tf.unique(idxs) g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs)) updates = cache['updates'] if self.mu > 0: m_t, t_m = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', beta=self.mu) m_t_ = tf.gather(m_t, idxs) m_bar_t_ = (1-self.gamma) * m_t_ + self.gamma * g_t_ updates.extend([m_t, t_m]) else: m_bar_t_ = g_t_ if self.nu > 0: v_t, t_v = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', beta=self.nu) v_t_ = tf.gather(v_t, idxs) v_bar_t_ = tf.sqrt(v_t_ + self.epsilon) updates.extend([v_t, t_v]) else: v_bar_t_ = 1 s_t_ = self.learning_rate * m_bar_t_ / v_bar_t_ cache['s_t'] = tf.where(tf.is_finite(s_t_), s_t_, tf.zeros_like(s_t_)) cache['g_t'] = g_t_ cache['idxs'] = idxs return cache
def test_UnsortedSegmentSum(self): t = tf.unsorted_segment_sum(self.random(4, 2, 3), np.array([0, 2, 2, 1]), 3) self.check(t)
def kMeans(iterations, labelledSet, columnPrefix="cluster"): X = labelledSet.as_matrix() start_pos = tf.Variable(X[np.random.randint(X.shape[0], size=iterations),:], dtype=tf.float32) centroids = tf.Variable(start_pos.initialized_value(), "S", dtype=tf.float32) points = tf.Variable(X, 'X', dtype=tf.float32) ones_like = tf.ones((points.get_shape()[0], 1)) prev_assignments = tf.Variable(tf.zeros((points.get_shape()[0], ), dtype=tf.int64)) p1 = tf.matmul( tf.expand_dims(tf.reduce_sum(tf.square(points), 1), 1), tf.ones(shape=(1, iterations)) ) p2 = tf.transpose(tf.matmul( tf.reshape(tf.reduce_sum(tf.square(centroids), 1), shape=[-1, 1]), ones_like, transpose_b=True )) distance = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(points, centroids, transpose_b=True)) point_to_centroid_assignment = tf.argmin(distance, axis=1) total = tf.unsorted_segment_sum(points, point_to_centroid_assignment, iterations) count = tf.unsorted_segment_sum(ones_like, point_to_centroid_assignment, iterations) means = total / count is_continue = tf.reduce_any(tf.not_equal(point_to_centroid_assignment, prev_assignments)) with tf.control_dependencies([is_continue]): loop = tf.group(centroids.assign(means), prev_assignments.assign(point_to_centroid_assignment)) sess = tf.Session() sess.run(tf.global_variables_initializer()) has_changed, cnt = True, 0 while has_changed and cnt < 300: cnt += 1 has_changed, _ = sess.run([is_continue, loop]) res = sess.run(point_to_centroid_assignment) return pandas.DataFrame(res, columns=[columnPrefix + "_" + str(iterations)])
def apply_factor(tensor, *args, **kwargs): scope = kwargs.pop("scope", "") with tf.name_scope(scope): n_args = len(args) if n_args is 0: tensor, output_size, error_symbol = tensor return one_hot(tensor, output_size, scope=scope) else: tensor, args = slice_out_int_literals(tensor, list(args)) args, is_batched = make_batch_consistent(args) tensor, output_size, error_symbol = tensor # handle the case where all arguments were int literals tensor_dim_sizes = [dim.value for dim in tensor.get_shape()] if not tensor_dim_sizes: return one_hot(tensor, output_size, scope=scope) # Each arg is batch size x arg dim. Add dimensions to enable broadcasting. for i, arg in enumerate(args): for j in xrange(n_args): if j == i: continue args[i] = tf.expand_dims(args[i], j + 1) # compute joint before tensor is applied joint = 1 for arg in args: joint = joint * arg # prepare for unsorted_segment_sum joint = tf.reshape(joint, (-1, np.prod(tensor_dim_sizes))) joint = tf.transpose(joint, [1, 0]) # |tensor| x batch_size if error_symbol is not None: result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size + 1) # assume error bin is last bin result = result[:output_size, :] else: result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size) result = tf.transpose(result, [1, 0]) if not is_batched: result = tf.squeeze(result) return result
def __init__(self, layer_type, input_size, target_size, num_hidden_units, activation_type, **kwargs): self.input_size = input_size self.target_size = target_size self.num_hidden_units = num_hidden_units self.square_initializer = tf.random_normal_initializer(0.0, np.sqrt(1.0 / num_hidden_units)) self.non_square_initializer = tf.random_normal_initializer(0.0, np.sqrt(1.0 / num_hidden_units)) self.bias_initializer = tf.constant_initializer(0.0) Layer = getattr(layers, layer_type) activation = getattr(tf.nn, activation_type) self.inputs = tf.placeholder(tf.float32, shape=[None, None, input_size], name='inputs') self.targets = tf.placeholder(tf.float32, shape=[None, None, target_size], name='targets') self.batch_size = tf.shape(self.inputs)[0] self.length = tf.shape(self.inputs)[1] valid_mask_incl_invalid_seqs = tf.logical_not(tf.is_nan(self.targets[0:, 0:, 0])) target_step_counts = tf.reduce_sum(tf.to_int32(valid_mask_incl_invalid_seqs), axis=[1], name='target_step_counts') valid_seq_mask = tf.greater(target_step_counts, 0, name='valid_seq_mask') self.valid_split_ind = tf.identity(tf.cumsum(target_step_counts)[:-1], name='valid_split_ind') valid_seq_ids_incl_invalid_seqs = tf.tile(tf.expand_dims(tf.range(0, self.batch_size), 1), [1, self.length]) valid_seq_ids = tf.boolean_mask(valid_seq_ids_incl_invalid_seqs, valid_mask_incl_invalid_seqs, name='valid_seq_ids') self.valid_targets = tf.boolean_mask(self.targets, valid_mask_incl_invalid_seqs, name='valid_targets') with tf.variable_scope('rnn') as rnn_scope: inputs = self.inputs self._rnn_layer = Layer(inputs, self.num_hidden_units, activation, self.square_initializer, self.non_square_initializer, self.bias_initializer, **kwargs) self.initial_rnn_states = self._rnn_layer.initial_states self.final_rnn_states = self._rnn_layer.final_states with tf.variable_scope('predictions') as predictions_scope: W = tf.get_variable('W', shape=[self.num_hidden_units, self.target_size], initializer=self.non_square_initializer) b = tf.get_variable('b', shape=[self.target_size], initializer=self.bias_initializer) valid_rnn_outputs = tf.boolean_mask(self._rnn_layer.outputs, valid_mask_incl_invalid_seqs) self.valid_predictions = tf.nn.xw_plus_b(valid_rnn_outputs, W, b, name = 'valid_predictions') with tf.variable_scope('loss'): num_valid_seqs = tf.reduce_sum(tf.to_float(valid_seq_mask)) stepwise_losses = self._compute_stepwise_losses() self.valid_stepwise_loss = tf.reduce_mean(stepwise_losses, name='stepwise_loss') self.valid_stepwise_loss_for_opt = tf.identity(num_valid_seqs * self.valid_stepwise_loss, name='valid_stepwise_loss_for_opt') time_counts = tf.to_float(tf.expand_dims(target_step_counts, 1)) * tf.to_float(valid_mask_incl_invalid_seqs) valid_time_counts = tf.boolean_mask(time_counts, valid_mask_incl_invalid_seqs) seq_losses = tf.unsorted_segment_sum(stepwise_losses / valid_time_counts, valid_seq_ids, self.batch_size) self.valid_seq_losses = tf.boolean_mask(seq_losses, valid_seq_mask, name='valid_seq_losses') self.valid_seqwise_loss = tf.reduce_mean(self.valid_seq_losses, name='valid_seqwise_loss') self.valid_seqwise_loss_for_opt = tf.identity(num_valid_seqs * self.valid_seqwise_loss, name='valid_seqwise_loss_for_opt')
def inference(documents, doc_mask, query, query_mask): embedding = tf.get_variable('embedding', [FLAGS.vocab_size, FLAGS.embedding_size], initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05)) regularizer = tf.nn.l2_loss(embedding) doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob) doc_emb.set_shape([None, None, FLAGS.embedding_size]) query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob) query_emb.set_shape([None, None, FLAGS.embedding_size]) with tf.variable_scope('document', initializer=orthogonal_initializer()): fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) doc_len = tf.reduce_sum(doc_mask, reduction_indices=1) h, _ = tf.nn.bidirectional_dynamic_rnn( fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32) #h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) h_doc = tf.concat(h, 2) with tf.variable_scope('query', initializer=orthogonal_initializer()): fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) query_len = tf.reduce_sum(query_mask, reduction_indices=1) h, _ = tf.nn.bidirectional_dynamic_rnn( fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32) #h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) h_query = tf.concat(h, 2) M = tf.matmul(h_doc, h_query, adjoint_b=True) M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1))) alpha = softmax(M, 1, M_mask) beta = softmax(M, 2, M_mask) #query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1) query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1) s = tf.squeeze(tf.matmul(alpha, query_importance), [2]) unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size)) y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s]) return y_hat, regularizer
def __init__(self, requests, expert_capacity): """Create a TruncatingDispatcher. Args: requests: a boolean `Tensor` of shape `[batch, length, num_experts]`. Alternatively, a float or int Tensor containing zeros and ones. expert_capacity: a Scalar - maximum number of examples per expert per batch element. Returns: a TruncatingDispatcher """ self._requests = tf.to_float(requests) self._expert_capacity = expert_capacity expert_capacity_f = tf.to_float(expert_capacity) self._batch, self._length, self._num_experts = tf.unstack( tf.shape(self._requests), num=3) # [batch, length, num_experts] position_in_expert = tf.cumsum(self._requests, axis=1, exclusive=True) # [batch, length, num_experts] self._gates = self._requests * tf.to_float( tf.less(position_in_expert, expert_capacity_f)) batch_index = tf.reshape( tf.to_float(tf.range(self._batch)), [self._batch, 1, 1]) length_index = tf.reshape( tf.to_float(tf.range(self._length)), [1, self._length, 1]) expert_index = tf.reshape( tf.to_float(tf.range(self._num_experts)), [1, 1, self._num_experts]) # position in a Tensor with shape [batch * num_experts * expert_capacity] flat_position = ( position_in_expert + batch_index * (tf.to_float(self._num_experts) * expert_capacity_f) + expert_index * expert_capacity_f) # Tensor of shape [batch * num_experts * expert_capacity]. # each element is an integer in [0, length) self._indices = tf.unsorted_segment_sum( data=tf.reshape((length_index + 1.0) * self._gates, [-1]), segment_ids=tf.to_int32(tf.reshape(flat_position, [-1])), num_segments=self._batch * self._num_experts * expert_capacity) self._indices = tf.reshape( self._indices, [self._batch, self._num_experts, expert_capacity]) # Tensors of shape [batch, num_experts, expert_capacity]. # each element is 0.0 or 1.0 self._nonpadding = tf.minimum(self._indices, 1.0) # each element is an integer in [0, length) self._indices = tf.nn.relu(self._indices - 1.0) # self._flat_indices is [batch, num_experts, expert_capacity], with values # in [0, batch * length) self._flat_indices = tf.to_int32( self._indices + (tf.reshape(tf.to_float(tf.range(self._batch)), [-1, 1, 1]) * tf.to_float(self._length))) self._indices = tf.to_int32(self._indices)