我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用torch.pstrf()。
def test_pstrf(self): def checkPsdCholesky(a, uplo, inplace): if inplace: u = torch.Tensor(a.size()) piv = torch.IntTensor(a.size(0)) args = [u, piv, a] else: args = [a] if uplo is not None: args += [uplo] u, piv = torch.pstrf(*args) if uplo is False: a_reconstructed = torch.mm(u, u.t()) else: a_reconstructed = torch.mm(u.t(), u) piv = piv.long() a_permuted = a.index_select(0, piv).index_select(1, piv) self.assertEqual(a_permuted, a_reconstructed, 1e-14) dimensions = ((5, 1), (5, 3), (5, 5), (10, 10)) for dim in dimensions: m = torch.Tensor(*dim).uniform_() a = torch.mm(m, m.t()) # add a small number to the diagonal to make the matrix numerically positive semidefinite for i in range(m.size(0)): a[i][i] = a[i][i] + 1e-7 for inplace in (True, False): for uplo in (None, True, False): checkPsdCholesky(a, uplo, inplace)
def test_pstrf(self): def checkPsdCholesky(a, uplo, inplace): if inplace: u = torch.Tensor(a.size()) piv = torch.IntTensor(a.size(0)) kwargs = {'out': (u, piv)} else: kwargs = {} args = [a] if uplo is not None: args += [uplo] u, piv = torch.pstrf(*args, **kwargs) if uplo is False: a_reconstructed = torch.mm(u, u.t()) else: a_reconstructed = torch.mm(u.t(), u) piv = piv.long() a_permuted = a.index_select(0, piv).index_select(1, piv) self.assertEqual(a_permuted, a_reconstructed, 1e-14) dimensions = ((5, 1), (5, 3), (5, 5), (10, 10)) for dim in dimensions: m = torch.Tensor(*dim).uniform_() a = torch.mm(m, m.t()) # add a small number to the diagonal to make the matrix numerically positive semidefinite for i in range(m.size(0)): a[i][i] = a[i][i] + 1e-7 for inplace in (True, False): for uplo in (None, True, False): checkPsdCholesky(a, uplo, inplace)