我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用train.train()。
def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Deep360Pilot') parser.add_argument('--opt', dest='opt_method', help='[Adam, Adadelta, RMSProp]', default='Adam') parser.add_argument('--root', dest='root_path', help='root path of data', default='./') parser.add_argument('--data', dest='data_path', help='data path of data', default='./data/') parser.add_argument('--mode', dest='mode', help='[train, test, vid, pred]', required=True) parser.add_argument('--model', dest='model_path', help='model path to load') parser.add_argument('--gpu', dest='gpu', help='Choose which gpu to use', default='0') parser.add_argument('-n', '--name', dest='video_name', help='youtube_id + _ + part') parser.add_argument('-d', '--domain', dest='domain', help='skate, skiing, ...', required=True) parser.add_argument('-l', '--lambda', dest='lam', help='movement tradeoff lambda, the higher the smoother.', type=float, required=True) parser.add_argument('-b', '--boxnum', dest='boxnum', help='boxes number, Use integer, [8, 16, 32]', type=int, required=True) parser.add_argument('-p', '--phase', dest='phase', help='phase [classify, regress]', required=True) parser.add_argument('-s', '--save', dest='save', help='save images for debug', default=False) group = parser.add_mutually_exclusive_group() group.add_argument('--debug', dest='debug', help='Start debug mode or not', action='store_true') args = parser.parse_args() return args, parser
def main(): parser = argparse.ArgumentParser(description='PyTorch YOLO') parser.add_argument('--use_cuda', type=bool, default=False, help='use cuda or not') parser.add_argument('--epochs', type=int, default=10, help='Epochs') parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--seed', type=int, default=1234, help='Random seed') args = parser.parse_args() torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = args.use_cuda train.train(args)
def main(_): pp.pprint(FLAGS.__flags) # training/inference with tf.Session() as sess: dcgan = DCGAN(sess, FLAGS) # path checks if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())): os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())) if not os.path.exists(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())): os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())) # load checkpoint if found if dcgan.checkpoint_exists(): print("Loading checkpoints...") if dcgan.load(): print "success!" else: raise IOError("Could not read checkpoints from {0}!".format( FLAGS.checkpoint_dir)) else: if not FLAGS.train: raise IOError("No checkpoints found but need for sampling!") print "No checkpoints found. Training from scratch." dcgan.load() # train DCGAN if FLAGS.train: train(dcgan) # inference/visualization code goes here print "Generating samples..." inference.sample_images(dcgan) print "Generating visualizations of z..." inference.visualize_z(dcgan)
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) with tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=False)) as sess: if FLAGS.dataset == 'mnist': assert False dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, sample_size = 64, z_dim = 8192, d_label_smooth = .25, generator_target_prob = .75 / 2., out_stddev = .075, out_init_b = - .45, image_shape=[FLAGS.image_width, FLAGS.image_width, 3], dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, generator=Generator(), train_func=train, discriminator_func=discriminator, build_model_func=build_model, config=FLAGS, devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"] ) if FLAGS.is_train: print "TRAINING" dcgan.train(FLAGS) print "DONE TRAINING" else: dcgan.load(FLAGS.checkpoint_dir) OPTION = 2 visualize(sess, dcgan, FLAGS, OPTION)
def main(job_id, params): logger.info("Model options:\n{}".format(pprint.pformat(params))) validerr = train(**params) return validerr
def train_word2vec(corpus_file, out_file, **kwargs): word2vec.train(corpus_file, out_file, **kwargs)
def train_model(db_file, entity_db_file, vocab_file, word2vec, **kwargs): db = AbstractDB(db_file, 'r') entity_db = EntityDB.load(entity_db_file) vocab = Vocab.load(vocab_file) if word2vec: w2vec = ModelReader(word2vec) else: w2vec = None train.train(db, entity_db, vocab, w2vec, **kwargs)
def main(outputName): print("Welcome into RNTN implementation 0.6 (recording will be on ", outputName, ")") random.seed("MetaMind") # Lucky seed ? Fixed seed for replication np.random.seed(7) print("Parsing dataset, creating dictionary...") # Dictionary initialisation vocabulary.initVocab() # Loading dataset datasets = {} datasets['training'] = utils.loadDataset("trees/train.txt"); print("Training loaded !") datasets['testing'] = utils.loadDataset("trees/test.txt"); print("Testing loaded !") datasets['validating'] = utils.loadDataset("trees/dev.txt"); print("Validation loaded !") print("Datasets loaded !") print("Nb of words", vocabulary.vocab.length()); # Datatransform (normalisation, remove outliers,...) ?? > Not here # Grid search on our hyperparameters (too long for complete k-fold cross validation so just train/test) for mBS in miniBatchSize: for aRNI in adagradResetNbIter: for lR in learningRate: for rT in regularisationTerm: params = {} params["nbEpoch"] = nbEpoch params["learningRate"] = lR params["regularisationTerm"] = rT params["adagradResetNbIter"] = aRNI params["miniBatchSize"] = mBS # No need to reset the vocabulary values (contained in model.L so automatically reset) # Same for the training and testing set (output values recomputed at each iterations) model, errors = train.train(outputName, datasets, params) # TODO: Plot the cross-validation curve # TODO: Plot a heat map of the hyperparameters cost to help tunning them ? ## Validate on the last computed model (Only used for final training) #print("Training complete, validating...") #vaError = model.computeError(datasets['validating'], True) #print("Validation error: ", vaError) print("The End. Thank you for using this program!")