Python torchvision.datasets 模块,CIFAR10 实例源码

我们从Python开源项目中,提取了以下10个代码示例,用于说明如何使用torchvision.datasets.CIFAR10

项目:convNet.pytorch    作者:eladhoffer    | 项目源码 | 文件源码
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
项目:bigBatch    作者:eladhoffer    | 项目源码 | 文件源码
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
项目:generative_zoo    作者:DL-IT    | 项目源码 | 文件源码
def CIFAR10_loader(root, image_size, normalize=True):
    """
        Function to load torchvision dataset object based on just image size
        Args:
            root        = If your dataset is downloaded and ready to use, mention the location of this folder. Else, the dataset will be downloaded to this location
            image_size  = Size of every image
            normalize   = Requirement to normalize the image. Default is true
    """
    transformations = [transforms.Scale(image_size), transforms.ToTensor()]
    if normalize == True:
        transformations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    cifar10_data    = dset.CIFAR10(root=root, download=True, transform=transforms.Compose(transformations))
    return cifar10_data
项目:gan-error-avoidance    作者:aleju    | 项目源码 | 文件源码
def __init__(self, opt):
        transform_list = []

        if (opt.crop_height > 0) and (opt.crop_width > 0):
            transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
        elif opt.crop_size > 0:
            transform_list.append(transforms.CenterCrop(opt.crop_size))

        transform_list.append(transforms.Scale(opt.image_size))
        transform_list.append(transforms.CenterCrop(opt.image_size))

        transform_list.append(transforms.ToTensor())

        if opt.dataset == 'cifar10':
            dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
                transform = transforms.Compose(transform_list))
            dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
                transform = transforms.Compose(transform_list))
            def get_data(k):
                if k < len(dataset1):
                    return dataset1[k][0]
                else:
                    return dataset2[k - len(dataset1)][0]
        else:
            if opt.dataset in ['imagenet', 'folder', 'lfw']:
                dataset = datasets.ImageFolder(root = opt.dataroot,
                    transform = transforms.Compose(transform_list))
            elif opt.dataset == 'lsun':
                dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
                    transform = transforms.Compose(transform_list))
            def get_data(k):
                return dataset[k][0]

        data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
        train_index = data_index['train']

        self.opt = opt
        self.get_data = get_data
        self.train_index = data_index['train']
        self.counter = 0
项目:pytorch-reverse-gan    作者:yxlao    | 项目源码 | 文件源码
def get_dataloader(opt):
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageScaleSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageScaleSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ])
                               )
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    return dataloader
项目:pytorch-playground    作者:aaron-xichen    | 项目源码 | 文件源码
def get10(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building CIFAR-10 data loader with {} workers".format(num_workers))
    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                root=data_root, train=True, download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)
    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                root=data_root, train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds
项目:dawn-bench-models    作者:stanford-futuredata    | 项目源码 | 文件源码
def infer(dataset_dir, run_dir, output_file, start, end, repeat, log2,
          cpu, gpu, append, models):

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])

    testset = datasets.CIFAR10(root=dataset_dir, train=False, download=True,
                               transform=transform_test)
    models = models or os.listdir(run_dir)
    output_path = os.path.join(run_dir, output_file)
    assert not os.path.exists(output_path) or append
    for model in models:
        model_dir = os.path.join(run_dir, model)
        paths = glob(f"{model_dir}/*/checkpoint_best_model.t7")
        assert len(paths) > 0
        path = os.path.abspath(paths[0])

        print(f'Model: {model}')
        print(f'Path: {path}')

        if cpu:
            print('With CPU:')
            engine = PyTorchEngine(path, use_cuda=False, arch=model)
            infer_cifar10(testset, engine, start=start, end=end, log2=log2,
                          repeat=repeat, output=output_path)

        if gpu and torch.cuda.is_available():
            print('With GPU:')
            engine = PyTorchEngine(path, use_cuda=True, arch=model)
            # Warmup
            time_batch_size(testset, 1, engine.pred, engine.use_cuda, repeat=1)

            infer_cifar10(testset, engine, start=start, end=end, log2=log2,
                          repeat=repeat, output=output_path)
项目:DenseNet    作者:kevinzakka    | 项目源码 | 文件源码
def get_test_loader(data_dir,
                    name,
                    batch_size,
                    shuffle=True,
                    num_workers=4,
                    pin_memory=False):
    """
    Utility function for loading and returning a multi-process 
    test iterator over the CIFAR-10 dataset.

    If using CUDA, num_workers should be set to 1 and pin_memory to True.

    Params
    ------
    - data_dir: path directory to the dataset.
    - name: string specifying which dataset to load. Can be `cifar10`,
      or `cifar100`.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.

    Returns
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    if name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir, 
                                   train=False, 
                                   download=True,
                                   transform=transform)
    else:
        dataset = datasets.CIFAR100(root=data_dir, 
                                    train=False, 
                                    download=True,
                                    transform=transform)

    data_loader = torch.utils.data.DataLoader(dataset, 
                                              batch_size=batch_size, 
                                              shuffle=shuffle, 
                                              num_workers=num_workers,
                                              pin_memory=pin_memory)

    return data_loader
项目:MMD-GAN    作者:OctoberChang    | 项目源码 | 文件源码
def get_data(args, train_flag=True):
    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transform)

    elif args.dataset == 'lsun':
        dataset = dset.LSUN(db_path=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transform)

    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               train=train_flag,
                               transform=transform)

    elif args.dataset == 'cifar100':
        dataset = dset.CIFAR100(root=args.dataroot,
                                download=True,
                                train=train_flag,
                                transform=transform)

    elif args.dataset == 'mnist':
        dataset = dset.MNIST(root=args.dataroot,
                             download=True,
                             train=train_flag,
                             transform=transform)

    elif args.dataset == 'celeba':
        imdir = 'train' if train_flag else 'val'
        dataroot = os.path.join(args.dataroot, imdir)
        if args.image_size != 64:
            raise ValueError('the image size for CelebA dataset need to be 64!')

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       ALICropAndScale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    else:
        raise ValueError("Unknown dataset %s" % (args.dataset))
    return dataset
项目:optnet    作者:locuslab    | 项目源码 | 文件源码
def get_loaders(args):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'mnist':
        trainLoader = torch.utils.data.DataLoader(
            dset.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batchSz, shuffle=True, **kwargs)
        testLoader = torch.utils.data.DataLoader(
            dset.MNIST('data/mnist', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.batchSz, shuffle=False, **kwargs)
    elif args.dataset == 'cifar-10':
        normMean = [0.49139968, 0.48215827, 0.44653124]
        normStd = [0.24703233, 0.24348505, 0.26158768]
        normTransform = transforms.Normalize(normMean, normStd)

        trainTransform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normTransform
        ])
        testTransform = transforms.Compose([
            transforms.ToTensor(),
            normTransform
        ])

        trainLoader = DataLoader(
            dset.CIFAR10(root='data/cifar', train=True, download=True,
                        transform=trainTransform),
            batch_size=args.batchSz, shuffle=True, **kwargs)
        testLoader = DataLoader(
            dset.CIFAR10(root='data/cifar', train=False, download=True,
                        transform=testTransform),
            batch_size=args.batchSz, shuffle=False, **kwargs)
    else:
        assert(False)

    return trainLoader, testLoader