Python train 模块,train() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用train.train()

项目:Deep360Pilot-CVPR17    作者:eborboihuc    | 项目源码 | 文件源码
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Deep360Pilot')
    parser.add_argument('--opt', dest='opt_method', help='[Adam, Adadelta, RMSProp]', default='Adam')
    parser.add_argument('--root', dest='root_path', help='root path of data', default='./')
    parser.add_argument('--data', dest='data_path', help='data path of data', default='./data/')
    parser.add_argument('--mode', dest='mode', help='[train, test, vid, pred]', required=True)
    parser.add_argument('--model', dest='model_path', help='model path to load')
    parser.add_argument('--gpu', dest='gpu', help='Choose which gpu to use', default='0')
    parser.add_argument('-n', '--name', dest='video_name', help='youtube_id + _ + part')
    parser.add_argument('-d', '--domain', dest='domain', help='skate, skiing, ...', required=True)
    parser.add_argument('-l', '--lambda', dest='lam', help='movement tradeoff lambda, the higher the smoother.', type=float, required=True)
    parser.add_argument('-b', '--boxnum', dest='boxnum', help='boxes number, Use integer, [8, 16, 32]', type=int, required=True)
    parser.add_argument('-p', '--phase', dest='phase', help='phase [classify, regress]', required=True)
    parser.add_argument('-s', '--save', dest='save', help='save images for debug', default=False)

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--debug', dest='debug', help='Start debug mode or not', action='store_true')

    args = parser.parse_args()

    return args, parser
项目:yolo-pytorch    作者:makora9143    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='PyTorch YOLO')

    parser.add_argument('--use_cuda', type=bool, default=False,
                        help='use cuda or not')
    parser.add_argument('--epochs', type=int, default=10,
                        help='Epochs')
    parser.add_argument('--batch_size', type=int, default=1,
                        help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='Learning rate')
    parser.add_argument('--seed', type=int, default=1234,
                        help='Random seed')

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = args.use_cuda
    train.train(args)
项目:dcgan-tfslim    作者:mqtlam    | 项目源码 | 文件源码
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
        if not os.path.exists(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found but need for sampling!")
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)

        # inference/visualization code goes here
        print "Generating samples..."
        inference.sample_images(dcgan)
        print "Generating visualizations of z..."
        inference.visualize_z(dcgan)
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session(config=tf.ConfigProto(
              allow_soft_placement=True, log_device_placement=False)) as sess:
        if FLAGS.dataset == 'mnist':
            assert False
        dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    sample_size = 64,
                    z_dim = 8192,
                    d_label_smooth = .25,
                    generator_target_prob = .75 / 2.,
                    out_stddev = .075,
                    out_init_b = - .45,
                    image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    generator=Generator(),
                    train_func=train, discriminator_func=discriminator,
                    build_model_func=build_model, config=FLAGS,
                    devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
                    )

        if FLAGS.is_train:
            print "TRAINING"
            dcgan.train(FLAGS)
            print "DONE TRAINING"
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        OPTION = 2
        visualize(sess, dcgan, FLAGS, OPTION)
项目:rmn    作者:orhanf    | 项目源码 | 文件源码
def main(job_id, params):
    logger.info("Model options:\n{}".format(pprint.pformat(params)))
    validerr = train(**params)
    return validerr
项目:ntee    作者:studio-ousia    | 项目源码 | 文件源码
def train_word2vec(corpus_file, out_file, **kwargs):
    word2vec.train(corpus_file, out_file, **kwargs)
项目:ntee    作者:studio-ousia    | 项目源码 | 文件源码
def train_model(db_file, entity_db_file, vocab_file, word2vec, **kwargs):
    db = AbstractDB(db_file, 'r')
    entity_db = EntityDB.load(entity_db_file)
    vocab = Vocab.load(vocab_file)

    if word2vec:
        w2vec = ModelReader(word2vec)
    else:
        w2vec = None

    train.train(db, entity_db, vocab, w2vec, **kwargs)
项目:SentimentAnalysis    作者:Conchylicultor    | 项目源码 | 文件源码
def main(outputName):
    print("Welcome into RNTN implementation 0.6 (recording will be on ", outputName, ")")

    random.seed("MetaMind") # Lucky seed ? Fixed seed for replication
    np.random.seed(7)

    print("Parsing dataset, creating dictionary...")
    # Dictionary initialisation
    vocabulary.initVocab()

    # Loading dataset
    datasets = {}
    datasets['training'] = utils.loadDataset("trees/train.txt");
    print("Training loaded !")
    datasets['testing'] = utils.loadDataset("trees/test.txt");
    print("Testing loaded !")
    datasets['validating'] = utils.loadDataset("trees/dev.txt");
    print("Validation loaded !")

    print("Datasets loaded !")
    print("Nb of words", vocabulary.vocab.length());

    # Datatransform (normalisation, remove outliers,...) ?? > Not here

    # Grid search on our hyperparameters (too long for complete k-fold cross validation so just train/test)
    for mBS in miniBatchSize:
        for aRNI in adagradResetNbIter:
            for lR in learningRate:
                for rT in regularisationTerm:
                    params = {}
                    params["nbEpoch"]            = nbEpoch
                    params["learningRate"]       = lR
                    params["regularisationTerm"] = rT
                    params["adagradResetNbIter"] = aRNI
                    params["miniBatchSize"]      = mBS
                    # No need to reset the vocabulary values (contained in model.L so automatically reset)
                    # Same for the training and testing set (output values recomputed at each iterations)
                    model, errors = train.train(outputName, datasets, params)

    # TODO: Plot the cross-validation curve
    # TODO: Plot a heat map of the hyperparameters cost to help tunning them ?

    ## Validate on the last computed model (Only used for final training)
    #print("Training complete, validating...")
    #vaError = model.computeError(datasets['validating'], True)
    #print("Validation error: ", vaError)

    print("The End. Thank you for using this program!")