我们从Python开源项目中,提取了以下27个代码示例,用于说明如何使用torch.FloatStorage()。
def test_serialization(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] for use_name in (False, True): with tempfile.NamedTemporaryFile() as f: handle = f if not use_name else f.name torch.save(b, handle) f.seek(0) c = torch.load(handle) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4], c[5][1:4], 0)
def test_serialization_backwards_compat(self): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') c = torch.load(path) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4][1:4], c[5], 0)
def test_from_file(self): size = 10000 with tempfile.NamedTemporaryFile() as f: s1 = torch.FloatStorage.from_file(f.name, True, size) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping s2 = torch.FloatStorage.from_file(f.name, True, size) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, 0) # check changes to t1 from t2 rnum = random.uniform(-1, 1) t1.fill_(rnum) self.assertEqual(t1, t2, 0) # check changes to t2 from t1 rnum = random.uniform(-1, 1) t2.fill_(rnum) self.assertEqual(t1, t2, 0)
def test_type_conversions(self): x = torch.randn(5, 5) self.assertIs(type(x.float()), torch.FloatTensor) self.assertIs(type(x.cuda()), torch.cuda.DoubleTensor) self.assertIs(type(x.cuda().float()), torch.cuda.FloatTensor) self.assertIs(type(x.cuda().float().cpu()), torch.FloatTensor) self.assertIs(type(x.cuda().float().cpu().int()), torch.IntTensor) y = x.storage() self.assertIs(type(y.float()), torch.FloatStorage) self.assertIs(type(y.cuda()), torch.cuda.DoubleStorage) self.assertIs(type(y.cuda().float()), torch.cuda.FloatStorage) self.assertIs(type(y.cuda().float().cpu()), torch.FloatStorage) self.assertIs(type(y.cuda().float().cpu().int()), torch.IntStorage)
def assign(): torch.FloatStorage(10)[1:-1] = '1'
def test_element_size(self): byte = torch.ByteStorage().element_size() char = torch.CharStorage().element_size() short = torch.ShortStorage().element_size() int = torch.IntStorage().element_size() long = torch.LongStorage().element_size() float = torch.FloatStorage().element_size() double = torch.DoubleStorage().element_size() self.assertEqual(byte, torch.ByteTensor().element_size()) self.assertEqual(char, torch.CharTensor().element_size()) self.assertEqual(short, torch.ShortTensor().element_size()) self.assertEqual(int, torch.IntTensor().element_size()) self.assertEqual(long, torch.LongTensor().element_size()) self.assertEqual(float, torch.FloatTensor().element_size()) self.assertEqual(double, torch.DoubleTensor().element_size()) self.assertGreater(byte, 0) self.assertGreater(char, 0) self.assertGreater(short, 0) self.assertGreater(int, 0) self.assertGreater(long, 0) self.assertGreater(float, 0) self.assertGreater(double, 0) # These tests are portable, not necessarily strict for your system. self.assertEqual(byte, 1) self.assertEqual(char, 1) self.assertGreaterEqual(short, 2) self.assertGreaterEqual(int, 2) self.assertGreaterEqual(int, short) self.assertGreaterEqual(long, 4) self.assertGreaterEqual(long, int) self.assertGreaterEqual(double, float)
def test_from_buffer(self): a = bytearray([1, 2, 3, 4]) self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) shorts = torch.ShortStorage.from_buffer(a, 'big') self.assertEqual(shorts.size(), 2) self.assertEqual(shorts.tolist(), [258, 772]) ints = torch.IntStorage.from_buffer(a, 'little') self.assertEqual(ints.size(), 1) self.assertEqual(ints[0], 67305985) f = bytearray([0x40, 0x10, 0x00, 0x00]) floats = torch.FloatStorage.from_buffer(f, 'big') self.assertEqual(floats.size(), 1) self.assertEqual(floats[0], 2.25)
def test_serialization(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] b += [torch.arange(1, 11).int()] t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) b += [(t1.storage(), t1.storage(), t2.storage())] b += [a[0].storage()[0:2]] for use_name in (False, True): with tempfile.NamedTemporaryFile(delete=True) as f: handle = f if not use_name else f.name if sys.platform == 'win32' and use_name: handle = tempfile.mktemp() torch.save(b, handle) f.seek(0) c = torch.load(handle) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4], c[5][1:4], 0) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) views = c[7] self.assertEqual(views[0]._cdata, views[1]._cdata) self.assertEqual(views[0], views[2]) self.assertNotEqual(views[0]._cdata, views[2]._cdata) rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
def test_serialization_backwards_compat(self): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] DATA_URL = 'https://download.pytorch.org/test_data/legacy_serialized.pt' data_dir = os.path.join(os.path.dirname(__file__), 'data') test_file_path = os.path.join(data_dir, 'legacy_serialized.pt') succ = download_file(DATA_URL, test_file_path) if not succ: warnings.warn(("Couldn't download the test file for backwards compatibility! " "Tests will be incomplete!"), RuntimeWarning) return c = torch.load(test_file_path) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4], c[5][1:4], 0)
def test_serialization(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] b += [torch.arange(1, 11).int()] t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) b += [(t1.storage(), t1.storage(), t2.storage())] b += [a[0].storage()[0:2]] for use_name in (False, True): with tempfile.NamedTemporaryFile() as f: handle = f if not use_name else f.name torch.save(b, handle) f.seek(0) c = torch.load(handle) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4], c[5][1:4], 0) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) views = c[7] self.assertEqual(views[0]._cdata, views[1]._cdata) self.assertEqual(views[0], views[2]) self.assertNotEqual(views[0]._cdata, views[2]._cdata) rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
def test_serialization(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] b += [torch.arange(1, 11).int()] t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) b += [(t1.storage(), t1.storage(), t2.storage())] b += [a[0].storage()[0:2]] for use_name in (False, True): with tempfile.NamedTemporaryFile() as f: handle = f if not use_name else f.name torch.save(b, handle) f.seek(0) c = torch.load(handle) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4][1:4], c[5], 0) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) views = c[7] self.assertEqual(views[0]._cdata, views[1]._cdata) self.assertEqual(views[0], views[2]) self.assertNotEqual(views[0]._cdata, views[2]._cdata) rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
def test_serialization(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].storage()[1:4]] b += [torch.arange(1, 11).int()] t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) b += [(t1.storage(), t1.storage(), t2.storage())] b += [a[0].storage()[0:2]] for use_name in (False, True): # Passing filename to torch.save(...) will cause the file to be opened twice, # which is not supported on Windows if sys.platform == "win32" and use_name: continue with tempfile.NamedTemporaryFile() as f: handle = f if not use_name else f.name torch.save(b, handle) f.seek(0) c = torch.load(handle) self.assertEqual(b, c, 0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], 0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) self.assertEqual(c[4][1:4], c[5], 0) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) views = c[7] self.assertEqual(views[0]._cdata, views[1]._cdata) self.assertEqual(views[0], views[2]) self.assertNotEqual(views[0]._cdata, views[2]._cdata) rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr())