我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用model.train()。
def save_model(model, sess, log_path, step): """ Save model using tensorflow checkpoint (also save hidden variables) Args: model : model to save variable from sess : tensorflow session log_path : where to save step : number of step at time of saving """ path = log_path + '/' + model.name if tf.gfile.Exists(path): tf.gfile.DeleteRecursively(path) tf.gfile.MakeDirs(path) saver = tf.train.Saver() checkpoint_path = os.path.join(path, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def restore_model(model, sess, log_path): """ Restore model (including hidden variable) In practice use to resume the training of the same model Args model : model to restore variable to sess : tensorflow session log_path : where to save Returns: step_b : the step number at which training ended """ path = log_path + '/' + model.name saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] else: print('------------------------------------------------------') print('No checkpoint file found') print('------------------------------------------------------ \n') exit()
def save_weight_only(model, sess, log_path, step): """ Save model but only weight (meaning no hidden variable) In practice use this to just transfer weights from one model to the other Args: model : model to save variable from sess : tensorflow session log_path : where to save step : number of step at time of saving """ path = log_path + '/' + model.name + '_weight_only' if tf.gfile.Exists(path): tf.gfile.DeleteRecursively(path) tf.gfile.MakeDirs(path) variable_to_save = {} for i in range(30): name = 'conv_' + str(i) variable_to_save[name] = model.parameters_conv[i] if i in [2, 4] and model.concat: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i][0] name = 'deconv_' + str(i) + '_bis' variable_to_save[name] = model.parameters_deconv[i][1] else: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i] if i < 2: name = 'deconv_bis_' + str(i) variable_to_save[name] = model.deconv[i] saver = tf.train.Saver(variable_to_save) checkpoint_path = os.path.join(path, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(): # Turn on training mode which enables dropout. model.train() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): data, targets = get_batch(train_data, i) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) model.zero_grad() output, hidden = model(data, hidden) loss = criterion(output.view(-1, ntokens), targets) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) for p in model.parameters(): p.data.add_(-lr, p.grad.data) total_loss += loss.data if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss[0] / args.log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, lr, elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time() # Loop over epochs.
def train(): with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) image, label = input.get_input(LABEL_PATH, LABEL_FORMAT, IMAGE_PATH, IMAGE_FORMAT) logits = model.inference(image) loss = model.loss(logits, label) train_op = model.train(loss, global_step) saver = tf.train.Saver(tf.all_variables()) summary_op = tf.merge_all_summaries() init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto(log_device_placement=input.FLAGS.log_device_placement)) sess.run(init) # Start the queue runners. tf.train.start_queue_runners(sess=sess) summary_writer = tf.train.SummaryWriter(input.FLAGS.train_dir, graph_def=sess.graph_def) for step in xrange(input.FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 1 == 0: num_examples_per_step = input.FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)') print (format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 10 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 25 == 0: checkpoint_path = os.path.join(input.FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def main(argv=None): train()
def init_variables(sess): init_op = tf.initialize_all_variables() sess.run(init_op) gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator') discrim_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') variables_to_restore = gen_vars + discrim_vars saver = tf.train.Saver(variables_to_restore) if FLAGS.restore_model: ckpt = tf.train.get_checkpoint_state('./') if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: raise RuntimeError('No checkpoint file found.') return
def main(argv=None): if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) train()
def restore_weight_from(model, name, sess, log_path, copy_concat = False): """ Restore model (excluding hidden variable) In practice use to train a model with the weight from another model. As long as both model have architecture from the original model.py, then it works Compatible w or w/o direct connections Args model : model to restore variable to name : name of model to copy sess : tensorflow session log_path : where to restore copy_concat : specify if the model to copy from also had direct connections Returns: step_b : the step number at which training ended """ path = log_path + '/' + name + '_weight_only' variable_to_save = {} for i in range(30): name = 'conv_' + str(i) variable_to_save[name] = model.parameters_conv[i] if i < 2: if copy_concat == model.concat: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i] name = 'deconv_bis_' + str(i) variable_to_save[name] = model.deconv[i] else: if i in [2, 4] and model.concat: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i][0] if copy_concat: name = 'deconv_' + str(i) + '_bis' variable_to_save[name] = model.parameters_deconv[i][1] elif i in [2, 4] and not model.concat: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i] else: name = 'deconv_' + str(i) variable_to_save[name] = model.parameters_deconv[i] saver = tf.train.Saver(variable_to_save) ckpt = tf.train.get_checkpoint_state(path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) return ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] else: print('------------------------------------------------------') print('No checkpoint file found') print('------------------------------------------------------ \n') exit()
def train(): # Turn on training mode which enables dropout. if args.model == 'QRNN': model.reset() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) batch, i = 0, 0 while i < train_data.size(0) - 1 - 1: bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. # Prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # There's a very small chance that it could select a very long sequence length resulting in OOM seq_len = min(seq_len, args.bptt + 10) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt model.train() data, targets = get_batch(train_data, i, args, seq_len=seq_len) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) optimizer.zero_grad() output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) raw_loss = criterion(output.view(-1, ntokens), targets) loss = raw_loss # Activiation Regularization loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) optimizer.step() total_loss += raw_loss.data optimizer.param_groups[0]['lr'] = lr2 if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss[0] / args.log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time() ### batch += 1 i += seq_len # Load the best saved model.
def train(): # Turn on training mode which enables dropout. if args.model == 'QRNN': model.reset() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) batch, i = 0, 0 while i < train_data.size(0) - 1 - 1: bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. # Prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # There's a very small chance that it could select a very long sequence length resulting in OOM # seq_len = min(seq_len, args.bptt + 10) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt model.train() data, targets = get_batch(train_data, i, args, seq_len=seq_len) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) optimizer.zero_grad() output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) raw_loss = criterion(output.view(-1, ntokens), targets) loss = raw_loss # Activiation Regularization loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) optimizer.step() total_loss += raw_loss.data optimizer.param_groups[0]['lr'] = lr2 if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss[0] / args.log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time() ### batch += 1 i += seq_len # Loop over epochs.