我们从Python开源项目中,提取了以下14个代码示例,用于说明如何使用torch.triu()。
def test_triu(self): x = torch.rand(SIZE, SIZE) res1 = torch.triu(x) res2 = torch.Tensor() torch.triu(res2, x) self.assertEqual(res1, res2, 0)
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): """Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. Returns a tuple indexed by: 0: The pivots. 1: The L tensor. 2: The U tensor. Arguments: LU_data (Tensor): The packed LU factorization data. LU_pivots (Tensor): The packed LU factorization pivots. unpack_data (bool): Flag indicating if the data should be unpacked. unpack_pivots (bool): Flag indicating if the pivots should be unpacked. """ nBatch, sz, _ = LU_data.size() if unpack_data: I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) I_L = 1 - I_U L = LU_data.new(LU_data.size()).zero_() U = LU_data.new(LU_data.size()).zero_() I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) L[I_diag] = 1.0 L[I_L] = LU_data[I_L] U[I_U] = LU_data[I_U] else: L = U = None if unpack_pivots: P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1) for i in range(nBatch): for j in range(sz): k = LU_pivots[i, j] - 1 t = P[i, :, j].clone() P[i, :, j] = P[i, :, k] P[i, :, k] = t else: P = None return P, L, U
def test_triu(self): x = torch.rand(SIZE, SIZE) res1 = torch.triu(x) res2 = torch.Tensor() torch.triu(x, out=res2) self.assertEqual(res1, res2, 0)
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. Returns a tuple indexed by: 0: The pivots. 1: The L tensor. 2: The U tensor. Arguments: LU_data (Tensor): the packed LU factorization data LU_pivots (Tensor): the packed LU factorization pivots unpack_data (bool): flag indicating if the data should be unpacked unpack_pivots (bool): tlag indicating if the pivots should be unpacked """ nBatch, sz, _ = LU_data.size() if unpack_data: I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) I_L = 1 - I_U L = LU_data.new(LU_data.size()).zero_() U = LU_data.new(LU_data.size()).zero_() I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) L[I_diag] = 1.0 L[I_L] = LU_data[I_L] U[I_U] = LU_data[I_U] else: L = U = None if unpack_pivots: P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1) for i in range(nBatch): for j in range(sz): k = LU_pivots[i, j] - 1 t = P[i, :, j].clone() P[i, :, j] = P[i, :, k] P[i, :, k] = t else: P = None return P, L, U
def test_trtrs(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99))).t() U = torch.triu(a) L = torch.tril(a) # solve Ux = b x = torch.trtrs(b, U)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) x = torch.trtrs(b, U, True, False, False)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) # solve Lx = b x = torch.trtrs(b, L, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) x = torch.trtrs(b, L, False, False, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) # solve U'x = b x = torch.trtrs(b, U, True, True)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) x = torch.trtrs(b, U, True, True, False)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) # solve U'x = b by manual transposition y = torch.trtrs(b, U.t(), False, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # solve L'x = b x = torch.trtrs(b, L, False, True)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) x = torch.trtrs(b, L, False, True, False)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) # solve L'x = b by manual transposition y = torch.trtrs(b, L.t(), True, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # test reuse res1 = torch.trtrs(b,a)[0] ta = torch.Tensor() tb = torch.Tensor() torch.trtrs(tb,ta,b,a) self.assertEqual(res1, tb, 0) tb.zero_() torch.trtrs(tb,ta,b,a) self.assertEqual(res1, tb, 0)
def test_trtrs(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), (8.32, 2.71, 4.35, -7.17, 2.14), (-9.67, -5.14, -7.26, 6.08, -6.87))).t() b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99))).t() U = torch.triu(a) L = torch.tril(a) # solve Ux = b x = torch.trtrs(b, U)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) x = torch.trtrs(b, U, True, False, False)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) # solve Lx = b x = torch.trtrs(b, L, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) x = torch.trtrs(b, L, False, False, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) # solve U'x = b x = torch.trtrs(b, U, True, True)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) x = torch.trtrs(b, U, True, True, False)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) # solve U'x = b by manual transposition y = torch.trtrs(b, U.t(), False, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # solve L'x = b x = torch.trtrs(b, L, False, True)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) x = torch.trtrs(b, L, False, True, False)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) # solve L'x = b by manual transposition y = torch.trtrs(b, L.t(), True, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # test reuse res1 = torch.trtrs(b, a)[0] ta = torch.Tensor() tb = torch.Tensor() torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) tb.zero_() torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0)