def _get_flat_core_sizes(cores): """Obtains the list flattened output sizes of a list of cores. Args: cores: list of cores to get the shapes from. Returns: List of lists that, for each core, contains the list of its output dimensions. """ core_sizes_lists = [] for core in cores: flat_output_size = nest.flatten(core.output_size) core_sizes_lists.append([tensor_shape.as_shape(size).as_list() for size in flat_output_size]) return core_sizes_lists
def _state_size_with_prefix(state_size, prefix=None): """Helper function that enables int or TensorShape shape specification. This function takes a size specification, which can be an integer or a TensorShape, and converts it into a list of integers. One may specify any additional dimensions that precede the final state size specification. Args: state_size: TensorShape or int that specifies the size of a tensor. prefix: optional additional list of dimensions to prepend. Returns: result_state_size: list of dimensions the resulting tensor size. """ result_state_size = tensor_shape.as_shape(state_size).as_list() if prefix is not None: if not isinstance(prefix, list): raise TypeError("prefix of _state_size_with_prefix should be a list.") result_state_size = prefix + result_state_size return result_state_size
def is_compatible_with(self, other): """Returns True if signatures are compatible.""" def _shape_is_compatible_0dim(this, other): """Checks that shapes are compatible skipping dim 0.""" other = tensor_shape.as_shape(other) # If shapes are None (unknown) they may be compatible. if this.dims is None or other.dims is None: return True if this.ndims != other.ndims: return False for dim, (x_dim, y_dim) in enumerate(zip(this.dims, other.dims)): if dim == 0: continue if not x_dim.is_compatible_with(y_dim): return False return True if other.is_sparse: return self.is_sparse and self.dtype.is_compatible_with(other.dtype) return (self.dtype.is_compatible_with(other.dtype) and _shape_is_compatible_0dim(self.shape, other.shape) and not self.is_sparse)
def set_attr_shape(node, key, value): try: node.attr[key].CopyFrom( attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto())) except KeyError: pass
def trainable_initial_state(batch_size, state_size, dtype, initializers=None): """Creates an initial state consisting of trainable variables. The trainable variables are created with the same shapes as the elements of `state_size` and are tiled to produce an initial state. Args: batch_size: An int, or scalar int32 Tensor representing the batch size. state_size: A `TensorShape` or nested tuple of `TensorShape`s to use for the shape of the trainable variables. dtype: The data type used to create the variables and thus initial state. initializers: An optional container of the same structure as `state_size` containing initializers for the variables. Returns: A `Tensor` or nested tuple of `Tensor`s with the same size and structure as `state_size`, where each `Tensor` is a tiled trainable `Variable`. Raises: ValueError: if the user passes initializers that are not functions. """ flat_state_size = nest.flatten(state_size) if not initializers: flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size) else: nest.assert_same_structure(initializers, state_size) flat_initializer = nest.flatten(initializers) if not all([callable(init) for init in flat_initializer]): raise ValueError("Not all the passed initializers are callable objects.") # Produce names for the variables. In the case of a tuple or nested tuple, # this is just a sequence of numbers, but for a flat `namedtuple`, we use # the field names. NOTE: this could be extended to nested `namedtuple`s, # but for now that's extra complexity that's not used anywhere. try: names = ["init_{}".format(state_size._fields[i]) for i in xrange(len(flat_state_size))] except (AttributeError, IndexError): names = ["init_state_{}".format(i) for i in xrange(len(flat_state_size))] flat_initial_state = [] for name, size, init in zip(names, flat_state_size, flat_initializer): shape_with_batch_dim = [1] + tensor_shape.as_shape(size).as_list() initial_state_variable = tf.get_variable( name, shape=shape_with_batch_dim, dtype=dtype, initializer=init) initial_state_variable_dims = initial_state_variable.get_shape().ndims tile_dims = [batch_size] + [1] * (initial_state_variable_dims - 1) flat_initial_state.append( tf.tile(initial_state_variable, tile_dims, name=(name + "_tiled"))) return nest.pack_sequence_as(structure=state_size, flat_sequence=flat_initial_state)
def _concat(prefix, suffix, static=False): """Concat that enables int, Tensor, or TensorShape values. This function takes a size specification, which can be an integer, a TensorShape, or a Tensor, and converts it into a concatenated Tensor (if static = False) or a list of integers (if static = True). Args: prefix: The prefix; usually the batch size (and/or time step size). (TensorShape, int, or Tensor.) suffix: TensorShape, int, or Tensor. static: If `True`, return a python list with possibly unknown dimensions. Otherwise return a `Tensor`. Returns: shape: the concatenation of prefix and suffix. Raises: ValueError: if `suffix` is not a scalar or vector (or TensorShape). ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ if isinstance(prefix, ops.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: p = array_ops.expand_dims(p, 0) elif p.shape.ndims != 1: raise ValueError("prefix tensor must be either a scalar or vector, " "but saw tensor: %s" % p) else: p = tensor_shape.as_shape(prefix) p_static = p.as_list() if p.ndims is not None else None p = (constant_op.constant(p.as_list(), dtype=dtypes.int32) if p.is_fully_defined() else None) if isinstance(suffix, ops.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: s = array_ops.expand_dims(s, 0) elif s.shape.ndims != 1: raise ValueError("suffix tensor must be either a scalar or vector, " "but saw tensor: %s" % s) else: s = tensor_shape.as_shape(suffix) s_static = s.as_list() if s.ndims is not None else None s = (constant_op.constant(s.as_list(), dtype=dtypes.int32) if s.is_fully_defined() else None) if static: shape = tensor_shape.as_shape(p_static).concatenate(s_static) shape = shape.as_list() if shape.ndims is not None else None else: if p is None or s is None: raise ValueError("Provided a prefix or suffix of None: %s and %s" % (prefix, suffix)) shape = array_ops.concat((p, s), 0) return shape