我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用torch.dsmm()。
def test_dsmm(self): def test_shape(di, dj, dk): x = self._gen_sparse(2, 20, [di, dj])[0] y = self.randn(dj, dk) res = torch.dsmm(x, y) expected = torch.mm(x.to_dense(), y) self.assertEqual(res, expected) test_shape(7, 5, 3) test_shape(1000, 100, 100) test_shape(3000, 64, 300)
def test_interpolation(): x = torch.linspace(0.01, 1, 100) grid = torch.linspace(-0.05, 1.05, 50) J, C = Interpolation().interpolate(grid, x) W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid)) test_func_grid = grid.pow(2) test_func_x = x.pow(2) interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze() assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
def _derivative_quadratic_form_factory(self, *args): def closure(left_vectors, right_vectors): if left_vectors.ndimension() == 1: left_factor = left_vectors.unsqueeze(0) right_factor = right_vectors.unsqueeze(0) else: left_factor = left_vectors right_factor = right_vectors if len(args) == 1: columns, = args return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor), elif len(args) == 3: columns, W_left, W_right = args left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor) return res, None, None elif len(args) == 4: columns, W_left, W_right, added_diag, = args diag_grad = columns.new(len(added_diag)).zero_() diag_grad[0] = (left_factor * right_factor).sum() left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor) return res, None, None, diag_grad return closure
def forward(self, dense): if self.sparse.ndimension() == 3: return bdsmm(self.sparse, dense) else: return torch.dsmm(self.sparse, dense)
def backward(self, grad_output): if self.sparse.ndimension() == 3: return bdsmm(self.sparse.transpose(1, 2), grad_output) else: return torch.dsmm(self.sparse.t(), grad_output)
def test_dsmm(self): def test_shape(di, dj, dk): x = self._gen_sparse(2, 20, [di, dj])[0] y = self.randn(dj, dk) res = torch.dsmm(x, y) expected = torch.mm(self.safeToDense(x), y) self.assertEqual(res, expected) test_shape(7, 5, 3) test_shape(1000, 100, 100) test_shape(3000, 64, 300)
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None): """ Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is symmetric Toeplitz matrices. Args: - toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with length n_i - interp_left (sparse matrix nxm) - Left interpolation matrix - interp_right (sparse matrix pxm) - Right interpolation matrix - tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with - noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end. Returns: - tensor """ output_dims = tensor.ndimension() noise_term = None if output_dims == 1: tensor = tensor.unsqueeze(1) if noise_diag is not None: noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor if interp_left is not None: # Get interp_{r}^{T} tensor interp_right_tensor = torch.dsmm(interp_right.t(), tensor) # Get (T interp_{r}^{T}) tensor rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor) # Get (interp_{l} T interp_{r}^{T})tensor output = torch.dsmm(interp_left, rhs) else: output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor) if noise_term is not None: # Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor output = output + noise_term if output_dims == 1: output = output.squeeze(1) return output