我们从Python开源项目中,提取了以下7个代码示例,用于说明如何使用torch.symeig()。
def test_symeig(self): xval = torch.rand(100,3) cov = torch.mm(xval.t(), xval) rese = torch.zeros(3) resv = torch.zeros(3,3) # First call to symeig self.assertTrue(resv.is_contiguous(), 'resv is not contiguous') torch.symeig(rese, resv, cov.clone(), True) ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t()) self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong') # Second call to symeig self.assertFalse(resv.is_contiguous(), 'resv is contiguous') torch.symeig(rese, resv, cov.clone(), True) ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t()) self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong') # test non-contiguous X = torch.rand(5, 5) X = X.t() * X e = torch.zeros(4, 2).select(1, 1) v = torch.zeros(4, 2, 4)[:,1] self.assertFalse(v.is_contiguous(), 'V is contiguous') self.assertFalse(e.is_contiguous(), 'E is contiguous') torch.symeig(e, v, X, True) Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t()) self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
def test_symeig(self): xval = torch.rand(100, 3) cov = torch.mm(xval.t(), xval) rese = torch.zeros(3) resv = torch.zeros(3, 3) # First call to symeig self.assertTrue(resv.is_contiguous(), 'resv is not contiguous') torch.symeig(cov.clone(), True, out=(rese, resv)) ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t()) self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong') # Second call to symeig self.assertFalse(resv.is_contiguous(), 'resv is contiguous') torch.symeig(cov.clone(), True, out=(rese, resv)) ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t()) self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong') # test non-contiguous X = torch.rand(5, 5) X = X.t() * X e = torch.zeros(4, 2).select(1, 1) v = torch.zeros(4, 2, 4)[:, 1] self.assertFalse(v.is_contiguous(), 'V is contiguous') self.assertFalse(e.is_contiguous(), 'E is contiguous') torch.symeig(X, True, out=(e, v)) Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t()) self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
def test_symeig(self): # Small case tensor = torch.randn(3, 3).cuda() tensor = torch.mm(tensor, tensor.t()) eigval, eigvec = torch.symeig(tensor, eigenvectors=True) self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t())) # Large case tensor = torch.randn(257, 257).cuda() tensor = torch.mm(tensor, tensor.t()) eigval, eigvec = torch.symeig(tensor, eigenvectors=True) self.assertEqual(tensor, torch.mm(torch.mm(eigvec, eigval.diag()), eigvec.t()))
def step(self): # Add weight decay if self.weight_decay > 0: for p in self.model.parameters(): p.grad.data.add_(self.weight_decay, p.data) updates = {} for i, m in enumerate(self.modules): assert len(list(m.parameters()) ) == 1, "Can handle only one parameter at the moment" classname = m.__class__.__name__ p = next(m.parameters()) la = self.damping + self.weight_decay if self.steps % self.Tf == 0: # My asynchronous implementation exists, I will add it later. # Experimenting with different ways to this in PyTorch. self.d_a[m], self.Q_a[m] = torch.symeig( self.m_aa[m], eigenvectors=True) self.d_g[m], self.Q_g[m] = torch.symeig( self.m_gg[m], eigenvectors=True) self.d_a[m].mul_((self.d_a[m] > 1e-6).float()) self.d_g[m].mul_((self.d_g[m] > 1e-6).float()) if classname == 'Conv2d': p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1) else: p_grad_mat = p.grad.data v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] v2 = v1 / ( self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la) v = self.Q_g[m] @ v2 @ self.Q_a[m].t() v = v.view(p.grad.data.size()) updates[p] = v vg_sum = 0 for p in self.model.parameters(): v = updates[p] vg_sum += (v * p.grad.data * self.lr * self.lr).sum() nu = min(1, math.sqrt(self.kl_clip / vg_sum)) for p in self.model.parameters(): v = updates[p] p.grad.data.copy_(v) p.grad.data.mul_(nu) self.optim.step() self.steps += 1