我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用torch.kthvalue()。
def test_kthvalue(self): SIZE = 50 x = torch.rand(SIZE, SIZE, SIZE) x0 = x.clone() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:,:,0], res2val[:,:,k-1], 0) self.assertEqual(res1ind[:,:,0], res2ind[:,:,k-1], 0) # test use of result tensors k = random.randint(1, SIZE) res1val = torch.Tensor() res1ind = torch.LongTensor() torch.kthvalue(res1val, res1ind, x, k) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:,:,0], res2val[:,:,k-1], 0) self.assertEqual(res1ind[:,:,0], res2ind[:,:,k-1], 0) # test non-default dim k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k, 0) res2val, res2ind = torch.sort(x, 0) self.assertEqual(res1val[0], res2val[k-1], 0) self.assertEqual(res1ind[0], res2ind[k-1], 0) # non-contiguous y = x.narrow(1, 0, 1) y0 = y.contiguous() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(y, k) res2val, res2ind = torch.kthvalue(y0, k) self.assertEqual(res1val, res2val, 0) self.assertEqual(res1ind, res2ind, 0) # check that the input wasn't modified self.assertEqual(x, x0, 0) # simple test case (with repetitions) y = torch.Tensor((3, 5, 4, 1, 1, 5)) self.assertEqual(torch.kthvalue(y, 3)[0], torch.Tensor((3,)), 0) self.assertEqual(torch.kthvalue(y, 2)[0], torch.Tensor((1,)), 0)
def test_kthvalue(self): SIZE = 50 x = torch.rand(SIZE, SIZE, SIZE) x0 = x.clone() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:, :, 0], res2val[:, :, k - 1], 0) self.assertEqual(res1ind[:, :, 0], res2ind[:, :, k - 1], 0) # test use of result tensors k = random.randint(1, SIZE) res1val = torch.Tensor() res1ind = torch.LongTensor() torch.kthvalue(x, k, out=(res1val, res1ind)) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:, :, 0], res2val[:, :, k - 1], 0) self.assertEqual(res1ind[:, :, 0], res2ind[:, :, k - 1], 0) # test non-default dim k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k, 0) res2val, res2ind = torch.sort(x, 0) self.assertEqual(res1val[0], res2val[k - 1], 0) self.assertEqual(res1ind[0], res2ind[k - 1], 0) # non-contiguous y = x.narrow(1, 0, 1) y0 = y.contiguous() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(y, k) res2val, res2ind = torch.kthvalue(y0, k) self.assertEqual(res1val, res2val, 0) self.assertEqual(res1ind, res2ind, 0) # check that the input wasn't modified self.assertEqual(x, x0, 0) # simple test case (with repetitions) y = torch.Tensor((3, 5, 4, 1, 1, 5)) self.assertEqual(torch.kthvalue(y, 3)[0], torch.Tensor((3,)), 0) self.assertEqual(torch.kthvalue(y, 2)[0], torch.Tensor((1,)), 0)
def test_kthvalue(self): SIZE = 50 x = torch.rand(SIZE, SIZE, SIZE) x0 = x.clone() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k, False) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) # test use of result tensors k = random.randint(1, SIZE) res1val = torch.Tensor() res1ind = torch.LongTensor() torch.kthvalue(x, k, False, out=(res1val, res1ind)) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) # test non-default dim k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(x, k, 0, False) res2val, res2ind = torch.sort(x, 0) self.assertEqual(res1val, res2val[k - 1], 0) self.assertEqual(res1ind, res2ind[k - 1], 0) # non-contiguous y = x.narrow(1, 0, 1) y0 = y.contiguous() k = random.randint(1, SIZE) res1val, res1ind = torch.kthvalue(y, k) res2val, res2ind = torch.kthvalue(y0, k) self.assertEqual(res1val, res2val, 0) self.assertEqual(res1ind, res2ind, 0) # check that the input wasn't modified self.assertEqual(x, x0, 0) # simple test case (with repetitions) y = torch.Tensor((3, 5, 4, 1, 1, 5)) self.assertEqual(torch.kthvalue(y, 3)[0], torch.Tensor((3,)), 0) self.assertEqual(torch.kthvalue(y, 2)[0], torch.Tensor((1,)), 0)
def test_keepdim_warning(self): torch.utils.backcompat.keepdim_warning.enabled = True x = Variable(torch.randn(3, 4), requires_grad=True) def run_backward(y): y_ = y if type(y) is tuple: y_ = y[0] # check that backward runs smooth y_.backward(y_.data.new(y_.size()).normal_()) def keepdim_check(f): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") y = f(x, 1) self.assertTrue(len(w) == 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) self.assertTrue("keepdim" in str(w[-1].message)) run_backward(y) self.assertEqual(x.size(), x.grad.size()) # check against explicit keepdim y2 = f(x, 1, keepdim=False) self.assertEqual(y, y2) run_backward(y2) y3 = f(x, 1, keepdim=True) if type(y3) == tuple: y3 = (y3[0].squeeze(1), y3[1].squeeze(1)) else: y3 = y3.squeeze(1) self.assertEqual(y, y3) run_backward(y3) keepdim_check(torch.sum) keepdim_check(torch.prod) keepdim_check(torch.mean) keepdim_check(torch.max) keepdim_check(torch.min) keepdim_check(torch.mode) keepdim_check(torch.median) keepdim_check(torch.kthvalue) keepdim_check(torch.var) keepdim_check(torch.std) torch.utils.backcompat.keepdim_warning.enabled = False