我们从Python开源项目中,提取了以下11个代码示例,用于说明如何使用torch.lt()。
def test_logical(self): x = torch.rand(100, 100) * 2 - 1 xx = x.clone() xgt = torch.gt(x, 1) xlt = torch.lt(x, 1) xeq = torch.eq(x, 1) xne = torch.ne(x, 1) neqs = xgt + xlt all = neqs + xeq self.assertEqual(neqs.sum(), xne.sum(), 0) self.assertEqual(x.nelement(), all.sum())
def test_comparison_ops(self): x = torch.randn(5, 5) y = torch.randn(5, 5) eq = x == y for idx in iter_indices(x): self.assertIs(x[idx] == y[idx], eq[idx] == 1) ne = x != y for idx in iter_indices(x): self.assertIs(x[idx] != y[idx], ne[idx] == 1) lt = x < y for idx in iter_indices(x): self.assertIs(x[idx] < y[idx], lt[idx] == 1) le = x <= y for idx in iter_indices(x): self.assertIs(x[idx] <= y[idx], le[idx] == 1) gt = x > y for idx in iter_indices(x): self.assertIs(x[idx] > y[idx], gt[idx] == 1) ge = x >= y for idx in iter_indices(x): self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
def lesser(x: T.FloatTensor, y: T.FloatTensor) -> T.ByteTensor: """ Elementwise test if x < y. Args: x: A tensor. y: A tensor. Returns: tensor (of bools): Elementwise test of x < y. """ return torch.lt(x, y)
def reward(sample_solution, USE_CUDA=False): """ The reward for the sorting task is defined as the length of the longest sorted consecutive subsequence. Input sequences must all be the same length. Example: input | output ==================== [1 4 3 5 2] | [5 1 2 3 4] The output gets a reward of 4/5, or 0.8 The range is [1/sourceL, 1] Args: sample_solution: list of len sourceL of [batch_size] Tensors Returns: [batch_size] containing trajectory rewards """ batch_size = sample_solution[0].size(0) sourceL = len(sample_solution) longest = Variable(torch.ones(batch_size), requires_grad=False) current = Variable(torch.ones(batch_size), requires_grad=False) if USE_CUDA: longest = longest.cuda() current = current.cuda() for i in range(1, sourceL): # compare solution[i-1] < solution[i] res = torch.lt(sample_solution[i-1], sample_solution[i]) # if res[i,j] == 1, increment length of current sorted subsequence current += res.float() # else, reset current to 1 current[torch.eq(res, 0)] = 1 #current[torch.eq(res, 0)] -= 1 # if, for any, current > longest, update longest mask = torch.gt(current, longest) longest[mask] = current[mask] return -torch.div(longest, sourceL)