Python model 模块,Discriminator() 实例源码

我们从Python开源项目中,提取了以下13个代码示例,用于说明如何使用model.Discriminator()

项目:DeepWorks    作者:daigo0927    | 项目源码 | 文件源码
def __init__(self,
                 z_dim, image_size,
                 lr_d, lr_g):

        self.sess = tf.Session()

        self.z_dim = z_dim
        self.image_size = image_size

        self.gen = GeneratorDeconv(input_size = z_dim,
                                   image_size = image_size)
        self.disc = Discriminator()

        self._build_graph(lr_d = lr_d, lr_g = lr_g)

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())
项目:DeepWorks    作者:daigo0927    | 项目源码 | 文件源码
def __init__(self,
                 label_size,
                 z_dim, image_size,
                 lr_d, lr_g):

        self.sess = tf.Session()

        self.label_size = label_size
        self.z_dim = z_dim
        self.image_size = image_size

        self.gen = GeneratorDeconv(input_size = z_dim+label_size,
                                   image_size = image_size)
        self.disc = Discriminator()

        self._build_graph(lr_d = lr_d, lr_g = lr_g)

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def main():
    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    sd = Discriminator("{}/_sd".format(directory)).load()
    ae = ConvolutionalGumbelAE2(directory).load()

    input = "generated_states.csv"
    print("loading {}".format("{}/{}".format(directory,input)), end='...', flush=True)
    states = np.loadtxt("{}/{}".format(directory,input),dtype=np.uint8)
    print("done.")
    zs      = states.view()
    total   = states.shape[0]
    N       = states.shape[1]
    batch   = 500000
    output = "generated_states2.csv"
    try:
        print(ae.local(output))
        with open(ae.local(output), 'wb') as f:
            print("original states:",total)
            for i in range(total//batch+1):
                _zs = zs[i*batch:(i+1)*batch]
                _result = sd.discriminate(_zs,batch_size=5000).round().astype(np.uint8)
                _zs_filtered = _zs[np.where(_result > 0)[0],:]
                print("reduced  states:",len(_zs_filtered),"/",len(_zs))

                _xs = ae.decode_binary(_zs_filtered[:20],batch_size=5000).round().astype(np.uint8)
                ae.plot(_xs,path="generated_states_filtered{}.png".format(i))

                np.savetxt(f,_zs_filtered,"%d",delimiter=" ")

    except KeyboardInterrupt:
        print("dump stopped")
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")
    name = "generated_actions.csv"

    print("loading {}".format("{}/generated_states.csv".format(directory)), end='...', flush=True)
    states  = np.loadtxt("{}/generated_states.csv".format(directory),dtype=np.uint8)
    print("done.")
    total   = states.shape[0]
    N       = states.shape[1]
    actions = np.pad(states,((0,0),(0,N)),"constant")

    acc = 0

    try:
        print(ad.local(name))
        with open(ad.local(name), 'wb') as f:
            for i,s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
                actions[:,N:] = s
                ys            = ad.discriminate(actions,batch_size=400000)
                valid_actions = actions[np.where(ys > 0.8)]
                acc           += len(valid_actions)
                print(len(valid_actions),acc)
                np.savetxt(f,valid_actions,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    discriminator = Discriminator(directory_ad).load()
    name = "generated_actions.csv"

    N = discriminator.net.input_shape[1]
    lowbit  = 20
    highbit = N - lowbit
    print("batch size: {}".format(2**lowbit))

    xs   = (((np.arange(2**lowbit )[:,None] & (1 << np.arange(N)))) > 0).astype(int)
    # xs_h = (((np.arange(2**highbit)[:,None] & (1 << np.arange(highbit)))) > 0).astype(int)

    try:
        print(discriminator.local(name))
        with open(discriminator.local(name), 'wb') as f:
            for i in range(2**highbit):
                print("Iteration {}/{} base: {}".format(i,2**highbit,i*(2**lowbit)), end=' ')
                # h = np.binary_repr(i*(2**lowbit), width=N)
                # print(h)
                # xs_h = np.unpackbits(np.array([i*(2**lowbit)],dtype=int))
                xs_h = (((np.array([i])[:,None] & (1 << np.arange(highbit)))) > 0).astype(int)
                xs[:,lowbit:] = xs_h
                # print(xs_h)
                # print(xs[:10])
                ys = discriminator.discriminate(xs,batch_size=100000)
                ind = np.where(ys > 0.5)
                valid_xs = xs[ind]
                print(len(valid_xs))
                np.savetxt(f,valid_xs,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")
    name = "generated_actions.csv"

    print("loading {}".format("{}/generated_states2.csv".format(directory)), end='...', flush=True)
    states  = np.loadtxt("{}/generated_states2.csv".format(directory),dtype=np.uint8)
    print("done.")
    total   = states.shape[0]
    N       = states.shape[1]
    actions = np.pad(states,((0,0),(0,N)),"constant")

    acc = 0

    try:
        print(ad.local(name))
        with open(ad.local(name), 'wb') as f:
            for i,s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
                actions[:,N:] = s
                ys            = ad.discriminate(actions,batch_size=400000)
                valid_actions = actions[np.where(ys > 0.8)]
                acc           += len(valid_actions)
                print(len(valid_actions),acc)
                np.savetxt(f,valid_actions,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
项目:pytorch-tutorial    作者:yunjey    | 项目源码 | 文件源码
def build_model(self):
        """Build generator and discriminator."""
        self.generator = Generator(z_dim=self.z_dim,
                                   image_size=self.image_size,
                                   conv_dim=self.g_conv_dim)
        self.discriminator = Discriminator(image_size=self.image_size,
                                           conv_dim=self.d_conv_dim)
        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr, [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()
项目:DeepLearning    作者:Wanwannodao    | 项目源码 | 文件源码
def __init__(self, input_shape):
        self.batch_size = input_shape[0]

        self.D = Discriminator(self.batch_size)
        self.G = Generator(self.batch_size)


        self.X = tf.placeholder(shape=input_shape, dtype=tf.float32, name="X")

        self.gen_img = self.G()

        self.g_loss = 0.5*(tf.reduce_mean( (self.D(self.G(reuse=True)) - 1.0)**2 ))
        self.d_loss = 0.5*(tf.reduce_mean((self.D(self.X, reuse=True) - 1.0)**2 )\
                           + tf.reduce_mean( (self.D( self.G(reuse=True), reuse=True))**2 ) )

        g_opt = tf.train.AdamOptimizer(learning_rate=4e-3,beta1=0.5)
        d_opt = tf.train.AdamOptimizer(learning_rate=1e-3,beta1=0.5)

        g_grads_and_vars = g_opt.compute_gradients(self.g_loss)
        d_grads_and_vars = d_opt.compute_gradients(self.d_loss)

        g_grads_and_vars = [[grad, var] for grad, var in g_grads_and_vars \
                            if grad is not None and var.name.startswith("G")]
        d_grads_and_vars = [[grad, var] for grad, var in d_grads_and_vars \
                            if grad is not None and var.name.startswith("D")]

        self.g_train_op = g_opt.apply_gradients(g_grads_and_vars)
        self.d_train_op = d_opt.apply_gradients(d_grads_and_vars)
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Do not take all memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.80)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if FLAGS.dataset == 'cityscapes':
            print('Select CITYSCAPES')
            mask_dir = CITYSCAPES_mask_dir
            syn_dir = CITYSCAPES_syn_dir
            FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
            FLAGS.dataset_dir = CITYSCAPES_dir
        elif FLAGS.dataset == 'inria':
            print('Select INRIAPerson')
            FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
            FLAGS.dataset_dir = INRIA_dir
        elif FLAGS.dataset == 'indoor':
            print('Select indoor')
            syn_dir = CITYSCAPES_syn_dir
            FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
            FLAGS.dataset_dir = indoor_dir

        discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
                      dataset_name=FLAGS.dataset, checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)

        if FLAGS.mode == 'test':
            print('Testing!')
            discriminator.test(FLAGS, syn_dir)
        elif FLAGS.mode == 'train':
            print('Train!')
            discriminator.train(FLAGS, syn_dir)
        elif FLAGS.mode == 'complete':
            print('Complete!')
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")

    # valid_states  = load("{}/states.csv".format(directory))
    valid_actions = load("{}/actions.csv".format(directory))
    threshold = maxdiff(valid_actions)
    print("maxdiff:",threshold)

    states  = load("{}/generated_states.csv".format(directory))

    path = "{}/generated_actions.csv".format(directory)

    total   = states.shape[0]
    N       = states.shape[1]
    acc = 0

    try:
        print(path)
        with open(path, 'wb') as f:
            for i,s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
                diff = np.sum(np.abs(states - s),axis=1)
                neighbors = states[np.where(diff<threshold)]
                tmp_actions = np.pad(neighbors,((0,0),(0,N)),"constant")
                tmp_actions[:,N:] = s
                ys            = ad.discriminate(tmp_actions,batch_size=400000)
                valid_actions = tmp_actions[np.where(ys > 0.8)]
                acc           += len(valid_actions)
                print(len(neighbors),len(valid_actions),acc)
                np.savetxt(f,valid_actions,"%d")
    except KeyboardInterrupt:
        print("dump stopped")
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def grid_search(path, train_in, train_out, test_in, test_out):
    # perform random trials on possible combinations
    network = Discriminator
    best_error = float('inf')
    best_params = None
    best_ae     = None
    results = []
    print("Network: {}".format(network))
    try:
        import itertools
        names  = [ k for k, _ in parameters.items()]
        values = [ v for _, v in parameters.items()]
        all_params = list(itertools.product(*values))
        random.shuffle(all_params)
        [ print(r) for r in all_params]
        for i,params in enumerate(all_params):
            config.reload_session()
            params_dict = { k:v for k,v in zip(names,params) }
            print("{}/{} Testing model with parameters=\n{}".format(i, len(all_params), params_dict))
            ae = learn_model(path, train_in,train_out,test_in,test_out,
                             network=curry(network, parameters=params_dict),
                             params_dict=params_dict)
            error = ae.net.evaluate(test_in,test_out,batch_size=100,verbose=0)
            results.append({'error':error, **params_dict})
            print("Evaluation result for:\n{}\nerror = {}".format(params_dict,error))
            print("Current results:")
            results.sort(key=lambda result: result['error'])
            [ print(r) for r in results]
            if error < best_error:
                print("Found a better parameter:\n{}\nerror:{} old-best:{}".format(
                    params_dict,error,best_error))
                del best_ae
                best_params = params_dict
                best_error = error
                best_ae = ae
            else:
                del ae
        print("Best parameter:\n{}\nerror: {}".format(best_params,best_error))
    finally:
        print(results)
    best_ae.save()
    with open(best_ae.local("grid_search.log"), 'a') as f:
        import json
        f.write("\n")
        json.dump(results, f)
    return best_ae,best_params,best_error
项目:Conditional-GAN    作者:m516825    | 项目源码 | 文件源码
def train(self):
        batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1

        print("Start training WGAN...\n")

        for t in range(self.FLAGS.iter):

            d_cost = 0
            g_coat = 0

            for d_ep in range(self.d_epoch):

                img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size)
                z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)

                feed_dict = {
                    self.seq:tags,
                    self.img:img,
                    self.z:z,
                    self.w_seq:w_tags,
                    self.w_img:w_img
                }

                _, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict)

                d_cost += loss/self.d_epoch

            z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
            feed_dict = {
                self.img:img,
                self.w_seq:w_tags,
                self.w_img:w_img,
                self.seq:tags,
                self.z:z
            }

            _, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict)

            current_step = tf.train.global_step(self.sess, self.global_step)

            g_cost = loss

            if current_step % self.FLAGS.display_every == 0:
                print("Epoch {}, Current_step {}".format(self.data.epoch, current_step))
                print("Discriminator loss :{}".format(d_cost))
                print("Generator loss     :{}".format(g_cost))
                print("---------------------------------")

            if current_step % self.FLAGS.checkpoint_every == 0:
                path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step)
                print ("\nSaved model checkpoint to {}\n".format(path))

            if current_step % self.FLAGS.dump_every == 0:
                self.eval(current_step)
                print("Dump test image")
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Do not take all memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)
    # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # w/ y label
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
                          c_dim=1, dataset_name=FLAGS.dataset,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        # w/o y label
        else:
            if FLAGS.dataset == 'cityscapes':
                print 'Select CITYSCAPES'
                mask_dir = CITYSCAPES_mask_dir
                syn_dir = CITYSCAPES_syn_dir_2
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
                FLAGS.dataset_dir = CITYSCAPES_dir
            elif FLAGS.dataset == 'inria':
                print 'Select INRIAPerson'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
                FLAGS.dataset_dir = INRIA_dir

            discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset,
                          checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)

        if FLAGS.mode == 'test':
            print('Testing!')
            discriminator.test(FLAGS, syn_dir)
        elif FLAGS.mode == 'train':
            print('Train!')
            discriminator.train(FLAGS, syn_dir)
        elif FLAGS.mode == 'complete':
            print('Complete!')