我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用utils.load_model()。
def main(): model = utils.load_model(args) new_model = conv_vh_decomposition(model, args) new_model.save(args.save_model)
def main(): model = utils.load_model(args) new_model = fc_decomposition(model, args) new_model.save(args.save_model)
def main(): parser = argparse.ArgumentParser(description='chat utility') parser.add_argument('--model_file', type=str, help='model file') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') parser.add_argument('--context_file', type=str, default='', help='context file') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--num_types', type=int, default=3, help='number of object types') parser.add_argument('--num_objects', type=int, default=6, help='total number of objects') parser.add_argument('--max_score', type=int, default=10, help='max score per object') parser.add_argument('--score_threshold', type=int, default=6, help='successful dialog should have more than score_threshold in score') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--smart_ai', action='store_true', default=False, help='make AI smart again') parser.add_argument('--ai_starts', action='store_true', default=False, help='allow AI to start the dialog') parser.add_argument('--ref_text', type=str, help='file with the reference text') args = parser.parse_args() utils.set_seed(args.seed) human = HumanAgent(domain.get_domain(args.domain)) alice_ty = LstmRolloutAgent if args.smart_ai else LstmAgent ai = alice_ty(utils.load_model(args.model_file), args) agents = [ai, human] if args.ai_starts else [human, ai] dialog = Dialog(agents, args) logger = DialogLogger(verbose=True) # either take manually produced contextes, or relay on the ones from the dataset if args.context_file == '': ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score) else: ctx_gen = ContextGenerator(args.context_file) chat = Chat(dialog, ctx_gen, logger) chat.run()
def main(): parser = argparse.ArgumentParser(description='selfplaying script') parser.add_argument('--alice_model_file', type=str, help='Alice model file') parser.add_argument('--bob_model_file', type=str, help='Bob model file') parser.add_argument('--context_file', type=str, help='context file') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--verbose', action='store_true', default=False, help='print out converations') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--score_threshold', type=int, default=6, help='successful dialog should have more than score_threshold in score') parser.add_argument('--max_turns', type=int, default=20, help='maximum number of turns in a dialog') parser.add_argument('--log_file', type=str, default='', help='log successful dialogs to file for training') parser.add_argument('--smart_alice', action='store_true', default=False, help='make Alice smart again') parser.add_argument('--fast_rollout', action='store_true', default=False, help='to use faster rollouts') parser.add_argument('--rollout_bsz', type=int, default=100, help='rollout batch size') parser.add_argument('--rollout_count_threshold', type=int, default=3, help='rollout count threshold') parser.add_argument('--smart_bob', action='store_true', default=False, help='make Bob smart again') parser.add_argument('--ref_text', type=str, help='file with the reference text') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') args = parser.parse_args() utils.set_seed(args.seed) alice_model = utils.load_model(args.alice_model_file) alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout) alice = alice_ty(alice_model, args, name='Alice') bob_model = utils.load_model(args.bob_model_file) bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout) bob = bob_ty(bob_model, args, name='Bob') dialog = Dialog([alice, bob], args) logger = DialogLogger(verbose=args.verbose, log_file=args.log_file) ctx_gen = ContextGenerator(args.context_file) selfplay = SelfPlay(dialog, ctx_gen, args, logger) selfplay.run()
def main(): parser = argparse.ArgumentParser(description='Negotiator') parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt', help='location of the dataset') parser.add_argument('--model_file', type=str, help='model file') parser.add_argument('--smart_ai', action='store_true', default=False, help='to use rollouts') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') parser.add_argument('--log_file', type=str, default='', help='log file') args = parser.parse_args() utils.set_seed(args.seed) model = utils.load_model(args.model_file) ai = LstmAgent(model, args) logger = DialogLogger(verbose=True, log_file=args.log_file) domain = get_domain(args.domain) score_func = rollout if args.smart_ai else likelihood dataset, sents = read_dataset(args.dataset) ranks, n, k = 0, 0, 0 for ctx, dialog in dataset: start_time = time.time() # start new conversation ai.feed_context(ctx) for sent, you in dialog: if you: # if it is your turn to say, take the target word and compute its rank rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func) # compute lang_h for the groundtruth sentence enc = ai._encode(sent, ai.model.word_dict) _, ai.lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, args.temperature) # save hidden states and the utterance ai.lang_hs.append(lang_hs) ai.words.append(ai.model.word2var('YOU:')) ai.words.append(Variable(enc)) ranks += rank n += 1 else: ai.read(sent) k += 1 time_elapsed = time.time() - start_time logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed)) logger.dump('final avg rank %.3f' % (1. * ranks / n))
def main(): parser = argparse.ArgumentParser(description='testing script') parser.add_argument('--data', type=str, default='data/negotiate', help='location of the data corpus') parser.add_argument('--unk_threshold', type=int, default=20, help='minimum word frequency to be in dictionary') parser.add_argument('--model_file', type=str, help='pretrained model file') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--hierarchical', action='store_true', default=False, help='use hierarchical model') parser.add_argument('--bsz', type=int, default=16, help='batch size') parser.add_argument('--cuda', action='store_true', default=False, help='use CUDA') args = parser.parse_args() device_id = utils.use_cuda(args.cuda) utils.set_seed(args.seed) corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold, verbose=True) model = utils.load_model(args.model_file) crit = Criterion(model.word_dict, device_id=device_id) sel_crit = Criterion(model.item_dict, device_id=device_id, bad_toks=['<disconnect>', '<disagree>']) testset, testset_stats = corpus.test_dataset(args.bsz, device_id=device_id) test_loss, test_select_loss = 0, 0 N = len(corpus.word_dict) for batch in testset: # run forward on the batch, produces output, hidden, target, # selection output and selection target out, hid, tgt, sel_out, sel_tgt = Engine.forward(model, batch, volatile=False) # compute LM and selection losses test_loss += tgt.size(0) * crit(out.view(-1, N), tgt).data[0] test_select_loss += sel_crit(sel_out, sel_tgt).data[0] test_loss /= testset_stats['nonpadn'] test_select_loss /= len(testset) print('testloss %.3f | testppl %.3f' % (test_loss, np.exp(test_loss))) print('testselectloss %.3f | testselectppl %.3f' % (test_select_loss, np.exp(test_select_loss)))
def _train(net, training_data, validation_data, model_name, learning_rate, max_epochs, min_improvement): min_learning_rate = 1e-6 best_validation_ppl = np.inf divide = False for epoch in range(1, max_epochs + 1): epoch_start = time() print "\n======= EPOCH %s =======" % epoch print "\tLearning rate is %s" % learning_rate train_ppl = _process_corpus(net, training_data, mode='train', learning_rate=learning_rate) print "\tTrain PPL is %.3f" % train_ppl validation_ppl = _process_corpus(net, validation_data, mode='test') print "\tValidation PPL is %.3f" % validation_ppl print "\tTime taken: %ds" % (time() - epoch_start) if np.log(validation_ppl) * min_improvement > np.log(best_validation_ppl): if not divide: divide = True print "\tStarting to reduce the learning rate..." if validation_ppl > best_validation_ppl: print "\tLoading best model." net = utils.load_model("../out/" + model_name) else: if validation_ppl < best_validation_ppl: print "\tSaving model." net.save("../out/" + model_name, final = True) break else: print "\tNew best model! Saving..." best_validation_ppl = validation_ppl final = learning_rate / 2. < min_learning_rate or epoch == max_epochs net.save("../out/" + model_name, final) if divide: learning_rate /= 2. if learning_rate < min_learning_rate: break print "-"*30 print "Finished training." print "Best validation PPL is %.3f\n\n" % best_validation_ppl
def train(): log.info('loading dataset...') train_data=TextIterator(train_file,n_batch=batch_size,maxlen=maxlen) valid_data = TextIterator(valid_file, n_batch=batch_size, maxlen=maxlen) test_data = TextIterator(test_file, n_batch=batch_size, maxlen=maxlen,mode=2) log.info('building models....') model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru', optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size) start=time.time() if os.path.isfile(model_dir): print 'loading checkpoint parameters....',model_dir model=load_model(model_dir,model) if goto_line!=0: train_data.goto_line(goto_line) print 'goto line:',goto_line log.info('training start...') for epoch in xrange(NEPOCH): costs=0 idx=0 error_rate_list=[] try: for (x,xmask),(y,ymask),label in train_data: idx+=1 if x.shape[-1]!=batch_size: continue cost,error_rate=model.train(x,xmask,y,ymask,label,lr) #print cost,error_rate #projected_output,cost= model.test(x, xmask, y, ymask,label) #print "projected_output shape:", projected_output.shape ##print "cnn_output shape:",cnn_output.shape #print "cost:",cost costs+=cost error_rate_list.append(error_rate) if np.isnan(cost) or np.isinf(cost): print 'Nan Or Inf detected!' print x.shape,y.shape print cost,error_rate return -1 if idx % disp_freq==0: log.info('epoch: %d, idx: %d cost: %.3f, Accuracy: %.3f '%(epoch,idx,costs/idx, np.mean(list(itertools.chain.from_iterable(error_rate_list))))) if idx%dump_freq==0: save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model) except Exception: print np.max(x),np.max(y) print x.shape,y.shape evaluate(train_data,valid_data, test_data,model) log.info("Finished. Time = " +str(time.time()-start))
def test(): log.info('loading dataset...') log.info('building models....') model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size) log.info('training start....') start=time.time() if os.path.isfile(model_dir): print 'loading checkpoint parameters....',model_dir model=load_model(model_dir,model) for epoch in xrange(NEPOCH): costs=[] idx=0 acc_list=[] train_data = TextIterator(train_file+".train."+str(epoch), n_batch=batch_size, maxlen=maxlen) valid_data = TextIterator(train_file+".valid."+str(epoch), n_batch=batch_size, maxlen=maxlen) for (x,xmask),(y,ymask),label in train_data: idx+=1 if x.shape[-1]!=batch_size: continue #print x.shape cost,acc=model.predict(x,xmask,y,ymask,label) #print cost #projected_output,cost= model.test(x, xmask, y, ymask,label) #print "projected_output shape:", projected_output.shape ##print "cnn_output shape:",cnn_output.shape #print "cost:",cost costs.append(cost) acc_list.append(acc) if np.isnan(np.mean(cost)) or np.isinf(np.mean(cost)): print 'Nan Or Inf detected!' print "x:",x print x.shape print 'y:',y print y.shape return -1 #log.info('dumping parameters....') #save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model) log.info('epoch: %d, cost: %.3f, Accuracy: %.3f ' % ( epoch,np.mean(list(itertools.chain.from_iterable(costs))), np.mean(list(itertools.chain.from_iterable(acc_list))))) loss, acc = evaluate(valid_data, model) log.info('validation cost: %.3f, Accuracy: %.3f' % (loss,acc)) log.info("Finished. Time = " +str(time.time()-start))