我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.TensorShape()。
def __call__(self, getter, *args, **kwargs): size = tf.TensorShape(kwargs['shape']).num_elements() if size < self.small_variable_size_threshold: device_name = self.device_for_small_variables else: device_index, _ = min(enumerate(self.sizes), key=operator.itemgetter(1)) device_name = self.devices[device_index] self.sizes[device_index] += size kwargs['caching_device'] = device_name var = getter(*args, **kwargs) return var # To be used with custom_getter on tf.get_variable. Ensures the created variable # is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
def __init__(self, num_units, activation=None, reuse=None, kernel_initializer=None, bias_initializer=None, layer_norm=False, state_keep_prob=None, input_keep_prob=None, input_size=None, final=False): super(DropoutGRUCell, self).__init__(_reuse=reuse) self._num_units = num_units self._activation = activation or tf.nn.tanh self._kernel_initializer = kernel_initializer self._bias_initializer = bias_initializer self._layer_norm = layer_norm self._state_keep_prob = state_keep_prob self._input_keep_prob = input_keep_prob self._final = final def batch_noise(s): s = tf.concat(([1], tf.TensorShape(s).as_list()), 0) return tf.random_uniform(s) if input_keep_prob is not None: self._input_noise = DropoutGRUCell._enumerated_map_structure(lambda i, s: batch_noise(s), input_size) if state_keep_prob is not None: self._state_noise = DropoutGRUCell._enumerated_map_structure(lambda i, s: batch_noise(s), num_units)
def _sample(self, n_samples): if self.logits.get_shape().ndims == 2: logits_flat = self.logits else: logits_flat = tf.reshape(self.logits, [-1, self.n_categories]) samples_flat = tf.transpose(tf.multinomial(logits_flat, n_samples)) samples_flat = tf.cast(samples_flat, self.dtype) if self.logits.get_shape().ndims == 2: return samples_flat shape = tf.concat([[n_samples], self.batch_shape], 0) samples = tf.reshape(samples_flat, shape) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape())) return samples
def _sample(self, n_samples): n = self.n_experiments if self.logits.get_shape().ndims == 1: logits_flat = self.logits else: logits_flat = tf.reshape(self.logits, [-1]) log_1_minus_p = -tf.nn.softplus(logits_flat) log_p = logits_flat + log_1_minus_p stacked_logits_flat = tf.stack([log_1_minus_p, log_p], axis=-1) samples_flat = tf.transpose( tf.multinomial(stacked_logits_flat, n_samples * n)) shape = tf.concat([[n, n_samples], self.batch_shape], 0) samples = tf.reduce_sum(tf.reshape(samples_flat, shape), axis=0) static_n_samples = n_samples if isinstance(n_samples, int) else None static_shape = tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape()) samples.set_shape(static_shape) return tf.cast(samples, self.dtype)
def _sample(self, n_samples): # samples must be sampled from (-1, 1) rather than [-1, 1) loc, scale = self.loc, self.scale if not self.is_reparameterized: loc = tf.stop_gradient(loc) scale = tf.stop_gradient(scale) shape = tf.concat([[n_samples], self.batch_shape], 0) uniform_samples = tf.random_uniform( shape=shape, minval=np.nextafter(self.dtype.as_numpy_dtype(-1.), self.dtype.as_numpy_dtype(0.)), maxval=1., dtype=self.dtype) samples = loc - scale * tf.sign(uniform_samples) * \ tf.log1p(-tf.abs(uniform_samples)) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape())) return samples
def _sample(self, n_samples): mean, cov_tril = self.mean, self.cov_tril if not self.is_reparameterized: mean = tf.stop_gradient(mean) cov_tril = tf.stop_gradient(cov_tril) def tile(t): new_shape = tf.concat([[n_samples], tf.ones_like(tf.shape(t))], 0) return tf.tile(tf.expand_dims(t, 0), new_shape) batch_mean = tile(mean) batch_cov = tile(cov_tril) # n_dim -> n_dim x 1 for matmul batch_mean = tf.expand_dims(batch_mean, -1) noise = tf.random_normal(tf.shape(batch_mean), dtype=self.dtype) samples = tf.matmul(batch_cov, noise) + batch_mean samples = tf.squeeze(samples, -1) # Update static shape static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape(tf.TensorShape([static_n_samples]) .concatenate(self.get_batch_shape()) .concatenate(self.get_value_shape())) return samples
def _sample(self, n_samples): if self.logits.get_shape().ndims == 2: logits_flat = self.logits else: logits_flat = tf.reshape(self.logits, [-1, self.n_categories]) samples_flat = tf.transpose(tf.multinomial(logits_flat, n_samples)) if self.logits.get_shape().ndims == 2: samples = samples_flat else: shape = tf.concat([[n_samples], self.batch_shape], 0) samples = tf.reshape(samples_flat, shape) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]). concatenate(self.get_batch_shape())) samples = tf.one_hot(samples, self.n_categories, dtype=self.dtype) return samples
def _sample(self, n_samples): logits, temperature = self.logits, self.temperature if not self.is_reparameterized: logits = tf.stop_gradient(logits) temperature = tf.stop_gradient(temperature) shape = tf.concat([[n_samples], tf.shape(self.logits)], 0) uniform = open_interval_standard_uniform(shape, self.dtype) # TODO: Add Gumbel distribution gumbel = -tf.log(-tf.log(uniform)) samples = tf.nn.softmax((logits + gumbel) / temperature) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate(logits.get_shape())) return samples
def _check_input_shape(self, given): given = tf.convert_to_tensor(given, dtype=self.dtype) err_msg = "The given argument should be able to broadcast to " \ "match batch_shape + value_shape of the distribution." if (given.get_shape() and self.get_batch_shape() and self.get_value_shape()): static_sample_shape = tf.TensorShape( self.get_batch_shape().as_list() + self.get_value_shape().as_list()) try: tf.broadcast_static_shape(given.get_shape(), static_sample_shape) except ValueError: raise ValueError( err_msg + " ({} vs. {} + {})".format( given.get_shape(), self.get_batch_shape(), self.get_value_shape())) return given
def tf_obj_shape(input): """ Convert tf objects to shape tuple. Arguments: input: tf.TensorShape, tf.Tensor, tf.AttrValue or tf.NodeDef the corresponding tensorflow object Returns: tuple: shape of the tensorflow object """ if isinstance(input, tf.TensorShape): return tuple([int(i.value) for i in input]) elif isinstance(input, tf.Tensor): return tf_obj_shape(input.get_shape()) elif isinstance(input, tf.AttrValue): return tuple([int(d.size) for d in input.shape.dim]) elif isinstance(input, tf.NodeDef): return tf_obj_shape(input.attr['shape']) else: raise TypeError("Input to `tf_obj_shape` has the wrong type.")
def autoformat_kernel_2d(strides): if isinstance(strides, int): return [1, strides, strides, 1] elif isinstance(strides, (tuple, list, tf.TensorShape)): if len(strides) == 2: return [1, strides[0], strides[1], 1] elif len(strides) == 4: return [strides[0], strides[1], strides[2], strides[3]] else: raise Exception("strides length error: " + str(len(strides)) + ", only a length of 2 or 4 is supported.") else: raise Exception("strides format error: " + str(type(strides))) # Auto format filter size # Output shape: (rows, cols, input_depth, out_depth)
def autoformat_stride_3d(strides): if isinstance(strides, int): return [1, strides, strides, strides, 1] elif isinstance(strides, (tuple, list, tf.TensorShape)): if len(strides) == 3: return [1, strides[0], strides[1],strides[2], 1] elif len(strides) == 5: assert strides[0] == strides[4] == 1, "Must have strides[0] = strides[4] = 1" return [strides[0], strides[1], strides[2], strides[3], strides[4]] else: raise Exception("strides length error: " + str(len(strides)) + ", only a length of 3 or 5 is supported.") else: raise Exception("strides format error: " + str(type(strides))) # Auto format kernel for 3d convolution
def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size x s]` for each s in `state_size`. """ # Keep scope for backwards compatibility. with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return rnn_cell_impl._zero_state_tensors( # pylint: disable=protected-access self.state_size, batch_size, dtype)
def testMLPFinalCore(self): batch_size = 2 sequence_length = 3 input_size = 4 mlp_last_layer_size = 17 cores = [ snt.LSTM(hidden_size=10), snt.nets.MLP(output_sizes=[6, 7, mlp_last_layer_size]), ] deep_rnn = snt.DeepRNN(cores, skip_connections=False) input_sequence = tf.constant( np.random.randn(sequence_length, batch_size, input_size), dtype=tf.float32) initial_state = deep_rnn.initial_state(batch_size=batch_size) output, unused_final_state = tf.nn.dynamic_rnn( deep_rnn, input_sequence, initial_state=initial_state, time_major=True) self.assertEqual( output.get_shape(), tf.TensorShape([sequence_length, batch_size, mlp_last_layer_size]))
def select_present(x, presence, batch_size=1, name='select_present'): with tf.variable_scope(name): presence = 1 - tf.to_int32(presence) # invert mask bs = x.get_shape()[0] if bs != None: # here type(bs) is tf.Dimension and == is ok batch_size = int(bs) num_partitions = 2 * batch_size r = tf.range(0, num_partitions, 2) r.set_shape(tf.TensorShape(batch_size)) r = broadcast_against(r, presence) presence += r selected = tf.dynamic_partition(x, presence, num_partitions) selected = tf.concat(axis=0, values=selected) selected = tf.reshape(selected, tf.shape(x)) return selected
def broadcast_against(tensor, against_expr): """Adds trailing dimensions to mask to enable broadcasting against data :param tensor: tensor to be broadcasted :param against_expr: tensor will be broadcasted against it :return: mask expr with tf.rank(mask) == tf.rank(data) """ def cond(data, tensor): return tf.less(tf.rank(tensor), tf.rank(data)) def body(data, tensor): return data, tf.expand_dims(tensor, -1) shape_invariants = [against_expr.get_shape(), tf.TensorShape(None)] _, tensor = tf.while_loop(cond, body, [against_expr, tensor], shape_invariants) return tensor
def __call__(self, getter, *args, **kwargs): size = tf.TensorShape(kwargs['shape']).num_elements() if size < self.small_variable_size_threshold: device_name = self.device_for_small_variables else: device_index, _ = min(enumerate( self.sizes), key=operator.itemgetter(1)) device_name = self.devices[device_index] self.sizes[device_index] += size kwargs['caching_device'] = device_name var = getter(*args, **kwargs) return var # To be used with custom_getter on tf.get_variable. Ensures the created variable # is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
def get_probs_and_accuracy(preds,O): """ helper function. we have a prediction for each MC sample of each observation in this batch. need to distill the multiple preds from each MC into a single pred for this observation. also get accuracy. use true probs to get ROC, PR curves in sklearn """ all_probs = tf.exp(preds[:,1] - tf.reduce_logsumexp(preds, axis = 1)) #normalize; and drop a dim so only prob of positive case N = tf.cast(tf.shape(preds)[0]/n_mc_smps,tf.int32) #actual number of observations in preds, collapsing MC samples #predicted probability per observation; collapse the MC samples probs = tf.zeros([0]) #store all samples in a list, then concat into tensor at end #setup tf while loop (have to use this bc loop size is variable) def cond(i,probs): return i < N def body(i,probs): probs = tf.concat([probs,[tf.reduce_mean(tf.slice(all_probs,[i*n_mc_smps],[n_mc_smps]))]],0) return i+1,probs i = tf.constant(0) i,probs = tf.while_loop(cond,body,loop_vars=[i,probs],shape_invariants=[i.get_shape(),tf.TensorShape([None])]) #compare to truth; just use cutoff of 0.5 for right now to get accuracy correct_pred = tf.equal(tf.cast(tf.greater(probs,0.5),tf.int32), O) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) return probs,accuracy
def batch_repeat_unpack(x, repeats=1, name=None): with tf.name_scope(name, "batch-repeat-unpack", values=[x]): # x.shape = (batches, repeats, ...) # reshape to (batches * repeats, ...) shape = tf.concat([[-1], [repeats], tf.shape(x)[1:]], axis=0) t = tf.reshape(x, shape=shape) repeats_dim = tf.Dimension(repeats) t.set_shape( tf.TensorShape([ x.get_shape()[0] // repeats_dim, repeats_dim ]).concatenate(x.get_shape()[1:]) ) return t
def output_size(self): # Return the cell output and the id return BeamSearchOptimizationDecoderOutput( scores=tf.TensorShape([self._beam_width]), predicted_ids=tf.TensorShape([self._beam_width]), parent_ids=tf.TensorShape([self._beam_width]), gold_score=tf.TensorShape(()), loss=tf.TensorShape(()))
def _merge_batch_beams(self, t, s): """Merges the tensor from a batch of beams into a batch by beams. More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We reshape this into [batch_size*beam_width, s] Args: t: Tensor of dimension [batch_size, beam_width, s] Returns: A reshaped version of t with dimension [batch_size * beam_width, s]. """ t_shape = tf.shape(t) reshaped = tf.reshape(t, tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), axis=0)) reshaped.set_shape(tf.TensorShape([None]).concatenate(s)) return reshaped
def _split_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We reshape this into [batch_size, beam_width, s] Args: t: Tensor of dimension [batch_size*beam_width, s]. s: (Possibly known) depth shape. Returns: A reshaped version of t with dimension [batch_size, beam_width, s]. Raises: ValueError: If, after reshaping, the new tensor is not shaped `[batch_size, beam_width, s]` (assuming batch_size and beam_width are known statically). """ t_shape = tf.shape(t) reshaped = tf.reshape(t, tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), axis=0)) reshaped.set_shape(tf.TensorShape([None, self._beam_width]).concatenate(t.shape[1:])) expected_reshaped_shape = tf.TensorShape([None, self._beam_width]).concatenate(s) if not reshaped.shape.is_compatible_with(expected_reshaped_shape): raise ValueError("Unexpected behavior when reshaping between beam width " "and batch size. The reshaped tensor has shape: %s. " "We expected it to have shape " "(batch_size, beam_width, depth) == %s. Perhaps you " "forgot to create a zero_state with " "batch_size=encoder_batch_size * beam_width?" % (reshaped.shape, expected_reshaped_shape)) return reshaped
def _maybe_split_batch_beams(self, t, s): """Maybe splits the tensor from a batch by beams into a batch of beams. We do this so that we can use nest and not run into problems with shapes. Args: t: Tensor of dimension [batch_size*beam_width, s] s: Tensor, Python int, or TensorShape. Returns: Either a reshaped version of t with dimension [batch_size, beam_width, s] if t's first dimension is of size batch_size*beam_width or t if not. Raises: TypeError: If t is an instance of TensorArray. ValueError: If the rank of t is not statically known. """ return self._split_batch_beams(t, s) if t.shape.ndims >= 1 else t
def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We reshape this into [batch_size, beam_width, s] Args: t: Tensor of dimension [batch_size*beam_width, s] s: Tensor, Python int, or TensorShape. Returns: A reshaped version of t with dimension [batch_size, beam_width, s]. Raises: TypeError: If t is an instance of TensorArray. ValueError: If the rank of t is not statically known. """ return self._merge_batch_beams(t, s) if t.shape.ndims >= 2 else t
def build(self, input_shape): input_shape = tf.TensorShape(input_shape) if input_shape[-1].value is None: raise ValueError("Input to DotProductLayer must have the last dimension defined") if input_shape[-1].value != self._depth_size: self._space_transform = self.add_variable('kernel', shape=(input_shape[-1].value, self._depth_size), dtype=self.dtype, trainable=True) else: self._space_transform = None
def _compute_output_shape(self, input_shape): input_shape = tf.TensorShape(input_shape) input_shape = input_shape.with_rank_at_least(2) return input_shape[:-1].concatenate(self._output_size)
def output_size(self): return BeamDecoderOutput( logits=self.decoder.vocab_size, predicted_ids=tf.TensorShape([]), log_probs=tf.TensorShape([]), scores=tf.TensorShape([]), beam_parent_ids=tf.TensorShape([]), original_outputs=self.decoder.output_size)
def output_size(self): return DecoderOutput( logits=self.vocab_size, predicted_ids=tf.TensorShape([]), cell_output=self.cell.output_size)
def output_size(self): return AttentionDecoderOutput( logits=self.vocab_size, predicted_ids=tf.TensorShape([]), cell_output=self.cell.output_size, attention_scores=tf.shape(self.attention_values)[1:-1], attention_context=self.attention_values.get_shape()[-1])
def _two_element_tuple(int_or_tuple): """Converts `int_or_tuple` to height, width. Several of the functions that follow accept arguments as either a tuple of 2 integers or a single integer. A single integer indicates that the 2 values of the tuple are the same. This functions normalizes the input value by always returning a tuple. Args: int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape. Returns: A tuple with 2 values. Raises: ValueError: If `int_or_tuple` it not well formed. """ if isinstance(int_or_tuple, (list, tuple)): if len(int_or_tuple) != 2: raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple) return int(int_or_tuple[0]), int(int_or_tuple[1]) if isinstance(int_or_tuple, int): return int(int_or_tuple), int(int_or_tuple) if isinstance(int_or_tuple, tf.TensorShape): if len(int_or_tuple) == 2: return int_or_tuple[0], int_or_tuple[1] raise ValueError('Must be an int, a list with 2 elements or a TensorShape of ' 'length 2')
def _init_inception(): global softmax if not os.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR) filename = DATA_URL.split('/')[-1] filepath = os.path.join(MODEL_DIR, filename) if not os.path.exists(filepath): def _progress(count, block_size, total_size): sys.stdout.write('\r>> Downloading %s %.1f%%' % ( filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) print() statinfo = os.stat(filepath) print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) with tf.gfile.FastGFile(os.path.join( MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') # Works with an arbitrary minibatch size. with tf.Session() as sess: pool3 = sess.graph.get_tensor_by_name('pool_3:0') ops = pool3.graph.get_operations() for op_idx, op in enumerate(ops): for o in op.outputs: shape = o.get_shape() shape = [s.value for s in shape] new_shape = [] for j, s in enumerate(shape): if s == 1 and j == 0: new_shape.append(None) else: new_shape.append(s) o._shape = tf.TensorShape(new_shape) w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] logits = tf.matmul(tf.squeeze(pool3), w) softmax = tf.nn.softmax(logits)
def cwt(wav, widthCwt, wavelet): length = wav.shape[0] wav = tf.to_float(wav) wav = tf.reshape(wav, [1,length,1,1]) # While loop functions def body(i, m): v = conv1DWavelet(wav, i, wavelet) v = tf.reshape(v, [length, 1]) m = tf.concat([m,v], 1) return [1 + i, m] def cond_(i, m): return tf.less_equal(i, widthCwt) # Initialize and run while loop emptyCwtMatrix = tf.zeros([length, 0], dtype='float32') i = tf.constant(1) _, result = tf.while_loop( cond_, body, [i, emptyCwtMatrix], shape_invariants=[i.get_shape(), tf.TensorShape([length, None])], back_prop=False, parallel_iterations=1024, ) result = tf.transpose(result) return result # ------------------------------------------------------ # wavelets
def get_distorted_inputs(original_image, bboxes, cfg, add_summaries): distorter = DistortedInputs(cfg, add_summaries) num_bboxes = tf.shape(bboxes)[0] distorted_inputs = tf.TensorArray( dtype=tf.float32, size=num_bboxes, element_shape=tf.TensorShape([1, cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3]) ) if add_summaries: image_summaries = tf.TensorArray( dtype=tf.float32, size=4, element_shape=tf.TensorShape([1, cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3]) ) else: image_summaries = tf.constant([]) current_index = tf.constant(0, dtype=tf.int32) loop_vars = [original_image, bboxes, distorted_inputs, image_summaries, current_index] original_image, bboxes, distorted_inputs, image_summaries, current_index = tf.while_loop( cond=bbox_crop_loop_cond, body=distorter.apply, loop_vars=loop_vars, parallel_iterations=10, back_prop=False, swap_memory=False ) distorted_inputs = distorted_inputs.concat() if add_summaries: tf.summary.image('0.original_image', image_summaries.read(0)) tf.summary.image('1.image_with_random_crop', image_summaries.read(1)) tf.summary.image('2.cropped_resized_image', image_summaries.read(2)) tf.summary.image('3.final_distorted_image', image_summaries.read(3)) return distorted_inputs
def test_tensor_shape(self): self.assertConverts(tf.TensorShape([]), tdt.TensorType(())) self.assertConverts(tf.TensorShape([1]), tdt.TensorType((1,))) self.assertConverts(tf.TensorShape([1, 2]), tdt.TensorType((1, 2))) self.assertConverts(tf.TensorShape([1, 2, 3]), tdt.TensorType((1, 2, 3)))
def _create_queue(self, queue_id, ctor=tf.RandomShuffleQueue): # The enqueuing workers transform inputs into serialized loom # weaver messages, which are represented as strings. return ctor( capacity=self.queue_capacity or 4 * self.batch_size, min_after_dequeue=0, dtypes=[tf.string], shapes=[tf.TensorShape([])], shared_name='tensorflow_fold_plan_queue%s' % queue_id)
def _get_value_shape(self): if self._shape_fully_defined: return tf.TensorShape([5]) return tf.TensorShape(None)
def _get_batch_shape(self): if self._shape_fully_defined: return tf.TensorShape([2, 3, 4]) return tf.TensorShape([None, 3, 4])
def _get_value_shape(self): return tf.TensorShape([])
def _sample(self, n_samples): mean, std = self.mean, self.std if not self.is_reparameterized: mean = tf.stop_gradient(mean) std = tf.stop_gradient(std) shape = tf.concat([[n_samples], self.batch_shape], 0) samples = tf.random_normal(shape, dtype=self.dtype) * std + mean static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape())) return samples
def _sample(self, n_samples): p = tf.sigmoid(self.logits) shape = tf.concat([[n_samples], self.batch_shape], 0) alpha = tf.random_uniform( shape, minval=0, maxval=1, dtype=self.param_dtype) samples = tf.cast(tf.less(alpha, p), dtype=self.dtype) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape())) return samples
def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None)
def _sample(self, n_samples): minval, maxval = self.minval, self.maxval if not self.is_reparameterized: minval = tf.stop_gradient(minval) maxval = tf.stop_gradient(maxval) shape = tf.concat([[n_samples], self.batch_shape], 0) samples = tf.random_uniform(shape, 0, 1, dtype=self.dtype) * \ (maxval - minval) + minval static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate( self.get_batch_shape())) return samples