Python torch 模块,std() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用torch.std()

项目:ktorch    作者:farizrahman4u    | 项目源码 | 文件源码
def std(x, axis=None, keepdims=False):
    def _std(x, axis, keepdims):
        y = torch.std(x, axis)
        # Since keepdims argument of torch not functional
        return y if keepdims else torch.squeeze(y, axis)

    def _compute_output_shape(x, axis, keepdims):
        if axis is None:
            return ()

        shape = list(_get_shape(x))
        if keepdims:
            shape[axis] = 1
        else:
            del shape[axis]

        return tuple(shape)

    return get_op(_std, output_shape=_compute_output_shape, arguments=[axis, keepdims])(x)
项目:torch_light    作者:ne7ermore    | 项目源码 | 文件源码
def forward(self, input):
        mu = torch.mean(input, dim=-1, keepdim=True)
        sigma = torch.std(input, dim=-1, keepdim=True).clamp(min=self.eps)
        output = (input - mu) / sigma
        return output * self.weight.expand_as(output) + self.bias.expand_as(output)
项目:torch_light    作者:ne7ermore    | 项目源码 | 文件源码
def forward(self, input):
        mu = torch.mean(input, dim=-1, keepdim=True)
        sigma = torch.std(input, dim=-1, keepdim=True).clamp(min=self.eps)
        output = (input - mu) / sigma
        return output * self.weight.expand_as(output) + self.bias.expand_as(output)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_importance_guide(self):
        posterior = pyro.infer.Importance(self.model, guide=self.guide, num_samples=10000)
        marginal = pyro.infer.Marginal(posterior)
        posterior_samples = [marginal() for i in range(1000)]
        posterior_mean = torch.mean(torch.cat(posterior_samples))
        posterior_stddev = torch.std(torch.cat(posterior_samples), 0)
        self.assertEqual(0, torch.norm(posterior_mean - self.mu_mean).data[0],
                         prec=0.01)
        self.assertEqual(0, torch.norm(posterior_stddev - self.mu_stddev).data[0],
                         prec=0.1)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def test_importance_prior(self):
        posterior = pyro.infer.Importance(self.model, guide=None, num_samples=10000)
        marginal = pyro.infer.Marginal(posterior)
        posterior_samples = [marginal() for i in range(1000)]
        posterior_mean = torch.mean(torch.cat(posterior_samples))
        posterior_stddev = torch.std(torch.cat(posterior_samples), 0)
        self.assertEqual(0, torch.norm(posterior_mean - self.mu_mean).data[0],
                         prec=0.01)
        self.assertEqual(0, torch.norm(posterior_stddev - self.mu_stddev).data[0],
                         prec=0.1)
项目:attention-is-all-you-need-pytorch    作者:jadore801120    | 项目源码 | 文件源码
def forward(self, z):
        if z.size(1) == 1:
            return z

        mu = torch.mean(z, keepdim=True, dim=-1)
        sigma = torch.std(z, keepdim=True, dim=-1)
        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

        return ln_out
项目:OpenNMT-py    作者:OpenNMT    | 项目源码 | 文件源码
def forward(self, z):
        if z.size(1) == 1:
            return z
        mu = torch.mean(z, dim=1)
        sigma = torch.std(z, dim=1)
        # HACK. PyTorch is changing behavior
        if mu.dim() == 1:
            mu = mu.unsqueeze(1)
            sigma = sigma.unsqueeze(1)
        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out.mul(self.a_2.expand_as(ln_out)) \
            + self.b_2.expand_as(ln_out)
        return ln_out
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_keepdim_warning(self):
        torch.utils.backcompat.keepdim_warning.enabled = True
        x = Variable(torch.randn(3, 4), requires_grad=True)

        def run_backward(y):
            y_ = y
            if type(y) is tuple:
                y_ = y[0]
            # check that backward runs smooth
            y_.backward(y_.data.new(y_.size()).normal_())

        def keepdim_check(f):
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")
                y = f(x, 1)
                self.assertTrue(len(w) == 1)
                self.assertTrue(issubclass(w[-1].category, UserWarning))
                self.assertTrue("keepdim" in str(w[-1].message))
                run_backward(y)
                self.assertEqual(x.size(), x.grad.size())

                # check against explicit keepdim
                y2 = f(x, 1, keepdim=False)
                self.assertEqual(y, y2)
                run_backward(y2)

                y3 = f(x, 1, keepdim=True)
                if type(y3) == tuple:
                    y3 = (y3[0].squeeze(1), y3[1].squeeze(1))
                else:
                    y3 = y3.squeeze(1)
                self.assertEqual(y, y3)
                run_backward(y3)

        keepdim_check(torch.sum)
        keepdim_check(torch.prod)
        keepdim_check(torch.mean)
        keepdim_check(torch.max)
        keepdim_check(torch.min)
        keepdim_check(torch.mode)
        keepdim_check(torch.median)
        keepdim_check(torch.kthvalue)
        keepdim_check(torch.var)
        keepdim_check(torch.std)
        torch.utils.backcompat.keepdim_warning.enabled = False