我们从Python开源项目中,提取了以下25个代码示例,用于说明如何使用tensorflow.matrix_inverse()。
def __init__(self): self.vertices = tf.placeholder(tf.float32, [None, 3]) self.normals = tf.placeholder(tf.float32, [None, 3]) self.uvs = tf.placeholder(tf.float32, [None, 2]) self.texture = tf.placeholder(tf.float32, [None, None, 3]) default_light_dir = np.array([-1, -1, -1], dtype=np.float32) default_ambient = np.array([0.5, 0.5, 0.5], dtype=np.float32) default_diffuse = np.array([1, 1, 1], dtype=np.float32) default_wvp = np.eye(4, dtype=np.float32) self.light_dir = tf.placeholder_with_default(default_light_dir, [3]) self.ambient = tf.placeholder_with_default(default_ambient, [3]) self.diffuse = tf.placeholder_with_default(default_diffuse, [3]) self.wvp = tf.placeholder_with_default(default_wvp, [4, 4]) self.packed_texture = utils.pack_colors(self.texture, 2, False) self.iwvp = tf.matrix_inverse(self.wvp) self.varying_uv = [None, None, None] self.varying_norm = [None, None, None]
def backward_step_fn(self, params, inputs): """ Backwards step over a batch, to be used in tf.scan :param params: :param inputs: (batch_size, variable dimensions) :return: """ mu_back, Sigma_back = params mu_pred_tp1, Sigma_pred_tp1, mu_filt_t, Sigma_filt_t, A = inputs # J_t = tf.matmul(tf.reshape(tf.transpose(tf.matrix_inverse(Sigma_pred_tp1), [0, 2, 1]), [-1, self.dim_z]), # self.A) # J_t = tf.transpose(tf.reshape(J_t, [-1, self.dim_z, self.dim_z]), [0, 2, 1]) J_t = tf.matmul(tf.transpose(A, [0, 2, 1]), tf.matrix_inverse(Sigma_pred_tp1)) J_t = tf.matmul(Sigma_filt_t, J_t) mu_back = mu_filt_t + tf.matmul(J_t, mu_back - mu_pred_tp1) Sigma_back = Sigma_filt_t + tf.matmul(J_t, tf.matmul(Sigma_back - Sigma_pred_tp1, J_t, adjoint_b=True)) return mu_back, Sigma_back
def _makeT(self,cp): with tf.variable_scope('_makeT'): cp = tf.reshape(cp,(-1,2,self.Column_controlP_number*self.Row_controlP_number)) cp = tf.cast(cp,'float32') N_f = tf.shape(cp)[0] #c_s x,y = tf.linspace(-1.,1.,self.Column_controlP_number),tf.linspace(-1.,1.,self.Row_controlP_number) x,y = tf.meshgrid(x,y) xs,ys = tf.transpose(tf.reshape(x,(-1,1))),tf.transpose(tf.reshape(y,(-1,1))) cp_s = tf.concat([xs,ys],0) cp_s_trans = tf.transpose(cp_s) ##===Compute distance R xs_trans,ys_trans = tf.transpose(tf.stack([xs],axis=2),perm=[1,0,2]),tf.transpose(tf.stack([ys],axis=2),perm=[1,0,2]) xs, xs_trans = tf.meshgrid(xs,xs_trans);ys, ys_trans = tf.meshgrid(ys,ys_trans) Rx,Ry = tf.square(tf.subtract(xs,xs_trans)),tf.square(tf.subtract(ys,ys_trans)) R = tf.add(Rx,Ry) R = tf.multiply(R,tf.log(tf.clip_by_value(R,1e-10,1e+10))) ones = tf.ones([tf.multiply(self.Row_controlP_number,self.Column_controlP_number),1],tf.float32) ones_trans = tf.transpose(ones) zeros = tf.zeros([3,3],tf.float32) Deltas1 = tf.concat([ones, cp_s_trans, R],1) Deltas2 = tf.concat([ones_trans,cp_s],0) Deltas2 = tf.concat([zeros,Deltas2],1) Deltas = tf.concat([Deltas1,Deltas2],0) ##get deltas_inv Deltas_inv = tf.matrix_inverse(Deltas) Deltas_inv = tf.expand_dims(Deltas_inv,0) Deltas_inv = tf.reshape(Deltas_inv,[-1]) Deltas_inv_f = tf.tile(Deltas_inv,tf.stack([N_f])) Deltas_inv_f = tf.reshape(Deltas_inv_f,tf.stack([N_f,self.Column_controlP_number*self.Row_controlP_number+3, -1])) cp_trans =tf.transpose(cp,perm=[0,2,1]) zeros_f_In = tf.zeros([N_f,3,2],tf.float32) cp = tf.concat([cp_trans,zeros_f_In],1) T = tf.transpose(tf.matmul(Deltas_inv_f,cp),[0,2,1]) return T
def _makeT(self,cp): with tf.variable_scope('_makeT'): cp = tf.reshape(cp,(-1,3,self.X_controlP_number*self.Y_controlP_number*self.Z_controlP_number)) cp = tf.cast(cp,'float32') N_f = tf.shape(cp)[0] #c_s x,y,z = tf.linspace(-1.,1.,self.X_controlP_number),tf.linspace(-1.,1.,self.Y_controlP_number),tf.linspace(-1.,1.,self.Z_controlP_number) x = tf.tile(x,[self.Y_controlP_number*self.Z_controlP_number]) y = tf.tile(self._repeat(y,self.X_controlP_number,'float32'),[self.Z_controlP_number]) z = self._repeat(z,self.X_controlP_number*self.Y_controlP_number,'float32') xs,ys,zs = tf.transpose(tf.reshape(x,(-1,1))),tf.transpose(tf.reshape(y,(-1,1))),tf.transpose(tf.reshape(z,(-1,1))) cp_s = tf.concat([xs,ys,zs],0) cp_s_trans = tf.transpose(cp_s) # (4*4*4)*3 -> 64 * 3 ##===Compute distance R xs_trans,ys_trans,zs_trans = tf.transpose(tf.stack([xs],axis=2),perm=[1,0,2]),tf.transpose(tf.stack([ys],axis=2),perm=[1,0,2]),tf.transpose(tf.stack([zs],axis=2),perm=[1,0,2]) xs, xs_trans = tf.meshgrid(xs,xs_trans);ys, ys_trans = tf.meshgrid(ys,ys_trans);zs, zs_trans = tf.meshgrid(zs,zs_trans) Rx,Ry, Rz = tf.square(tf.subtract(xs,xs_trans)),tf.square(tf.subtract(ys,ys_trans)),tf.square(tf.subtract(zs,zs_trans)) R = tf.add_n([Rx,Ry,Rz]) R = tf.multiply(R,tf.log(tf.clip_by_value(R,1e-10,1e+10))) ones = tf.ones([self.Y_controlP_number*self.X_controlP_number*self.Z_controlP_number,1],tf.float32) ones_trans = tf.transpose(ones) zeros = tf.zeros([4,4],tf.float32) Deltas1 = tf.concat([ones, cp_s_trans, R],1) Deltas2 = tf.concat([ones_trans,cp_s],0) Deltas2 = tf.concat([zeros,Deltas2],1) Deltas = tf.concat([Deltas1,Deltas2],0) ##get deltas_inv Deltas_inv = tf.matrix_inverse(Deltas) Deltas_inv = tf.expand_dims(Deltas_inv,0) Deltas_inv = tf.reshape(Deltas_inv,[-1]) Deltas_inv_f = tf.tile(Deltas_inv,tf.stack([N_f])) Deltas_inv_f = tf.reshape(Deltas_inv_f,tf.stack([N_f,self.X_controlP_number*self.Y_controlP_number*self.Z_controlP_number+4, -1])) cp_trans =tf.transpose(cp,perm=[0,2,1]) zeros_f_In = tf.zeros([N_f,4,3],tf.float32) cp = tf.concat([cp_trans,zeros_f_In],1) T = tf.transpose(tf.matmul(Deltas_inv_f,cp),[0,2,1]) return T
def _define_distance_to_clusters(self, data): """Defines the Mahalanobis distance to the assigned Gaussian.""" # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input - # mean) from log probability function. self._all_scores = [] for shard in data: all_scores = [] shard = tf.expand_dims(shard, 0) for c in xrange(self._num_classes): if self._covariance_type == FULL_COVARIANCE: cov = self._covs[c, :, :] elif self._covariance_type == DIAG_COVARIANCE: cov = tf.diag(self._covs[c, :]) inverse = tf.matrix_inverse(cov + self._min_var) inv_cov = tf.tile( tf.expand_dims(inverse, 0), tf.pack([self._num_examples, 1, 1])) diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2]) m_left = tf.batch_matmul(diff, inv_cov) all_scores.append(tf.sqrt(tf.batch_matmul( m_left, tf.transpose(diff, perm=[0, 2, 1]) ))) self._all_scores.append(tf.reshape( tf.concat(1, all_scores), tf.pack([self._num_examples, self._num_classes]))) # Distance to the associated class. self._all_scores = tf.concat(0, self._all_scores) assignments = tf.concat(0, self.assignments()) rows = tf.to_int64(tf.range(0, self._num_examples)) indices = tf.concat(1, [tf.expand_dims(rows, 1), tf.expand_dims(assignments, 1)]) self._scores = tf.gather_nd(self._all_scores, indices)
def _define_distance_to_clusters(self, data): """Defines the Mahalanobis distance to the assigned Gaussian.""" # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input - # mean) from log probability function. self._all_scores = [] for shard in data: all_scores = [] shard = tf.expand_dims(shard, 0) for c in xrange(self._num_classes): if self._covariance_type == FULL_COVARIANCE: cov = self._covs[c, :, :] elif self._covariance_type == DIAG_COVARIANCE: cov = tf.diag(self._covs[c, :]) inverse = tf.matrix_inverse(cov + self._min_var) inv_cov = tf.tile( tf.expand_dims(inverse, 0), tf.stack([self._num_examples, 1, 1])) diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2]) m_left = tf.batch_matmul(diff, inv_cov) all_scores.append(tf.sqrt(tf.batch_matmul( m_left, tf.transpose(diff, perm=[0, 2, 1]) ))) self._all_scores.append( tf.reshape( tf.concat(1, all_scores), tf.stack([self._num_examples, self._num_classes]))) # Distance to the associated class. self._all_scores = tf.concat(0, self._all_scores) assignments = tf.concat(0, self.assignments()) rows = tf.to_int64(tf.range(0, self._num_examples)) indices = tf.concat(1, [tf.expand_dims(rows, 1), tf.expand_dims(assignments, 1)]) self._scores = tf.gather_nd(self._all_scores, indices)
def _setup_optimizer(self): if self.optimizer == 'seq': self.optimize_op = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss) else: # W_ml is calculated with solving normal equation xt = tf.transpose(self.x) x_xt = tf.matmul(xt, self.x) x_xt_inv = tf.matrix_inverse(x_xt) x_xt_inv_xt = tf.matmul(x_xt_inv, xt) self.w_ml = tf.matmul(x_xt_inv_xt, self.t) self.optimize_op = tf.assign(self.w, self.w_ml)
def _setup_model(self): # Setup mn, sn xt = tf.transpose(self.x) self.sn = tf.matrix_inverse(tf.matrix_inverse(self.s0) + (self.beta * tf.matmul(xt, self.x))) self.mn = tf.matmul(self.sn, tf.matmul(tf.matrix_inverse(self.s0), self.m0) + self.beta * tf.matmul(xt, self.t))
def regularized_inverse(mat, l=0.1): return tf.matrix_inverse(mat + l*Identity(int(mat.shape[0]))) # TODO: this gives biased result when I use identity
def regularized_inverse(mat, l=0.1): return tf.matrix_inverse(mat + l*Identity(int(mat.shape[0]))) # return tf.matrix_inverse(mat) # TODO: this gives biased result when I use identity
def learn_comb_orth(poses, dm_shape, reuse=None, _float_type=tf.float32): with tf.variable_scope("learn_comb", reuse=reuse): comb_matrix = tf.get_variable( "matrix", [dm_shape[0], dm_shape[1]], initializer=identity_initializer(0), dtype=_float_type, trainable=False ) tf.add_to_collection(COMB_MATRIX_COLLECTION, comb_matrix) poses = tf.tensordot(poses, comb_matrix, [[2], [1]]) poses = tf.transpose(poses, [0, 1, 3, 2]) # Special update code def update_comb_mat(grad, lr): A = tf.matmul(tf.transpose(grad), comb_matrix) - \ tf.matmul(tf.transpose(comb_matrix), grad) I = tf.constant(np.eye(dm_shape[0]), dtype=_float_type) t1 = I + lr / 2 * A t2 = I - lr / 2 * A Y = tf.matmul(tf.matmul(tf.matrix_inverse(t1), t2), comb_matrix) return tf.assign(comb_matrix, Y) # Visualization cb_min = tf.reduce_min(comb_matrix) cb_max = tf.reduce_max(comb_matrix) comb_matrix_image = (comb_matrix - cb_min) / (cb_max - cb_min) * 255.0 comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8) comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1]) return poses, comb_matrix_image, update_comb_mat
def test_MatrixInverse(self): t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=False) self.check(t) t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=True) self.check(t)
def inverse(opt,p): with tf.name_scope("inverse"): pMtrx = vec2mtrx(opt,p) pInvMtrx = tf.matrix_inverse(pMtrx) pInv = mtrx2vec(opt,pInvMtrx) return pInv # convert warp parameters to matrix
def get_inv_quadratic_form(self, data, mean): tf_differences = tf.subtract(data, tf.expand_dims(mean, 0)) tf_diff_times_inv_cov = tf.matmul(tf_differences, tf.matrix_inverse(self.tf_covariance_matrix)) return tf.reduce_sum(tf_diff_times_inv_cov * tf_differences, 1)
def interatomic_distances(positions, cell, pbc, cutoff): with tf.variable_scope('distance'): # calculate heights # account for zero cell in case of no pbc c = tf.reduce_sum(tf.cast(pbc, tf.int32)) > 0 icell = tf.cond(c, lambda: tf.matrix_inverse(cell), lambda: tf.eye(3)) height = 1. / tf.sqrt(tf.reduce_sum(tf.square(icell), 0)) extent = tf.where(tf.cast(pbc, tf.bool), tf.cast(tf.floor(cutoff / height), tf.int32), tf.cast(tf.zeros_like(height), tf.int32)) n_reps = tf.reduce_prod(2 * extent + 1) # replicate atoms r = tf.range(-extent[0], extent[0] + 1) v0 = tf.expand_dims(r, 1) v0 = tf.tile(v0, tf.stack(((2 * extent[1] + 1) * (2 * extent[2] + 1), 1))) v0 = tf.reshape(v0, tf.stack((n_reps, 1))) r = tf.range(-extent[1], extent[1] + 1) v1 = tf.expand_dims(r, 1) v1 = tf.tile(v1, tf.stack((2 * extent[2] + 1, 2 * extent[0] + 1))) v1 = tf.reshape(v1, tf.stack((n_reps, 1))) v2 = tf.expand_dims(tf.range(-extent[2], extent[2] + 1), 1) v2 = tf.tile(v2, tf.stack((1, (2 * extent[0] + 1) * (2 * extent[1] + 1)))) v2 = tf.reshape(v2, tf.stack((n_reps, 1))) v = tf.cast(tf.concat((v0, v1, v2), axis=1), tf.float32) offset = tf.matmul(v, cell) offset = tf.expand_dims(offset, 0) # add axes positions = tf.expand_dims(positions, 1) rpos = positions + offset rpos = tf.expand_dims(rpos, 0) positions = tf.expand_dims(positions, 1) euclid_dist = tf.sqrt( tf.reduce_sum(tf.square(positions - rpos), reduction_indices=3)) return euclid_dist
def _makeT(self,cp): with tf.variable_scope('_makeT'): cp = tf.reshape(cp,(-1,3,self.X_controlP_number*self.Y_controlP_number*self.Z_controlP_number)) cp = tf.cast(cp,'float32') N_f = tf.shape(cp)[0] #c_s x,y,z = tf.linspace(-1.,1.,self.X_controlP_number),tf.linspace(-1.,1.,self.Y_controlP_number),tf.linspace(-1.,1.,self.Z_controlP_number) x = tf.tile(x,[self.Y_controlP_number*self.Z_controlP_number]) y = tf.tile(self._repeat(y,self.X_controlP_number,'float32'),[self.Z_controlP_number]) z = self._repeat(z,self.X_controlP_number*self.Y_controlP_number,'float32') xs,ys,zs = tf.transpose(tf.reshape(x,(-1,1))),tf.transpose(tf.reshape(y,(-1,1))),tf.transpose(tf.reshape(z,(-1,1))) cp_s = tf.concat([xs,ys,zs],0) cp_s_trans = tf.transpose(cp_s) # (4*4*4)*3 -> 64 * 3 ##===Compute distance R xs_trans,ys_trans,zs_trans = tf.transpose(tf.stack([xs],axis=2),perm=[1,0,2]),tf.transpose(tf.stack([ys],axis=2),perm=[1,0,2]),tf.transpose(tf.stack([zs],axis=2),perm=[1,0,2]) xs, xs_trans = tf.meshgrid(xs,xs_trans);ys, ys_trans = tf.meshgrid(ys,ys_trans);zs, zs_trans = tf.meshgrid(zs,zs_trans) Rx,Ry, Rz = tf.square(tf.subtract(xs,xs_trans)),tf.square(tf.subtract(ys,ys_trans)),tf.square(tf.subtract(zs,zs_trans)) R = tf.add_n([Rx,Ry,Rz]) #print("R",sess.run(R)) R = tf.multiply(R,tf.log(tf.clip_by_value(R,1e-10,1e+10))) #print("R",sess.run(R)) ones = tf.ones([self.Y_controlP_number*self.X_controlP_number*self.Z_controlP_number,1],tf.float32) ones_trans = tf.transpose(ones) zeros = tf.zeros([4,4],tf.float32) Deltas1 = tf.concat([ones, cp_s_trans, R],1) Deltas2 = tf.concat([ones_trans,cp_s],0) Deltas2 = tf.concat([zeros,Deltas2],1) Deltas = tf.concat([Deltas1,Deltas2],0) #print("Deltas",sess.run(Deltas)) ##get deltas_inv Deltas_inv = tf.matrix_inverse(Deltas) Deltas_inv = tf.expand_dims(Deltas_inv,0) Deltas_inv = tf.reshape(Deltas_inv,[-1]) Deltas_inv_f = tf.tile(Deltas_inv,tf.stack([N_f])) Deltas_inv_f = tf.reshape(Deltas_inv_f,tf.stack([N_f,self.X_controlP_number*self.Y_controlP_number*self.Z_controlP_number+4, -1])) cp_trans =tf.transpose(cp,perm=[0,2,1]) zeros_f_In = tf.zeros([N_f,4,3],tf.float32) cp = tf.concat([cp_trans,zeros_f_In],1) #print("cp",sess.run(cp)) #print("Deltas_inv_f",sess.run(Deltas_inv_f)) T = tf.transpose(tf.matmul(Deltas_inv_f,cp),[0,2,1]) #print("T",sess.run(T)) return T
def inv(kron_a): """Computes the inverse of a given Kronecker-factorized matrix. Args: kron_a: `TensorTrain` object containing a matrix of size N x N, factorized into a Kronecker product of square matrices (all tt-ranks are 1 and all tt-cores are square). All the cores must be invertable. Returns: `TensorTrain` object, containing a TT-matrix of size N x N. Raises: ValueError if the tt-cores of the provided matrix are not square, or the tt-ranks are not 1. """ if not _is_kron(kron_a): raise ValueError('The argument should be a Kronecker product ' '(tt-ranks should be 1)') shapes_defined = kron_a.get_shape().is_fully_defined() if shapes_defined: i_shapes = kron_a.get_raw_shape()[0] j_shapes = kron_a.get_raw_shape()[1] else: i_shapes = ops.raw_shape(kron_a)[0] j_shapes = ops.raw_shape(kron_a)[1] if shapes_defined: if i_shapes != j_shapes: raise ValueError('The argument should be a Kronecker product of square ' 'matrices (tt-cores must be square)') inv_cores = [] for core_idx in range(kron_a.ndims()): core = kron_a.tt_cores[core_idx] core_inv = tf.matrix_inverse(core[0, :, :, 0]) inv_cores.append(tf.expand_dims(tf.expand_dims(core_inv, 0), -1)) res_ranks = kron_a.get_tt_ranks() res_shape = kron_a.get_raw_shape() return TensorTrain(inv_cores, res_shape, res_ranks)
def learn_comb_orth_rmsprop(poses, dm_shape, reuse=None, _float_type=tf.float32): with tf.variable_scope("learn_comb", reuse=reuse): comb_matrix = tf.get_variable( "matrix", [dm_shape[0], dm_shape[1]], initializer=identity_initializer(0), dtype=_float_type, trainable=False ) comb_matrix_m = tf.get_variable( "matrix_momentum", [dm_shape[0], dm_shape[1]], initializer=tf.zeros_initializer(), dtype=_float_type, trainable=False ) tf.add_to_collection(COMB_MATRIX_COLLECTION, comb_matrix) poses = tf.tensordot(poses, comb_matrix, [[2], [1]]) poses = tf.transpose(poses, [0, 1, 3, 2]) # Special update code def update_comb_mat(grad, lr): I = tf.constant(np.eye(dm_shape[0]), dtype=_float_type) # Momentum update momentum_op = tf.assign(comb_matrix_m, comb_matrix_m * 0.99 + (1 - 0.99) * tf.square(grad)) with tf.control_dependencies([momentum_op]): # Matrix update scaled_grad = lr * grad / tf.sqrt(comb_matrix_m + 1.e-5) A = tf.matmul(tf.transpose(scaled_grad), comb_matrix) - \ tf.matmul(tf.transpose(comb_matrix), scaled_grad) t1 = I + 0.5 * A t2 = I - 0.5 * A Y = tf.matmul(tf.matmul(tf.matrix_inverse(t1), t2), comb_matrix) return tf.assign(comb_matrix, Y) # Visualization cb_min = tf.reduce_min(comb_matrix) cb_max = tf.reduce_max(comb_matrix) comb_matrix_image = (comb_matrix - cb_min) / (cb_max - cb_min) * 255.0 comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8) comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1]) return poses, comb_matrix_image, update_comb_mat
def forward_step_fn(self, params, inputs): """ Forward step over a batch, to be used in tf.scan :param params: :param inputs: (batch_size, variable dimensions) :return: """ mu_pred, Sigma_pred, _, _, alpha, u, state, buffer, _, _, _ = params y = tf.slice(inputs, [0, 0], [-1, self.dim_y]) # (bs, dim_y) _u = tf.slice(inputs, [0, self.dim_y], [-1, self.dim_u]) # (bs, dim_u) mask = tf.slice(inputs, [0, self.dim_y + self.dim_u], [-1, 1]) # (bs, dim_u) # Mixture of C C = tf.matmul(alpha, tf.reshape(self.C, [-1, self.dim_y*self.dim_z])) # (bs, k) x (k, dim_y*dim_z) C = tf.reshape(C, [-1, self.dim_y, self.dim_z]) # (bs, dim_y, dim_z) C.set_shape([Sigma_pred.get_shape()[0], self.dim_y, self.dim_z]) # Residual y_pred = tf.squeeze(tf.matmul(C, tf.expand_dims(mu_pred, 2))) # (bs, dim_y) r = y - y_pred # (bs, dim_y) # project system uncertainty into measurement space S = tf.matmul(tf.matmul(C, Sigma_pred), C, transpose_b=True) + self.R # (bs, dim_y, dim_y) S_inv = tf.matrix_inverse(S) K = tf.matmul(tf.matmul(Sigma_pred, C, transpose_b=True), S_inv) # (bs, dim_z, dim_y) # For missing values, set to 0 the Kalman gain matrix K = tf.multiply(tf.expand_dims(mask, 2), K) # Get current mu and Sigma mu_t = mu_pred + tf.squeeze(tf.matmul(K, tf.expand_dims(r, 2))) # (bs, dim_z) I_KC = self._I - tf.matmul(K, C) # (bs, dim_z, dim_z) Sigma_t = tf.matmul(tf.matmul(I_KC, Sigma_pred), I_KC, transpose_b=True) + self._sast(self.R, K) # (bs, dim_z, dim_z) # Mixture of A alpha, state, u, buffer = self.alpha(tf.multiply(mask, y) + tf.multiply((1-mask), y_pred), state, _u, buffer, reuse=True) # (bs, k) A = tf.matmul(alpha, tf.reshape(self.A, [-1, self.dim_z*self.dim_z])) # (bs, k) x (k, dim_z*dim_z) A = tf.reshape(A, [-1, self.dim_z, self.dim_z]) # (bs, dim_z, dim_z) A.set_shape(Sigma_pred.get_shape()) # set shape to batch_size x dim_z x dim_z # Mixture of B B = tf.matmul(alpha, tf.reshape(self.B, [-1, self.dim_z*self.dim_u])) # (bs, k) x (k, dim_y*dim_z) B = tf.reshape(B, [-1, self.dim_z, self.dim_u]) # (bs, dim_y, dim_z) B.set_shape([A.get_shape()[0], self.dim_z, self.dim_u]) # Prediction mu_pred = tf.squeeze(tf.matmul(A, tf.expand_dims(mu_t, 2))) + tf.squeeze(tf.matmul(B, tf.expand_dims(u, 2))) Sigma_pred = tf.scalar_mul(self._alpha_sq, tf.matmul(tf.matmul(A, Sigma_t), A, transpose_b=True) + self.Q) return mu_pred, Sigma_pred, mu_t, Sigma_t, alpha, u, state, buffer, A, B, C