我们从Python开源项目中,提取了以下25个代码示例,用于说明如何使用torch.LongStorage()。
def test_MaxUnpool2d_output_size(self): m = nn.MaxPool2d(3, stride=2, return_indices=True) mu = nn.MaxUnpool2d(3, stride=2) big_t = torch.rand(1, 1, 6, 6) big_t[0][0][4][4] = 100 output_big, indices_big = m(Variable(big_t)) self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big)) small_t = torch.rand(1, 1, 5, 5) for i in range(0, 4, 2): for j in range(0, 4, 2): small_t[:,:,i,j] = 100 output_small, indices_small = m(Variable(small_t)) for h in range(3, 10): for w in range(3, 10): if 4 <= h <= 6 and 4 <= w <= 6: size = (h, w) if h == 5: size = torch.LongStorage(size) elif h == 6: size = torch.LongStorage((1, 1) + size) mu(output_small, indices_small, output_size=size) else: self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def make_tensor_reader(typename): python_class = get_python_class(typename) def read_tensor(reader, version): # source: # https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243 ndim = reader.read_int() # read size: size = torch.LongStorage(reader.read_long_array(ndim)) # read stride: stride = torch.LongStorage(reader.read_long_array(ndim)) # storage offset: storage_offset = reader.read_long() - 1 # read storage: storage = reader.read() if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0: # empty torch tensor return python_class() return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride)) return read_tensor
def test_MaxUnpool2d_output_size(self): m = nn.MaxPool2d(3, stride=2, return_indices=True) mu = nn.MaxUnpool2d(3, stride=2) big_t = torch.rand(1, 1, 6, 6) big_t[0][0][4][4] = 100 output_big, indices_big = m(Variable(big_t)) self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big)) small_t = torch.rand(1, 1, 5, 5) for i in range(0, 4, 2): for j in range(0, 4, 2): small_t[:, :, i, j] = 100 output_small, indices_small = m(Variable(small_t)) for h in range(3, 10): for w in range(3, 10): if 4 <= h <= 6 and 4 <= w <= 6: size = (h, w) if h == 5: size = torch.LongStorage(size) elif h == 6: size = torch.LongStorage((1, 1) + size) mu(output_small, indices_small, output_size=size) else: self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))
def test_repeat(self): initial_shape = (8, 4) tensor = torch.rand(*initial_shape) size = (3, 1, 1) torchSize = torch.Size(size) target = [3, 8, 4] self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat') self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage') result = tensor.repeat(*size) self.assertEqual(result.size(), target, 'Error in repeat using result') result = tensor.repeat(torchSize) self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')
def test_ConvTranspose2d_output_size(self): m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) i = Variable(torch.randn(2, 3, 6, 6)) for h in range(15, 22): for w in range(15, 22): if 18 <= h <= 20 and 18 <= w <= 20: size = (h, w) if h == 19: size = torch.LongStorage(size) elif h == 2: size = torch.LongStorage((2, 4) + size) m(i, output_size=(h, w)) else: self.assertRaises(ValueError, lambda: m(i, (h, w)))
def test_repeat(self): result = torch.Tensor() tensor = torch.rand(8, 4) size = (3, 1, 1) torchSize = torch.Size(size) target = [3, 8, 4] self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat') self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage') result = tensor.repeat(*size) self.assertEqual(result.size(), target, 'Error in repeat using result') result = tensor.repeat(torchSize) self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') self.assertEqual((result.mean(0).view(8, 4)-tensor).abs().max(), 0, 'Error in repeat (not equal)')
def test_element_size(self): byte = torch.ByteStorage().element_size() char = torch.CharStorage().element_size() short = torch.ShortStorage().element_size() int = torch.IntStorage().element_size() long = torch.LongStorage().element_size() float = torch.FloatStorage().element_size() double = torch.DoubleStorage().element_size() self.assertEqual(byte, torch.ByteTensor().element_size()) self.assertEqual(char, torch.CharTensor().element_size()) self.assertEqual(short, torch.ShortTensor().element_size()) self.assertEqual(int, torch.IntTensor().element_size()) self.assertEqual(long, torch.LongTensor().element_size()) self.assertEqual(float, torch.FloatTensor().element_size()) self.assertEqual(double, torch.DoubleTensor().element_size()) self.assertGreater(byte, 0) self.assertGreater(char, 0) self.assertGreater(short, 0) self.assertGreater(int, 0) self.assertGreater(long, 0) self.assertGreater(float, 0) self.assertGreater(double, 0) # These tests are portable, not necessarily strict for your system. self.assertEqual(byte, 1) self.assertEqual(char, 1) self.assertGreaterEqual(short, 2) self.assertGreaterEqual(int, 2) self.assertGreaterEqual(int, short) self.assertGreaterEqual(long, 4) self.assertGreaterEqual(long, int) self.assertGreaterEqual(double, float)
def test_repeat(self): result = torch.Tensor() tensor = torch.rand(8, 4) size = (3, 1, 1) torchSize = torch.Size(size) target = [3, 8, 4] self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat') self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage') result = tensor.repeat(*size) self.assertEqual(result.size(), target, 'Error in repeat using result') result = tensor.repeat(torchSize) self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') self.assertEqual((result.mean(0).view(8, 4) - tensor).abs().max(), 0, 'Error in repeat (not equal)')
def _tensor_str(self): n = PRINT_OPTS.edgeitems has_hdots = self.size()[-1] > 2*n has_vdots = self.size()[-2] > 2*n print_full_mat = not has_hdots and not has_vdots formatter = _number_format(self, min_sz=3 if not print_full_mat else 0) print_dots = self.numel() >= PRINT_OPTS.threshold dim_sz = max(2, max(len(str(x)) for x in self.size())) dim_fmt = "{:^" + str(dim_sz) + "}" dot_fmt = u"{:^" + str(dim_sz+1) + "}" counter_dim = self.ndimension() - 2 counter = torch.LongStorage(counter_dim).fill_(0) counter[counter.size()-1] = -1 finished = False strt = '' while True: nrestarted = [False for i in counter] nskipped = [False for i in counter] for i in _range(counter_dim - 1, -1, -1): counter[i] += 1 if print_dots and counter[i] == n and self.size(i) > 2*n: counter[i] = self.size(i) - n nskipped[i] = True if counter[i] == self.size(i): if i == 0: finished = True counter[i] = 0 nrestarted[i] = True else: break if finished: break elif print_dots: if any(nskipped): for hdot in nskipped: strt += dot_fmt.format('...') if hdot \ else dot_fmt.format('') strt += '\n' if any(nrestarted): strt += ' ' for vdot in nrestarted: strt += dot_fmt.format(u'\u22EE' if vdot else '') strt += '\n' if strt != '': strt += '\n' strt += '({},.,.) = \n'.format( ','.join(dim_fmt.format(i) for i in counter)) submatrix = reduce(lambda t, i: t.select(0, i), counter, self) strt += _matrix_str(submatrix, ' ', formatter, print_dots) return strt
def _tensor_str(self): n = PRINT_OPTS.edgeitems has_hdots = self.size()[-1] > 2 * n has_vdots = self.size()[-2] > 2 * n print_full_mat = not has_hdots and not has_vdots formatter = _number_format(self, min_sz=3 if not print_full_mat else 0) print_dots = self.numel() >= PRINT_OPTS.threshold dim_sz = max(2, max(len(str(x)) for x in self.size())) dim_fmt = "{:^" + str(dim_sz) + "}" dot_fmt = u"{:^" + str(dim_sz + 1) + "}" counter_dim = self.ndimension() - 2 counter = torch.LongStorage(counter_dim).fill_(0) counter[counter.size() - 1] = -1 finished = False strt = '' while True: nrestarted = [False for i in counter] nskipped = [False for i in counter] for i in _range(counter_dim - 1, -1, -1): counter[i] += 1 if print_dots and counter[i] == n and self.size(i) > 2 * n: counter[i] = self.size(i) - n nskipped[i] = True if counter[i] == self.size(i): if i == 0: finished = True counter[i] = 0 nrestarted[i] = True else: break if finished: break elif print_dots: if any(nskipped): for hdot in nskipped: strt += dot_fmt.format('...') if hdot \ else dot_fmt.format('') strt += '\n' if any(nrestarted): strt += ' ' for vdot in nrestarted: strt += dot_fmt.format(u'\u22EE' if vdot else '') strt += '\n' if strt != '': strt += '\n' strt += '({},.,.) = \n'.format( ','.join(dim_fmt.format(i) for i in counter)) submatrix = reduce(lambda t, i: t.select(0, i), counter, self) strt += _matrix_str(submatrix, ' ', formatter, print_dots) return strt