我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用torch.trtrs()。
def _ssor_preconditioner(self, lhs_mat, mat): if lhs_mat.ndimension() == 2: DL = lhs_mat.tril() D = lhs_mat.diag() upper_part = (1 / D).expand_as(DL).mul(DL.t()) Minv_times_mat = torch.trtrs(torch.trtrs(mat, DL, upper=False)[0], upper_part)[0] elif lhs_mat.ndimension() == 3: if mat.size(0) == 1 and lhs_mat.size(0) != 1: mat = mat.expand(*([lhs_mat.size(0)] + list(mat.size())[1:])) Minv_times_mat = mat.new(*mat.size()) for i in range(lhs_mat.size(0)): DL = lhs_mat[i].tril() D = lhs_mat[i].diag() upper_part = (1 / D).expand_as(DL).mul(DL.t()) Minv_times_mat[i].copy_(torch.trtrs(torch.trtrs(mat[i], DL, upper=False)[0], upper_part)[0]) else: raise RuntimeError('Invalid number of dimensions') return Minv_times_mat
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, _ = get_sizes(G, A) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] U_Q = torch.potrf(Q) # partial cholesky of S matrix U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q) G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q)) R = G_invQ_GT if neq > 0: invQ_AT = torch.potrs(A.t(), U_Q) A_invQ_AT = torch.mm(A, invQ_AT) G_invQ_AT = torch.mm(G, invQ_AT) # TODO: torch.potrf sometimes says the matrix is not PSD but # numpy does? I filed an issue at # https://github.com/pytorch/pytorch/issues/199 try: U11 = torch.potrf(A_invQ_AT) except: U11 = torch.Tensor(np.linalg.cholesky( A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT) # TODO: torch.trtrs is currently not implemented on the GPU # and we are using gesv as a workaround. U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0] U_S[:neq, :neq] = U11 U_S[:neq, neq:] = U12 R -= torch.mm(U12.t(), U12) return U_Q, U_S, R
def test_trtrs(self): def _test_with_size(N, C): A = Variable(torch.rand(N, N), requires_grad=True) b = Variable(torch.rand(N, C), requires_grad=True) for upper, transpose, unitriangular in product((True, False), repeat=3): def func(A, b): return torch.trtrs(b, A, upper, transpose, unitriangular) gradcheck(func, [A, b]) gradgradcheck(func, [A, b]) _test_with_size(S, S + 1) _test_with_size(S, S - 1)
def test_trtrs(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99))).t() U = torch.triu(a) L = torch.tril(a) # solve Ux = b x = torch.trtrs(b, U)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) x = torch.trtrs(b, U, True, False, False)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) # solve Lx = b x = torch.trtrs(b, L, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) x = torch.trtrs(b, L, False, False, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) # solve U'x = b x = torch.trtrs(b, U, True, True)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) x = torch.trtrs(b, U, True, True, False)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) # solve U'x = b by manual transposition y = torch.trtrs(b, U.t(), False, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # solve L'x = b x = torch.trtrs(b, L, False, True)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) x = torch.trtrs(b, L, False, True, False)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) # solve L'x = b by manual transposition y = torch.trtrs(b, L.t(), True, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # test reuse res1 = torch.trtrs(b,a)[0] ta = torch.Tensor() tb = torch.Tensor() torch.trtrs(tb,ta,b,a) self.assertEqual(res1, tb, 0) tb.zero_() torch.trtrs(tb,ta,b,a) self.assertEqual(res1, tb, 0)
def test_trtrs(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99))).t() U = torch.triu(a) L = torch.tril(a) # solve Ux = b x = torch.trtrs(b, U)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) x = torch.trtrs(b, U, True, False, False)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) # solve Lx = b x = torch.trtrs(b, L, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) x = torch.trtrs(b, L, False, False, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) # solve U'x = b x = torch.trtrs(b, U, True, True)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) x = torch.trtrs(b, U, True, True, False)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) # solve U'x = b by manual transposition y = torch.trtrs(b, U.t(), False, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # solve L'x = b x = torch.trtrs(b, L, False, True)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) x = torch.trtrs(b, L, False, True, False)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) # solve L'x = b by manual transposition y = torch.trtrs(b, L.t(), True, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # test reuse res1 = torch.trtrs(b, a)[0] ta = torch.Tensor() tb = torch.Tensor() torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) tb.zero_() torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0)