我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用torch.svd()。
def test_svd(self): a=torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84), (9.93, 6.91, -7.93, 1.64, 4.02, 0.15), (9.83, 5.04, 4.86, 8.83, 9.80, -8.99), (5.45, -0.27, 4.85, 0.74, 10.00, -6.02), (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone() u, s, v = torch.svd(a) uu = torch.Tensor() ss = torch.Tensor() vv = torch.Tensor() uuu, sss, vvv = torch.svd(uu, ss, vv, a) self.assertEqual(u, uu, 0, 'torch.svd') self.assertEqual(u, uuu, 0, 'torch.svd') self.assertEqual(s, ss, 0, 'torch.svd') self.assertEqual(s, sss, 0, 'torch.svd') self.assertEqual(v, vv, 0, 'torch.svd') self.assertEqual(v, vvv, 0, 'torch.svd') # test reuse X = torch.randn(4, 4) U, S, V = torch.svd(X) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong') self.assertFalse(U.is_contiguous(), 'U is contiguous') torch.svd(U, S, V, X) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong') # test non-contiguous X = torch.randn(5, 5) U = torch.zeros(5, 2, 5)[:,1] S = torch.zeros(5, 2)[:,1] V = torch.zeros(5, 2, 5)[:,1] self.assertFalse(U.is_contiguous(), 'U is contiguous') self.assertFalse(S.is_contiguous(), 'S is contiguous') self.assertFalse(V.is_contiguous(), 'V is contiguous') torch.svd(U, S, V, X) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def test_svd(self): a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84), (9.93, 6.91, -7.93, 1.64, 4.02, 0.15), (9.83, 5.04, 4.86, 8.83, 9.80, -8.99), (5.45, -0.27, 4.85, 0.74, 10.00, -6.02), (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone() u, s, v = torch.svd(a) uu = torch.Tensor() ss = torch.Tensor() vv = torch.Tensor() uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv)) self.assertEqual(u, uu, 0, 'torch.svd') self.assertEqual(u, uuu, 0, 'torch.svd') self.assertEqual(s, ss, 0, 'torch.svd') self.assertEqual(s, sss, 0, 'torch.svd') self.assertEqual(v, vv, 0, 'torch.svd') self.assertEqual(v, vvv, 0, 'torch.svd') # test reuse X = torch.randn(4, 4) U, S, V = torch.svd(X) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong') self.assertFalse(U.is_contiguous(), 'U is contiguous') torch.svd(X, out=(U, S, V)) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong') # test non-contiguous X = torch.randn(5, 5) U = torch.zeros(5, 2, 5)[:, 1] S = torch.zeros(5, 2)[:, 1] V = torch.zeros(5, 2, 5)[:, 1] self.assertFalse(U.is_contiguous(), 'U is contiguous') self.assertFalse(S.is_contiguous(), 'S is contiguous') self.assertFalse(V.is_contiguous(), 'V is contiguous') torch.svd(X, out=(U, S, V)) Xhat = torch.mm(U, torch.mm(S.diag(), V.t())) self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
def func(): linalg.svd(np_data)
def func(): linalg.svd(np_data, lapack_driver='gesvd');
def func(): linalg.svd(np_data, lapack_driver='gesdd');
def func(): torch.svd(torch.rand((N,N)))
def func(): torch.svd(torch.rand((N,N)).cuda())
def partial_svd(matrix, n_eigenvecs=None): """Computes a fast partial SVD on `matrix` if `n_eigenvecs` is specified, sparse eigendecomposition is used on either matrix.dot(matrix.T) or matrix.T.dot(matrix) Parameters ---------- matrix : 2D-array n_eigenvecs : int, optional, default is None if specified, number of eigen[vectors-values] to return Returns ------- U : 2D-array of shape (matrix.shape[0], n_eigenvecs) contains the right singular vectors S : 1D-array of shape (n_eigenvecs, ) contains the singular values of `matrix` V : 2D-array of shape (n_eigenvecs, matrix.shape[1]) contains the left singular vectors """ # Check that matrix is... a matrix! if ndim(matrix) != 2: raise ValueError('matrix be a matrix. matrix.ndim is {} != 2'.format( ndim(matrix))) U, S, V = torch.svd(matrix, some=False) U, S, V = U[:, :n_eigenvecs], S[:n_eigenvecs], V.t()[:n_eigenvecs, :] return U, S, V
def det(var): r"""Calculates determinant of a 2D square Variable. .. note:: Backward through `det` internally uses SVD results. So double backward through `det` will need to backward through :meth:`~Tensor.svd`. This can be unstable in certain cases. Please see :meth:`~torch.svd` for details. Arguments: var (Variable): The input 2D square Variable. """ if torch.is_tensor(var): raise ValueError("det is currently only supported on Variable") return var.det()