Python torch 模块,median() 实例源码

我们从Python开源项目中,提取了以下18个代码示例,用于说明如何使用torch.median()

项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0, keepdim = True)[0] - point_set.min(dim=0, keepdim = True)[0]
    dim = torch.max(diff, dim = 1, keepdim = True)[1][0,0]
    cut = torch.median(point_set[:,dim], keepdim = True)[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0]
    diff = diff[:3]
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0] 
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]  
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            res1val, res1ind = torch.median(x)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size+1)/2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(res2val, res2ind, x)
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val[0], res2val[ind], 0)
            self.assertEqual(res1ind[0], res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            res1val, res1ind = torch.median(x)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size + 1) / 2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(x, out=(res2val, res2ind))
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val[0], res2val[ind], 0)
            self.assertEqual(res1ind[0], res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_dim_reduction(self):
        dim_red_fns = [
            "mean", "median", "mode", "norm", "prod",
            "std", "sum", "var", "max", "min"]

        def normfn_attr(t, dim, keepdim=True):
            attr = getattr(torch, "norm")
            return attr(t, 2, dim, keepdim)

        for fn_name in dim_red_fns:
            x = torch.randn(3, 4, 5)
            fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr

            def fn(t, dim, keepdim=True):
                ans = fn_attr(x, dim, keepdim)
                return ans if not isinstance(ans, tuple) else ans[0]

            dim = random.randint(0, 2)
            self.assertEqual(fn(x, dim, False).unsqueeze(dim), fn(x, dim))
            self.assertEqual(x.ndimension() - 1, fn(x, dim, False).ndimension())
            self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())

            # check 1-d behavior
            x = torch.randn(1)
            dim = 0
            self.assertEqual(fn(x, dim), fn(x, dim, True))
            self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
            self.assertEqual(x.ndimension(), fn(x, dim, True).ndimension())
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            res1val, res1ind = torch.median(x, keepdim=False)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size + 1) / 2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val, 0)
            self.assertEqual(res2val.select(1, ind), res1val, 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(x, keepdim=False, out=(res2val, res2ind))
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0, keepdim=False)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val, res2val[ind], 0)
            self.assertEqual(res1ind, res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def _test_dim_reduction(self, cast):
        dim_red_fns = [
            "mean", "median", "mode", "norm", "prod",
            "std", "sum", "var", "max", "min"]

        def normfn_attr(t, dim, keepdim=False):
            attr = getattr(torch, "norm")
            return attr(t, 2, dim, keepdim)

        for fn_name in dim_red_fns:
            fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr

            def fn(x, dim, keepdim=False):
                ans = fn_attr(x, dim, keepdim=keepdim)
                return ans if not isinstance(ans, tuple) else ans[0]

            def test_multidim(x, dim):
                self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
                self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
                self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())

            # general case
            x = cast(torch.randn(3, 4, 5))
            dim = random.randint(0, 2)
            test_multidim(x, dim)

            # check 1-d behavior
            x = cast(torch.randn(1))
            dim = 0
            self.assertEqual(fn(x, dim), fn(x, dim, keepdim=True))
            self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
            self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())

            # check reducing of a singleton dimension
            dims = [3, 4, 5]
            singleton_dim = random.randint(0, 2)
            dims[singleton_dim] = 1
            x = cast(torch.randn(dims))
            test_multidim(x, singleton_dim)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            nelem = x.nelement()
            res1val = torch.median(x)
            res2val, _ = torch.sort(x.view(nelem))
            ind = int(math.floor((nelem + 1) / 2) - 1)

            self.assertEqual(res2val[ind], res1val, 0)

            res1val, res1ind = torch.median(x, dim=1, keepdim=False)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size + 1) / 2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val, 0)
            self.assertEqual(res2val.select(1, ind), res1val, 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(x, keepdim=False, out=(res2val, res2ind))
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0, keepdim=False)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val, res2val[ind], 0)
            self.assertEqual(res1ind, res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def _test_dim_reduction(self, cast):
        dim_red_fns = [
            "mean", "median", "mode", "norm", "prod",
            "std", "sum", "var", "max", "min"]

        def normfn_attr(t, dim, keepdim=False):
            attr = getattr(torch, "norm")
            return attr(t, 2, dim, keepdim)

        for fn_name in dim_red_fns:
            fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr

            def fn(x, dim, keepdim=False):
                ans = fn_attr(x, dim, keepdim=keepdim)
                return ans if not isinstance(ans, tuple) else ans[0]

            def test_multidim(x, dim):
                self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
                self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
                self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())

            # general case
            x = cast(torch.randn(3, 4, 5))
            dim = random.randint(0, 2)
            test_multidim(x, dim)

            # check 1-d behavior
            x = cast(torch.randn(1))
            dim = 0
            self.assertEqual(fn(x, dim), fn(x, dim, keepdim=True))
            self.assertEqual(x.ndimension(), fn(x, dim).ndimension())
            self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())

            # check reducing of a singleton dimension
            dims = [3, 4, 5]
            singleton_dim = random.randint(0, 2)
            dims[singleton_dim] = 1
            x = cast(torch.randn(dims))
            test_multidim(x, singleton_dim)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            nelem = x.nelement()
            res1val = torch.median(x)
            res2val, _ = torch.sort(x.view(nelem))
            ind = int(math.floor((nelem + 1) / 2) - 1)

            self.assertEqual(res2val[ind], res1val, 0)

            res1val, res1ind = torch.median(x, dim=1, keepdim=False)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size + 1) / 2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val, 0)
            self.assertEqual(res2val.select(1, ind), res1val, 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(x, keepdim=False, out=(res2val, res2ind))
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0, keepdim=False)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val, res2val[ind], 0)
            self.assertEqual(res1ind, res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps_reuse(point_set, level, pos, tree, cutdim):
    sz = point_set.size()
    num_points = np.array(sz)[0]/2
    max_value = point_set.max(dim=0)[0]
    min_value = -(-point_set).max(dim=0)[0]

    diff = max_value - min_value
    dim = torch.max(diff, dim = 1)[1][0,0]

    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)

    tree[level+1][pos * 2] = left_ps
    tree[level+1][pos * 2 + 1] = right_ps
    cutdim[level][pos * 2] = dim
    cutdim[level][pos * 2 + 1] = dim

    return
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps_reuse(point_set, level, pos, tree, cutdim):
    sz = point_set.size()
    num_points = np.array(sz)[0]/2
    max_value = point_set.max(dim=0)[0]
    min_value = -(-point_set).max(dim=0)[0]

    diff = max_value - min_value
    dim = torch.max(diff, dim = 1)[1][0,0]

    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)

    tree[level+1][pos * 2] = left_ps
    tree[level+1][pos * 2 + 1] = right_ps
    cutdim[level][pos * 2] = dim
    cutdim[level][pos * 2 + 1] = dim

    return
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps_reuse(point_set, level, pos, tree, cutdim):
    sz = point_set.size()
    num_points = np.array(sz)[0]/2
    max_value = point_set.max(dim=0)[0]
    min_value = -(-point_set).max(dim=0)[0]

    diff = max_value - min_value
    diff = diff[:,:3]
    dim = torch.max(diff, dim = 1)[1][0,0]

    cut = torch.median(point_set[:,dim])[0][0]
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)

    tree[level+1][pos * 2] = left_ps
    tree[level+1][pos * 2 + 1] = right_ps
    cutdim[level][pos * 2] = dim
    cutdim[level][pos * 2 + 1] = dim

    return
项目:kdnet.pytorch    作者:fxia22    | 项目源码 | 文件源码
def split_ps_reuse(point_set, level, pos, tree, cutdim):
    sz = point_set.size()
    num_points = np.array(sz)[0]/2
    max_value = point_set.max(dim=0)[0]
    min_value = -(-point_set).max(dim=0)[0]

    diff = max_value - min_value
    dim = torch.max(diff, dim = 1)[1][0,0]

    cut = torch.median(point_set[:,dim])[0][0]  
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))

    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)

    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)

    tree[level+1][pos * 2] = left_ps
    tree[level+1][pos * 2 + 1] = right_ps
    cutdim[level][pos * 2] = dim
    cutdim[level][pos * 2 + 1] = dim

    return
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_keepdim_warning(self):
        torch.utils.backcompat.keepdim_warning.enabled = True
        x = Variable(torch.randn(3, 4), requires_grad=True)

        def run_backward(y):
            y_ = y
            if type(y) is tuple:
                y_ = y[0]
            # check that backward runs smooth
            y_.backward(y_.data.new(y_.size()).normal_())

        def keepdim_check(f):
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")
                y = f(x, 1)
                self.assertTrue(len(w) == 1)
                self.assertTrue(issubclass(w[-1].category, UserWarning))
                self.assertTrue("keepdim" in str(w[-1].message))
                run_backward(y)
                self.assertEqual(x.size(), x.grad.size())

                # check against explicit keepdim
                y2 = f(x, 1, keepdim=False)
                self.assertEqual(y, y2)
                run_backward(y2)

                y3 = f(x, 1, keepdim=True)
                if type(y3) == tuple:
                    y3 = (y3[0].squeeze(1), y3[1].squeeze(1))
                else:
                    y3 = y3.squeeze(1)
                self.assertEqual(y, y3)
                run_backward(y3)

        keepdim_check(torch.sum)
        keepdim_check(torch.prod)
        keepdim_check(torch.mean)
        keepdim_check(torch.max)
        keepdim_check(torch.min)
        keepdim_check(torch.mode)
        keepdim_check(torch.median)
        keepdim_check(torch.kthvalue)
        keepdim_check(torch.var)
        keepdim_check(torch.std)
        torch.utils.backcompat.keepdim_warning.enabled = False