Python torchvision 模块,datasets() 实例源码

我们从Python开源项目中,提取了以下13个代码示例,用于说明如何使用torchvision.datasets()

项目:pyro    作者:uber    | 项目源码 | 文件源码
def get_data_loader(dataset_name,
                    batch_size=1,
                    dataset_transforms=None,
                    is_training_set=True,
                    shuffle=True):
    if not dataset_transforms:
        dataset_transforms = []
    trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
    dataset = getattr(datasets, dataset_name)
    return DataLoader(
        dataset(root=DATA_DIR,
                train=is_training_set,
                transform=trans,
                download=True),
        batch_size=batch_size,
        shuffle=shuffle
    )
项目:attention-transfer    作者:szagoruyko    | 项目源码 | 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
        lambda x: x.transpose(2,0,1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        T.RandomHorizontalFlip(),
        T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        T.RandomCrop(32),
        convert,
    ])

    ds = getattr(datasets, opt.dataset)(opt.data_root, train=mode, download=True)
    smode = 'train' if mode else 'test'
    ds = tnt.dataset.TensorDataset([
        getattr(ds, smode+'_data'),
        getattr(ds, smode+'_labels')])
    return ds.transform({0: train_transform if mode else convert})
项目:Rocket-Launching    作者:zhougr1993    | 项目源码 | 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        T.RandomHorizontalFlip(),
        T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        T.RandomCrop(32),
        convert,
    ])

    ds = getattr(datasets, opt.dataset)(
        opt.data_root, train=mode, download=True)
    smode = 'train' if mode else 'test'
    ds = tnt.dataset.TensorDataset([
        getattr(ds, smode + '_data'),
        getattr(ds, smode + '_labels')])
    return ds.transform({0: train_transform if mode else convert})
项目:Rocket-Launching    作者:zhougr1993    | 项目源码 | 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        T.RandomHorizontalFlip(),
        T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        T.RandomCrop(32),
        convert,
    ])

    ds = getattr(datasets, opt.dataset)(
        opt.data_root, train=mode, download=True)
    smode = 'train' if mode else 'test'
    ds = tnt.dataset.TensorDataset([
        getattr(ds, smode + '_data'),
        getattr(ds, smode + '_labels')])
    return ds.transform({0: train_transform if mode else convert})
项目:Deep-learning-with-cats    作者:AlexiaJM    | 项目源码 | 文件源码
def __init__(self, *datasets):
        self.datasets = datasets
项目:Deep-learning-with-cats    作者:AlexiaJM    | 项目源码 | 文件源码
def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)
项目:Deep-learning-with-cats    作者:AlexiaJM    | 项目源码 | 文件源码
def __len__(self):
        return min(len(d) for d in self.datasets)
项目:gan-error-avoidance    作者:aleju    | 项目源码 | 文件源码
def __init__(self, opt):
        transform_list = []

        if (opt.crop_height > 0) and (opt.crop_width > 0):
            transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
        elif opt.crop_size > 0:
            transform_list.append(transforms.CenterCrop(opt.crop_size))

        transform_list.append(transforms.Scale(opt.image_size))
        transform_list.append(transforms.CenterCrop(opt.image_size))

        transform_list.append(transforms.ToTensor())

        if opt.dataset == 'cifar10':
            dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
                transform = transforms.Compose(transform_list))
            dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
                transform = transforms.Compose(transform_list))
            def get_data(k):
                if k < len(dataset1):
                    return dataset1[k][0]
                else:
                    return dataset2[k - len(dataset1)][0]
        else:
            if opt.dataset in ['imagenet', 'folder', 'lfw']:
                dataset = datasets.ImageFolder(root = opt.dataroot,
                    transform = transforms.Compose(transform_list))
            elif opt.dataset == 'lsun':
                dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
                    transform = transforms.Compose(transform_list))
            def get_data(k):
                return dataset[k][0]

        data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
        train_index = data_index['train']

        self.opt = opt
        self.get_data = get_data
        self.train_index = data_index['train']
        self.counter = 0
项目:scalingscattering    作者:edouardoyallon    | 项目源码 | 文件源码
def create_dataset(opt, mode,fold=0):

    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        lambda x: x / 255.0,
        # cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0,  62.1,  66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        cvtransforms.RandomCrop(96),
        convert,
    ])

    smode = 'train' if mode else 'test'
    ds = getattr(datasets, opt.dataset)('.', split=smode, download=True)
    if mode:
        if fold>-1:
            folds_idx = [map(int, v.split(' ')[:-1])
                for v in [line.replace('\n', '')
                    for line in open('./stl10_binary/fold_indices.txt')]][fold]


            ds = tnt.dataset.TensorDataset([
                getattr(ds, 'data').transpose(0, 2, 3, 1)[folds_idx],
                getattr(ds, 'labels')[folds_idx].tolist()])
        else:
            ds = tnt.dataset.TensorDataset([
                getattr(ds, 'data').transpose(0, 2, 3, 1),
                getattr(ds, 'labels').tolist()])

    else:
       ds = tnt.dataset.TensorDataset([
        getattr(ds, 'data').transpose(0, 2, 3, 1),
        getattr(ds, 'labels').tolist()])
    return ds.transform({0: train_transform if mode else convert})
项目:generative_models    作者:j-min    | 项目源码 | 文件源码
def get_dataset(config):
    """Return dataset class"""

    torchvision_datasets = [
        'LSUN',
        'CocoCaptions',
        'CocoDetection',
        'CIFAR10',
        'CIFAR100',
        'FashionMNIST',
        'MNIST',
        'STL10',
        'SVHN',
        'PhotoTour',
        'SEMEION']

    # unaligned_datasets = [
    #     'horse2zebra'
    # ]

    if config.dataset in torchvision_datasets:
        dataset = getattr(datasets, config.dataset)(
            root=config.dataset_dir,
            train=config.isTrain,
            download=True,
            transform=base_transform(config))
    else:
        dataset = get_custom_dataset(config)
    return dataset
项目:gan-error-avoidance    作者:aleju    | 项目源码 | 文件源码
def _generate_images(self, nb_batches, g_fp, r_idx, opt, show_info, queue):
        import torch
        import torch.nn as nn
        import torch.optim as optim
        import torchvision
        import torchvision.datasets as datasets
        import torchvision.transforms as transforms
        from torch.autograd import Variable

        #np.random.seed(42)
        #random.seed(42)
        #torch.manual_seed(42)

        gen = GeneratorLearnedInputSpace(opt.width, opt.height, opt.nfeature, opt.nlayer, opt.code_size, opt.norm, n_lis_layers=opt.r_iterations, upscaling=opt.g_upscaling)
        if show_info:
            print("G:", gen)
        gen.cuda()
        prefix = "last"
        gen.load_state_dict(torch.load(g_fp))
        gen.train()

        print("Generating images for checkpoint G'%s'..." % (g_fp,))
        #imgs_by_riter = [[] for _ in range(1+opt.r_iterations)]
        images_all = []
        for i in range(nb_batches):
            code = Variable(torch.randn(opt.batch_size, opt.code_size).cuda(), volatile=True)

            #for r_idx in range(1+opt.r_iterations):
            images, _ = gen(code, n_execute_lis_layers=r_idx)
            images_np = (images.data.cpu().numpy() * 255).astype(np.uint8).transpose((0, 2, 3, 1))

            #from scipy import misc
            #print(np.average(images[0]), np.min(images[0]), np.max(images[0]))
            #print(np.average(images_fixed[0]), np.min(images_fixed[0]), np.max(images_fixed[0]))
            #misc.imshow(list(images_np)[0])
            #misc.imshow(list(images_fixed)[0])

            #imgs_by_riter[r_idx].extend(list(images_np))
            images_all.extend(images_np)

        result_str = pickle.dumps({
            "g_fp": g_fp,
            "images": images_all
        }, protocol=-1)
        queue.put(result_str)
项目:scalingscattering    作者:edouardoyallon    | 项目源码 | 文件源码
def get_iterator(mode,opt):
    if (opt.imagenetpath is None):
        raise (RuntimeError('Where is imagenet?'))
    if (opt.N is None):
        raise (RuntimeError('Crop size not provided'))
    if (opt.batchSize is None):
        raise (RuntimeError('Batch Size not provided '))
    if (opt.nthread is None):
        raise (RuntimeError('num threads?'))


    def cvload(path):
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img


    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32) / 255.0,
        cvtransforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])




    print("| setting up data loader...")
    if mode:
        traindir = os.path.join(opt.imagenetpath, 'train')
        if (opt.max_samples > 0):
            ds = datasmall.ImageFolder(traindir, tnt.transform.compose([
                cvtransforms.RandomSizedCrop(opt.N),
                cvtransforms.RandomHorizontalFlip(),
                convert,
            ]), loader=cvload,maxSamp=opt.max_samples)
        else:
            ds =torchvision.datasets.ImageFolder(traindir, tnt.transform.compose([
            cvtransforms.RandomSizedCrop(opt.N),
            cvtransforms.RandomHorizontalFlip(),
            convert,
            ]), loader=cvload)
    else:
        if opt.N==224:
            crop_scale=256
        else:
            crop_scale=256*opt.N/224

        valdir = os.path.join(opt.imagenetpath, 'val')
        ds = torchvision.datasets.ImageFolder(valdir, tnt.transform.compose([
            cvtransforms.Scale(crop_scale),
            cvtransforms.CenterCrop(opt.N),
            convert,
        ]), loader=cvload)



    return torch.utils.data.DataLoader(ds,
                                       batch_size=opt.batchSize, shuffle=mode,
                                       num_workers=opt.nthread, pin_memory=False)
项目:scalingscattering    作者:edouardoyallon    | 项目源码 | 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        lambda x: x / 255.0,
        # cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0,  62.1,  66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        cvtransforms.RandomCrop(32),
        convert,
    ])


    ds = getattr(datasets, opt.dataset)('.', train=mode, download=True)
    smode = 'train' if mode else 'test'
    if mode:
        from numpy.random import RandomState
        prng = RandomState(opt.seed)

        assert(opt.sampleSize%10==0)

        random_permute=prng.permutation(np.arange(0,5000))[0:opt.sampleSize/10]

        labels = np.array(getattr(ds,'train_labels'))
        data = getattr(ds,'train_data')

        classes=np.unique(labels)
        inds_all=np.array([],dtype='int32')
        for cl in classes:
            inds=np.where(np.array(labels)==cl)[0][random_permute]
            inds_all=np.r_[inds,inds_all]

        ds = tnt.dataset.TensorDataset([
            data[inds_all,:].transpose(0, 2, 3, 1),
            labels[inds_all].tolist()])
    else:
        ds = tnt.dataset.TensorDataset([
            getattr(ds, smode + '_data').transpose(0, 2, 3, 1),
            getattr(ds, smode + '_labels')])
    return ds.transform({0: train_transform if mode else convert})