我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用utils.DataLoader()。
def train(args): data_loader = DataLoader(args.data_dir, args.batch_size, args.seq_length) with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f: cPickle.dump(args, f) model = Model(args) with tf.Session() as sess: tf.initialize_all_variables().run() saver = tf.train.Saver(tf.all_variables()) for e in xrange(args.num_epochs): sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) data_loader.reset_batch_pointer() state = model.initial_state.eval() for b in xrange(data_loader.num_batches): start = time.time() x, y = data_loader.next_batch() #print(x, '->', y) #import sys; sys.exit(); feed = { model.input_data: x, model.targets: y, model.initial_state: state } train_loss, state, _ = sess.run(\ [model.cost, model.final_state, model.train_op], feed) end = time.time() print "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, e, train_loss, end - start) if (e * data_loader.num_batches + b) % args.save_every == 0: checkpoint_path = os.path.join(args.save_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b) print "model saved to {}".format(checkpoint_path)
def train(args): datasets = range(4) # Remove the leaveDataset from datasets datasets.remove(args.leaveDataset) # Create the data loader object. This object would preprocess the data in terms of # batches each of size args.batch_size, of length args.seq_length data_loader = DataLoader(args.batch_size, args.seq_length, datasets, forcePreProcess=True) # Save the arguments int the config file with open(os.path.join('save_lstm', 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Create a Vanilla LSTM model with the arguments model = Model(args) # Initialize a TensorFlow session with tf.Session() as sess: # Initialize all the variables in the graph sess.run(tf.initialize_all_variables()) # Add all the variables to the list of variables to be saved saver = tf.train.Saver(tf.all_variables()) # For each epoch for e in range(args.num_epochs): # Assign the learning rate (decayed acc. to the epoch number) sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) # Reset the pointers in the data loader object data_loader.reset_batch_pointer() # Get the initial cell state of the LSTM state = sess.run(model.initial_state) # For each batch in this epoch for b in range(data_loader.num_batches): # Tic start = time.time() # Get the source and target data of the current batch # x has the source data, y has the target data x, y = data_loader.next_batch() # Feed the source, target data and the initial LSTM state to the model feed = {model.input_data: x, model.target_data: y, model.initial_state: state} # Fetch the loss of the model on this batch, the final LSTM state from the session train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) # Toc end = time.time() # Print epoch, batch, loss and time taken print( "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" .format( e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, e, train_loss, end - start)) # Save the model if the current epoch and batch number match the frequency if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0): checkpoint_path = os.path.join('save_lstm', 'model.ckpt') saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b) print("model saved to {}".format(checkpoint_path))