Python utils 模块,load_model() 实例源码

我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用utils.load_model()

项目:mxnet_tk1    作者:starimpact    | 项目源码 | 文件源码
def main():
  model = utils.load_model(args)  
  new_model = conv_vh_decomposition(model, args)
  new_model.save(args.save_model)
项目:mxnet_tk1    作者:starimpact    | 项目源码 | 文件源码
def main():
  model = utils.load_model(args)  
  new_model = fc_decomposition(model, args)
  new_model.save(args.save_model)
项目:end-to-end-negotiator    作者:facebookresearch    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='chat utility')
    parser.add_argument('--model_file', type=str,
        help='model file')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--context_file', type=str, default='',
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--num_types', type=int, default=3,
        help='number of object types')
    parser.add_argument('--num_objects', type=int, default=6,
        help='total number of objects')
    parser.add_argument('--max_score', type=int, default=10,
        help='max score per object')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--smart_ai', action='store_true', default=False,
        help='make AI smart again')
    parser.add_argument('--ai_starts', action='store_true', default=False,
        help='allow AI to start the dialog')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    human = HumanAgent(domain.get_domain(args.domain))

    alice_ty = LstmRolloutAgent if args.smart_ai else LstmAgent
    ai = alice_ty(utils.load_model(args.model_file), args)


    agents = [ai, human] if args.ai_starts else [human, ai]

    dialog = Dialog(agents, args)
    logger = DialogLogger(verbose=True)
    # either take manually produced contextes, or relay on the ones from the dataset
    if args.context_file == '':
        ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score)
    else:
        ctx_gen = ContextGenerator(args.context_file)

    chat = Chat(dialog, ctx_gen, logger)
    chat.run()
项目:end-to-end-negotiator    作者:facebookresearch    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='selfplaying script')
    parser.add_argument('--alice_model_file', type=str,
        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str,
        help='Bob model file')
    parser.add_argument('--context_file', type=str,
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--verbose', action='store_true', default=False,
        help='print out converations')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--max_turns', type=int, default=20,
        help='maximum number of turns in a dialog')
    parser.add_argument('--log_file', type=str, default='',
        help='log successful dialogs to file for training')
    parser.add_argument('--smart_alice', action='store_true', default=False,
        help='make Alice smart again')
    parser.add_argument('--fast_rollout', action='store_true', default=False,
        help='to use faster rollouts')
    parser.add_argument('--rollout_bsz', type=int, default=100,
        help='rollout batch size')
    parser.add_argument('--rollout_count_threshold', type=int, default=3,
        help='rollout count threshold')
    parser.add_argument('--smart_bob', action='store_true', default=False,
        help='make Bob smart again')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)
    alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout)
    alice = alice_ty(alice_model, args, name='Alice')

    bob_model = utils.load_model(args.bob_model_file)
    bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout)
    bob = bob_ty(bob_model, args, name='Bob')

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    selfplay = SelfPlay(dialog, ctx_gen, args, logger)
    selfplay.run()
项目:end-to-end-negotiator    作者:facebookresearch    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='Negotiator')
    parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt',
        help='location of the dataset')
    parser.add_argument('--model_file', type=str,
        help='model file')
    parser.add_argument('--smart_ai', action='store_true', default=False,
        help='to use rollouts')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--log_file', type=str, default='',
        help='log file')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    model = utils.load_model(args.model_file)
    ai = LstmAgent(model, args)
    logger = DialogLogger(verbose=True, log_file=args.log_file)
    domain = get_domain(args.domain)

    score_func = rollout if args.smart_ai else likelihood

    dataset, sents = read_dataset(args.dataset)
    ranks, n, k = 0, 0, 0
    for ctx, dialog in dataset:
        start_time = time.time()
        # start new conversation
        ai.feed_context(ctx)
        for sent, you in dialog:
            if you:
                # if it is your turn to say, take the target word and compute its rank
                rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func)
                # compute lang_h for the groundtruth sentence
                enc = ai._encode(sent, ai.model.word_dict)
                _, ai.lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, args.temperature)
                # save hidden states and the utterance
                ai.lang_hs.append(lang_hs)
                ai.words.append(ai.model.word2var('YOU:'))
                ai.words.append(Variable(enc))
                ranks += rank
                n += 1
            else:
                ai.read(sent)
        k += 1
        time_elapsed = time.time() - start_time
        logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed))

    logger.dump('final avg rank %.3f' % (1. * ranks / n))
项目:end-to-end-negotiator    作者:facebookresearch    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='testing script')
    parser.add_argument('--data', type=str, default='data/negotiate',
        help='location of the data corpus')
    parser.add_argument('--unk_threshold', type=int, default=20,
        help='minimum word frequency to be in dictionary')
    parser.add_argument('--model_file', type=str,
        help='pretrained model file')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--hierarchical', action='store_true', default=False,
        help='use hierarchical model')
    parser.add_argument('--bsz', type=int, default=16,
        help='batch size')
    parser.add_argument('--cuda', action='store_true', default=False,
        help='use CUDA')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold, verbose=True)
    model = utils.load_model(args.model_file)

    crit = Criterion(model.word_dict, device_id=device_id)
    sel_crit = Criterion(model.item_dict, device_id=device_id,
        bad_toks=['<disconnect>', '<disagree>'])


    testset, testset_stats = corpus.test_dataset(args.bsz, device_id=device_id)
    test_loss, test_select_loss = 0, 0

    N = len(corpus.word_dict)
    for batch in testset:
        # run forward on the batch, produces output, hidden, target,
        # selection output and selection target
        out, hid, tgt, sel_out, sel_tgt = Engine.forward(model, batch, volatile=False)

        # compute LM and selection losses
        test_loss += tgt.size(0) * crit(out.view(-1, N), tgt).data[0]
        test_select_loss += sel_crit(sel_out, sel_tgt).data[0]

    test_loss /= testset_stats['nonpadn']
    test_select_loss /= len(testset)
    print('testloss %.3f | testppl %.3f' % (test_loss, np.exp(test_loss)))
    print('testselectloss %.3f | testselectppl %.3f' % (test_select_loss, np.exp(test_select_loss)))
项目:python-lstm-punctuation    作者:kaituoxu    | 项目源码 | 文件源码
def _train(net, training_data, validation_data, model_name, learning_rate,
           max_epochs, min_improvement):
    min_learning_rate = 1e-6
    best_validation_ppl = np.inf
    divide = False

    for epoch in range(1, max_epochs + 1):
        epoch_start = time()

        print "\n======= EPOCH %s =======" % epoch
        print "\tLearning rate is %s" % learning_rate

        train_ppl = _process_corpus(net, training_data, mode='train',
                                    learning_rate=learning_rate)
        print "\tTrain PPL is %.3f" % train_ppl

        validation_ppl = _process_corpus(net, validation_data, mode='test')
        print "\tValidation PPL is %.3f" % validation_ppl

        print "\tTime taken: %ds" % (time() - epoch_start)

        if np.log(validation_ppl) * min_improvement > np.log(best_validation_ppl):
            if not divide:
                divide = True
                print "\tStarting to reduce the learning rate..."
                if validation_ppl > best_validation_ppl:
                    print "\tLoading best model."
                    net = utils.load_model("../out/" + model_name)
            else:
                if validation_ppl < best_validation_ppl:
                    print "\tSaving model."
                    net.save("../out/" + model_name, final = True)
                break
        else:
            print "\tNew best model! Saving..."
            best_validation_ppl = validation_ppl
            final = learning_rate / 2. < min_learning_rate or epoch == max_epochs
            net.save("../out/" + model_name, final)

        if divide:
            learning_rate /= 2.

        if learning_rate < min_learning_rate:
            break

    print "-"*30
    print "Finished training."
    print "Best validation PPL is %.3f\n\n" % best_validation_ppl
项目:kaggle-quora-solution-8th    作者:qqgeogor    | 项目源码 | 文件源码
def train():
    log.info('loading dataset...')
    train_data=TextIterator(train_file,n_batch=batch_size,maxlen=maxlen)
    valid_data = TextIterator(valid_file, n_batch=batch_size, maxlen=maxlen)
    test_data = TextIterator(test_file, n_batch=batch_size, maxlen=maxlen,mode=2)
    log.info('building models....')
    model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',  optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
    start=time.time()

    if os.path.isfile(model_dir):
        print 'loading checkpoint parameters....',model_dir
        model=load_model(model_dir,model)
    if goto_line!=0:
        train_data.goto_line(goto_line)
        print 'goto line:',goto_line

    log.info('training start...')
    for epoch in xrange(NEPOCH):
        costs=0
        idx=0
        error_rate_list=[]
        try:
            for (x,xmask),(y,ymask),label in train_data:
                idx+=1
                if x.shape[-1]!=batch_size:
                    continue
                cost,error_rate=model.train(x,xmask,y,ymask,label,lr)
                #print cost,error_rate
                #projected_output,cost= model.test(x, xmask, y, ymask,label)
                #print "projected_output shape:", projected_output.shape
                ##print "cnn_output shape:",cnn_output.shape
                #print "cost:",cost
                costs+=cost
                error_rate_list.append(error_rate)
                if np.isnan(cost) or np.isinf(cost):
                    print 'Nan Or Inf detected!'
                    print x.shape,y.shape
                    print cost,error_rate
                    return  -1
                if idx % disp_freq==0:
                    log.info('epoch: %d, idx: %d cost: %.3f, Accuracy: %.3f '%(epoch,idx,costs/idx,  np.mean(list(itertools.chain.from_iterable(error_rate_list)))))

                if idx%dump_freq==0:
                    save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)
        except Exception:
            print np.max(x),np.max(y)
            print x.shape,y.shape



        evaluate(train_data,valid_data, test_data,model)

    log.info("Finished. Time = " +str(time.time()-start))
项目:kaggle-quora-solution-8th    作者:qqgeogor    | 项目源码 | 文件源码
def test():
    log.info('loading dataset...')

    log.info('building models....')
    model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
    log.info('training start....')
    start=time.time()

    if os.path.isfile(model_dir):
        print 'loading checkpoint parameters....',model_dir
        model=load_model(model_dir,model)


    for epoch in xrange(NEPOCH):
        costs=[]
        idx=0
        acc_list=[]
        train_data = TextIterator(train_file+".train."+str(epoch), n_batch=batch_size, maxlen=maxlen)
        valid_data = TextIterator(train_file+".valid."+str(epoch), n_batch=batch_size, maxlen=maxlen)
        for (x,xmask),(y,ymask),label in train_data:
            idx+=1
            if x.shape[-1]!=batch_size:
                continue
            #print x.shape
            cost,acc=model.predict(x,xmask,y,ymask,label)
            #print cost
            #projected_output,cost= model.test(x, xmask, y, ymask,label)
            #print "projected_output shape:", projected_output.shape
            ##print "cnn_output shape:",cnn_output.shape
            #print "cost:",cost
            costs.append(cost)
            acc_list.append(acc)
            if np.isnan(np.mean(cost)) or np.isinf(np.mean(cost)):
                print 'Nan Or Inf detected!'
                print "x:",x
                print x.shape
                print 'y:',y
                print y.shape
                return  -1
        #log.info('dumping parameters....')    
        #save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)

        log.info('epoch: %d, cost: %.3f, Accuracy: %.3f ' % (
        epoch,np.mean(list(itertools.chain.from_iterable(costs))), np.mean(list(itertools.chain.from_iterable(acc_list)))))
        loss, acc = evaluate(valid_data, model)
        log.info('validation cost: %.3f, Accuracy: %.3f' % (loss,acc))

    log.info("Finished. Time = " +str(time.time()-start))