我们从Python开源项目中,提取了以下28个代码示例,用于说明如何使用torch.set_rng_state()。
def freeze_rng_state(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() yield if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) torch.set_rng_state(rng_state)
def test_randperm(self): _RNGState = torch.get_rng_state() res1 = torch.randperm(100) res2 = torch.Tensor() torch.set_rng_state(_RNGState) torch.randperm(res2, 100) self.assertEqual(res1, res2, 0)
def test_RNGState(self): state = torch.get_rng_state() stateCloned = state.clone() before = torch.rand(1000) self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0) torch.set_rng_state(state) after = torch.rand(1000) self.assertEqual(before, after, 0)
def test_boxMullerState(self): torch.manual_seed(123) odd_number = 101 seeded = torch.randn(odd_number) state = torch.get_rng_state() midstream = torch.randn(odd_number) torch.set_rng_state(state) repeat_midstream = torch.randn(odd_number) torch.manual_seed(123) reseeded = torch.randn(odd_number) self.assertEqual(midstream, repeat_midstream, 0, 'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers') self.assertEqual(seeded, reseeded, 0, 'repeated calls to manual_seed not generating same sequence of normally distributed numbers')
def test_manual_seed(self): rng_state = torch.get_rng_state() torch.manual_seed(2) x = torch.randn(100) self.assertEqual(torch.initial_seed(), 2) torch.manual_seed(2) y = torch.randn(100) self.assertEqual(x, y) torch.set_rng_state(rng_state)
def test_randperm(self): _RNGState = torch.get_rng_state() res1 = torch.randperm(100) res2 = torch.LongTensor() torch.set_rng_state(_RNGState) torch.randperm(100, out=res2) self.assertEqual(res1, res2, 0)
def set_rng_state(new_state): r"""Sets the random number generator state. Args: new_state (torch.ByteTensor): The desired state """ default_generator.set_state(new_state)
def test_exponential(self): rate = Variable(torch.randn(5, 5).abs(), requires_grad=True) rate_1d = Variable(torch.randn(1).abs(), requires_grad=True) self.assertEqual(Exponential(rate).sample().size(), (5, 5)) self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5)) self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1)) self.assertEqual(Exponential(rate_1d).sample().size(), (1,)) self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) self._gradcheck_log_prob(Exponential, (rate,)) state = torch.get_rng_state() eps = rate.new(rate.size()).exponential_() torch.set_rng_state(state) z = Exponential(rate).rsample() z.backward(torch.ones_like(z)) self.assertEqual(rate.grad, -eps / rate**2) rate.grad.zero_() self.assertEqual(z.size(), (5, 5)) def ref_log_prob(idx, x, log_prob): m = rate.data.view(-1)[idx] expected = math.log(m) - m * x self.assertAlmostEqual(log_prob, expected, places=3) self._check_log_prob(Exponential(rate), ref_log_prob) # This is a randomized test.
def test_normal(self): mean = Variable(torch.randn(5, 5), requires_grad=True) std = Variable(torch.randn(5, 5).abs(), requires_grad=True) mean_1d = Variable(torch.randn(1), requires_grad=True) std_1d = Variable(torch.randn(1), requires_grad=True) mean_delta = torch.Tensor([1.0, 0.0]) std_delta = torch.Tensor([1e-5, 1e-5]) self.assertEqual(Normal(mean, std).sample().size(), (5, 5)) self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5)) self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1)) self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1,)) self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1,)) self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1,)) # sample check for extreme value of mean, std self._set_rng_seed(1) self.assertEqual(Normal(mean_delta, std_delta).sample(sample_shape=(1, 2)), torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]), prec=1e-4) self._gradcheck_log_prob(Normal, (mean, std)) self._gradcheck_log_prob(Normal, (mean, 1.0)) self._gradcheck_log_prob(Normal, (0.0, std)) state = torch.get_rng_state() eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std)) torch.set_rng_state(state) z = Normal(mean, std).rsample() z.backward(torch.ones_like(z)) self.assertEqual(mean.grad, torch.ones_like(mean)) self.assertEqual(std.grad, eps) mean.grad.zero_() std.grad.zero_() self.assertEqual(z.size(), (5, 5)) def ref_log_prob(idx, x, log_prob): m = mean.data.view(-1)[idx] s = std.data.view(-1)[idx] expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) / math.sqrt(2 * math.pi * s ** 2)) self.assertAlmostEqual(log_prob, math.log(expected), places=3) self._check_log_prob(Normal(mean, std), ref_log_prob) # This is a randomized test.