我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用models.Generator()。
def main(): args = parse_args() config = Config(args) # ?????? os.makedirs(config.output_dir, exist_ok=True) # ?????? model = models.generate_model(config.model) # ???????? img_orig = load_image(config.original_image, [config.width, config.height]) img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None) # ????? generator = models.Generator(model, img_orig, img_style, config) generator.generate(config)
def train(args): nz = args.nz batch_size = args.batch_size epochs = args.epochs gpu = args.gpu # CIFAR-10 images in range [-1, 1] (tanh generator outputs) train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2) train -= 1.0 train_iter = iterators.SerialIterator(train, batch_size) z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz), batch_size) optimizer_generator = optimizers.RMSprop(lr=0.00005) optimizer_critic = optimizers.RMSprop(lr=0.00005) optimizer_generator.setup(Generator()) optimizer_critic.setup(Critic()) updater = WassersteinGANUpdater( iterator=train_iter, noise_iterator=z_iter, optimizer_generator=optimizer_generator, optimizer_critic=optimizer_critic, device=gpu) trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch')) trainer.extend(extensions.ProgressBar()) trainer.extend(extensions.LogReport(trigger=(1, 'iteration'))) trainer.extend(GeneratorSample(), trigger=(1, 'epoch')) trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss', 'critic/loss/real', 'critic/loss/fake', 'generator/loss'])) trainer.run()
def __init__(self, nic, noc, ngf, ndf, beta=0.5, lamb=100, lr=1e-3, cuda=True, crayon=False): """ Args: nic: Number of input channel noc: Number of output channels ngf: Number of generator filters ndf: Number of discriminator filters lamb: Weight on L1 term in objective """ self.cuda = cuda self.start_epoch = 0 self.crayon = crayon if crayon: self.cc = CrayonClient(hostname="localhost", port=8889) try: self.logger = self.cc.create_experiment('pix2pix') except: self.cc.remove_experiment('pix2pix') self.logger = self.cc.create_experiment('pix2pix') self.gen = self.cudafy(Generator(nic, noc, ngf)) self.dis = self.cudafy(Discriminator(nic, noc, ndf)) # Optimizers for generators self.gen_optim = self.cudafy(optim.Adam( self.gen.parameters(), lr=lr, betas=(beta, 0.999))) # Optimizers for discriminators self.dis_optim = self.cudafy(optim.Adam( self.dis.parameters(), lr=lr, betas=(beta, 0.999))) # Loss functions self.criterion_bce = nn.BCELoss() self.criterion_mse = nn.MSELoss() self.criterion_l1 = nn.L1Loss() self.lamb = lamb
def train(self, loader, c_epoch): self.dis.train() self.gen.train() self.reset_gradients() max_idx = len(loader) for idx, features in enumerate(tqdm(loader)): orig_x = Variable(self.cudafy(features[0])) orig_y = Variable(self.cudafy(features[1])) """ Discriminator """ # Train with real self.dis.volatile = False dis_real = self.dis(torch.cat((orig_x, orig_y), 1)) real_labels = Variable(self.cudafy( torch.ones(dis_real.size()) )) dis_real_loss = self.criterion_bce( dis_real, real_labels) # Train with fake gen_y = self.gen(orig_x) dis_fake = self.dis(torch.cat((orig_x, gen_y.detach()), 1)) fake_labels = Variable(self.cudafy( torch.zeros(dis_fake.size()) )) dis_fake_loss = self.criterion_bce( dis_fake, fake_labels) # Update weights dis_loss = dis_real_loss + dis_fake_loss dis_loss.backward() self.dis_optim.step() self.reset_gradients() """ Generator """ self.dis.volatile = True dis_real = self.dis(torch.cat((orig_x, gen_y), 1)) real_labels = Variable(self.cudafy( torch.ones(dis_real.size()) )) gen_loss = self.criterion_bce(dis_real, real_labels) + \ self.lamb * self.criterion_l1(gen_y, orig_y) gen_loss.backward() self.gen_optim.step() # Pycrayon or nah if self.crayon: self.logger.add_scalar_value('pix2pix_gen_loss', gen_loss.data[0]) self.logger.add_scalar_value('pix2pix_dis_loss', dis_loss.data[0]) if idx % 50 == 0: tqdm.write('Epoch: {} [{}/{}]\t' 'D Loss: {:.4f}\t' 'G Loss: {:.4f}'.format( c_epoch, idx, max_idx, dis_loss.data[0], gen_loss.data[0] ))