我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用torchvision.datasets.LSUN。
def LSUN_loader(root, image_size, classes=['bedroom'], normalize=True): """ Function to load torchvision dataset object based on just image size Args: root = If your dataset is downloaded and ready to use, mention the location of this folder. Else, the dataset will be downloaded to this location image_size = Size of every image classes = Default class is 'bedroom'. Other available classes are: 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower' normalize = Requirement to normalize the image. Default is true """ transformations = [transforms.Scale(image_size), transforms.CenterCrop(image_size), transforms.ToTensor()] if normalize == True: transformations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) for c in classes: c = c + '_train' lsun_data = dset.LSUN(db_path=root, classes=classes, transform=transforms.Compose(transformations)) return lsun_data
def __init__(self, opt): transform_list = [] if (opt.crop_height > 0) and (opt.crop_width > 0): transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width)) elif opt.crop_size > 0: transform_list.append(transforms.CenterCrop(opt.crop_size)) transform_list.append(transforms.Scale(opt.image_size)) transform_list.append(transforms.CenterCrop(opt.image_size)) transform_list.append(transforms.ToTensor()) if opt.dataset == 'cifar10': dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True, transform = transforms.Compose(transform_list)) dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False, transform = transforms.Compose(transform_list)) def get_data(k): if k < len(dataset1): return dataset1[k][0] else: return dataset2[k - len(dataset1)][0] else: if opt.dataset in ['imagenet', 'folder', 'lfw']: dataset = datasets.ImageFolder(root = opt.dataroot, transform = transforms.Compose(transform_list)) elif opt.dataset == 'lsun': dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'], transform = transforms.Compose(transform_list)) def get_data(k): return dataset[k][0] data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt')) train_index = data_index['train'] self.opt = opt self.get_data = get_data self.train_index = data_index['train'] self.counter = 0
def get_dataloader(opt): if opt.dataset in ['imagenet', 'folder', 'lfw']: # folder dataset dataset = dset.ImageFolder(root=opt.dataroot, transform=transforms.Compose([ transforms.Scale(opt.imageScaleSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif opt.dataset == 'lsun': dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], transform=transforms.Compose([ transforms.Scale(opt.imageScaleSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) elif opt.dataset == 'cifar10': dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=transforms.Compose([ transforms.Scale(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) ) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) return dataloader
def get_data(args, train_flag=True): transform = transforms.Compose([ transforms.Scale(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset in ['imagenet', 'folder', 'lfw']: dataset = dset.ImageFolder(root=args.dataroot, transform=transform) elif args.dataset == 'lsun': dataset = dset.LSUN(db_path=args.dataroot, classes=['bedroom_train'], transform=transform) elif args.dataset == 'cifar10': dataset = dset.CIFAR10(root=args.dataroot, download=True, train=train_flag, transform=transform) elif args.dataset == 'cifar100': dataset = dset.CIFAR100(root=args.dataroot, download=True, train=train_flag, transform=transform) elif args.dataset == 'mnist': dataset = dset.MNIST(root=args.dataroot, download=True, train=train_flag, transform=transform) elif args.dataset == 'celeba': imdir = 'train' if train_flag else 'val' dataroot = os.path.join(args.dataroot, imdir) if args.image_size != 64: raise ValueError('the image size for CelebA dataset need to be 64!') dataset = FolderWithImages(root=dataroot, input_transform=transforms.Compose([ ALICropAndScale(), transforms.ToTensor(), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), target_transform=transforms.ToTensor() ) else: raise ValueError("Unknown dataset %s" % (args.dataset)) return dataset