我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用model.RNN。
def __init__(self, training_file='../res/trump_tweets.txt', model_file='../res/model.pt', n_epochs=1000000, hidden_size=256, n_layers=2, learning_rate=0.001, chunk_len=140): self.training_file = training_file self.model_file = model_file self.n_epochs = n_epochs self.hidden_size = hidden_size self.n_layers = n_layers self.learning_rate = learning_rate self.chunk_len = chunk_len self.file, self.file_len = read_file(training_file) if os.path.isfile(model_file): self.decoder = torch.load(model_file) print('Loaded old model!') else: self.decoder = RNN(n_characters, hidden_size, n_characters, n_layers) print('Constructed new model!') self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), learning_rate) self.criterion = nn.CrossEntropyLoss() self.generator = Generator(self.decoder)
def __init__(self, input_size, hidden_size, output_size, n_layers=1, gpu=-1): self.decoder = RNN(input_size, hidden_size, output_size, n_layers, gpu) if gpu >= 0: print("Use GPU %d" % torch.cuda.current_device()) self.decoder.cuda() self.optimizer = torch.optim.Adam(self.decoder.parameters(), lr=0.01) self.criterion = nn.CrossEntropyLoss()
def main(_): # Save default params and set scope saved_params = FLAGS.__flags if saved_params['ensemble']: model_name = 'ensemble' elif saved_params['ngram'] == 1: model_name = 'unigram' elif saved_params['ngram'] == 2: model_name = 'bigram' elif saved_params['ngram'] == 3: model_name = 'trigram' else: assert True, 'Not supported ngram %d'% saved_params['ngram'] model_name += '_embedding' if saved_params['embed'] else '_no_embedding' saved_params['model_name'] = '%s' % model_name saved_params['checkpoint_dir'] += model_name pprint.PrettyPrinter().pprint(saved_params) saved_dataset = get_data(saved_params) validation_writer = open(saved_params['valid_result_path'], 'a') validation_writer.write(model_name + "\n") validation_writer.write("[dim_hidden, dim_rnn_cell, learning_rate, lstm_dropout, lstm_layer, hidden_dropout, dim_embed]\n") validation_writer.write("combination\ttop1\ttop5\tepoch\n") # Run the model for _ in range(saved_params['valid_iteration']): # Sample parameter sets params, combination = sample_parameters(saved_params.copy()) dataset = saved_dataset[:] # Initialize embeddings uni_init = get_char2vec(dataset[0][0][:], params['dim_embed_unigram'], dataset[3][0]) bi_init = get_char2vec(dataset[0][1][:], params['dim_embed_bigram'], dataset[3][4]) tri_init = get_char2vec(dataset[0][2][:], params['dim_embed_trigram'], dataset[3][5]) print(model_name, 'Parameter sets: ', end='') pprint.PrettyPrinter().pprint(combination) rnn_model = RNN(params, [uni_init, bi_init, tri_init]) top1, top5, ep = experiment(rnn_model, dataset, params) validation_writer.write(str(combination) + '\t') validation_writer.write(str(top1) + '\t' + str(top5) + '\tEp:' + str(ep) + '\n') validation_writer.close()
def train_and_test(challenge, rnn_cell): ''' ???? :return: ''' train, test = helper.extract_file(challenge) vocab, word_idx, story_maxlen, query_maxlen = helper.get_vocab(train, test) vocab_size = len(vocab) + 1 # Reserve 0 for masking via pad_sequences x, xq, y = helper.vectorize_stories(train, word_idx, story_maxlen, query_maxlen) tx, txq, ty = helper.vectorize_stories(test, word_idx, story_maxlen, query_maxlen) with tf.Graph().as_default() as graph: story_pl, question_pl, answer_pl, dropout_pl = get_placeholder(vocab_size, story_maxlen, query_maxlen) rnn = model.RNN(rnn_cell, FLAGS.embed_dim, FLAGS.rnn_size, vocab_size) logits = rnn.inference(story_pl, question_pl, dropout_pl) loss = rnn.loss(logits, answer_pl) train_op = rnn.train(loss, FLAGS.init_learning_rate) correct = rnn.eval(logits, answer_pl) init = tf.global_variables_initializer() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph=graph) as sess: # ??????? sess.run(init) max_test_acc = 0 for i in range(FLAGS.num_epochs): batch_id = 1 train_gen = helper.generate_data(FLAGS.batch_size, x, xq, y) for x_batch, xq_batch, y_batch in train_gen: feed_dict = {story_pl: x_batch, question_pl: xq_batch, answer_pl: y_batch, dropout_pl: FLAGS.dropout} cost, _ = sess.run([loss, train_op], feed_dict=feed_dict) # ????? # if batch_id % FLAGS.show_every_n_batches == 0: # print ('Epoch {:>3} Batch {:>4} train_loss = {:.3f}'.format(i, batch_id, cost)) batch_id += 1 # ??epoch?????? test_gen = helper.generate_data(FLAGS.batch_size, tx, txq, ty) total_correct = 0 total = len(tx) for tx_batch, txq_batch, ty_batch in test_gen: feed_dict = {story_pl: tx_batch, question_pl: txq_batch, answer_pl: ty_batch, dropout_pl: 1.0} cor = sess.run(correct, feed_dict=feed_dict) total_correct += int(cor) acc = total_correct * 1.0 / total # ??max test accuary if acc > max_test_acc: max_test_acc = acc print ( 'Epoch{:>3} train_loss = {:.3f} accuary = {:.3f} max_text_acc = {:.3f}'.format(i, cost, acc, max_test_acc)) return max_test_acc