我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用torch.inverse()。
def test_inverse(self): M = torch.randn(5,5) MI = torch.inverse(M) E = torch.eye(5) self.assertFalse(MI.is_contiguous(), 'MI is contiguous') self.assertEqual(E, torch.mm(M, MI), 1e-8, 'inverse value') self.assertEqual(E, torch.mm(MI, M), 1e-8, 'inverse value') MII = torch.Tensor(5, 5) torch.inverse(MII, M) self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place') # second call, now that MII is transposed torch.inverse(MII, M) self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place')
def test_inverse(self): M = torch.randn(5, 5) MI = torch.inverse(M) E = torch.eye(5) self.assertFalse(MI.is_contiguous(), 'MI is contiguous') self.assertEqual(E, torch.mm(M, MI), 1e-8, 'inverse value') self.assertEqual(E, torch.mm(MI, M), 1e-8, 'inverse value') MII = torch.Tensor(5, 5) torch.inverse(M, out=MII) self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place') # second call, now that MII is transposed torch.inverse(M, out=MII) self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place')
def tset_potri(self): a=torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() # make sure 'a' is symmetric PSD a = a * a.t() # compute inverse directly inv0 = torch.inverse(a) # default case chol = torch.potrf(a) inv1 = torch.potri(chol) self.assertLessEqual(inv0.dist(inv1), 1e-12) # upper Triangular Test chol = torch.potrf(a, 'U') inv1 = torch.potri(chol, 'U') self.assertLessEqual(inv0.dist(inv1), 1e-12) # lower Triangular Test chol = torch.potrf(a, 'L') inv1 = torch.potri(chol, 'L') self.assertLessEqual(inv0.dist(inv1), 1e-12)
def tset_potri(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() # make sure 'a' is symmetric PSD a = a * a.t() # compute inverse directly inv0 = torch.inverse(a) # default case chol = torch.potrf(a) inv1 = torch.potri(chol) self.assertLessEqual(inv0.dist(inv1), 1e-12) # upper Triangular Test chol = torch.potrf(a, 'U') inv1 = torch.potri(chol, 'U') self.assertLessEqual(inv0.dist(inv1), 1e-12) # lower Triangular Test chol = torch.potrf(a, 'L') inv1 = torch.potri(chol, 'L') self.assertLessEqual(inv0.dist(inv1), 1e-12)
def forward(ctx, input): inverse = torch.inverse(input) ctx.save_for_backward(inverse) return inverse
def backward(ctx, grad_output): inverse, = ctx.saved_variables return -torch.mm(inverse.t(), torch.mm(grad_output, inverse.t()))