我们从Python开源项目中,提取了以下22个代码示例,用于说明如何使用tensorflow.logging()。
def sequential(x, net, defaults = {}, name = '', reuse = None, var = {}, layers = {}): layers = dict(list(layers.items()) + list(predefined_layers.items())) y = x logging.info('Building Sequential Network : %s', name) with tf.variable_scope(name, reuse = reuse): for i in range(len(net)): ltype = net[i][0] lcfg = net[i][1] if len(net[i]) == 2 else {} lname = lcfg.get('name', ltype + str(i)) ldefs = defaults.get(ltype, {}) lcfg = dict(list(ldefs.items()) + list(lcfg.items())) for k, v in list(lcfg.items()): if isinstance(v, basestring) and v[0] == '$': # print var, v lcfg[k] = var[v[1:]] y = layers[ltype](y, lname, **lcfg) logging.info('\t %s \t %s', lname, y.get_shape().as_list()) return y
def _module_info_from_proto_safe(module_info_def, import_scope=None): """Deserializes the `module_info_def` proto without raising exceptions. Args: module_info_def: An instance of `module_pb2.SonnetModule`. import_scope: Optional `string`. Name scope to use. Returns: An instance of `ModuleInfo`. """ try: return _module_info_from_proto(module_info_def, import_scope) except Exception as e: # pylint: disable=broad-except logging.warning( "Error encountered when deserializing sonnet ModuleInfo:\n%s", str(e)) return None # `to_proto` is already wrapped into a try...except externally but # `from_proto` isn't. In order to minimize disruption, catch all the exceptions # happening during `from_proto` and just log them.
def main(unused_args): g = tf.Graph() with g.as_default(), tf.device('/cpu:0'): # Build the model for evaluation. model = create_model(FLAGS, 'eval') model.build() with tf.Session() as sess: # Start the queue runners. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # Run evaluation on the latest checkpoint. try: for i in range(FLAGS.total_steps): inspect_tensors(sess) except Exception as e: # pylint: disable=broad-except tf.logging.error("Evaluation failed.") coord.request_stop(e) coord.request_stop() coord.join(threads, stop_grace_period_secs=1)
def get_priming_melodies(self): """Runs a batch of training data through MelodyRNN model. If the priming mode is 'random_midi', priming the q-network requires a random training melody. Therefore this function runs a batch of data from the training directory through the internal model, and the resulting internal states of the LSTM are stored in a list. The next note in each training melody is also stored in a corresponding list called 'priming_notes'. Therefore, to prime the model with a random melody, it is only necessary to select a random index from 0 to batch_size-1 and use the hidden states and note at that index as input to the model. """ (next_note_softmax, self.priming_states, lengths) = self.q_network.run_training_batch() # Get the next note that was predicted for each priming melody to be used # in priming. self.priming_notes = [0] * len(lengths) for i in range(len(lengths)): # Each melody has TRAIN_SEQUENCE_LENGTH outputs, but the last note is # actually stored at lengths[i]. The rest is padding. start_i = i * TRAIN_SEQUENCE_LENGTH end_i = start_i + lengths[i] - 1 end_softmax = next_note_softmax[end_i, :] self.priming_notes[i] = np.argmax(end_softmax) tf.logging.info('Stored priming notes: %s', self.priming_notes)
def prime_internal_model(self, model): """Prime an internal model such as the q_network based on priming mode. Args: model: The internal model that should be primed. Returns: The first observation to feed into the model. """ model.state_value = model.get_zero_state() if self.priming_mode == 'random_midi': priming_idx = np.random.randint(0, len(self.priming_states)) model.state_value = np.reshape( self.priming_states[priming_idx, :], (1, model.cell.state_size)) priming_note = self.priming_notes[priming_idx] next_obs = np.array( rl_tuner_ops.make_onehot([priming_note], self.num_actions)).flatten() tf.logging.debug( 'Feeding priming state for midi file %s and corresponding note %s', priming_idx, priming_note) elif self.priming_mode == 'single_midi': model.prime_model() next_obs = model.priming_note elif self.priming_mode == 'random_note': next_obs = self.get_random_note() else: tf.logging.warn('Error! Invalid priming mode. Priming with random note') next_obs = self.get_random_note() return next_obs
def reward_leap_up_back(self, action, resolving_leap_bonus=5.0, leaping_twice_punishment=-5.0): """Applies punishment and reward based on the principle leap up leap back. Large interval jumps (more than a fifth) should be followed by moving back in the same direction. Args: action: One-hot encoding of the chosen action. resolving_leap_bonus: Amount of reward dispensed for resolving a previous leap. leaping_twice_punishment: Amount of reward received for leaping twice in the same direction. Returns: Float reward value. """ leap_outcome = self.detect_leap_up_back(action) if leap_outcome == rl_tuner_ops.LEAP_RESOLVED: tf.logging.debug('Leap resolved, awarding %s', resolving_leap_bonus) return resolving_leap_bonus elif leap_outcome == rl_tuner_ops.LEAP_DOUBLED: tf.logging.debug('Leap doubled, awarding %s', leaping_twice_punishment) return leaping_twice_punishment else: return 0.0
def testFinalCoreHasNoSizeWarning(self): cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu] rnn = snt.DeepRNN(cores, skip_connections=False) with mock.patch.object(tf.logging, "warning") as mocked_logging_warning: # This will produce a warning. unused_output_size = rnn.output_size self.assertTrue(mocked_logging_warning.called) first_call_args = mocked_logging_warning.call_args[0] self.assertTrue("final core %s does not have the " ".output_size field" in first_call_args[0]) self.assertEqual(first_call_args[2], 42)
def testNoSizeButAlreadyConnected(self): batch_size = 16 cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu] rnn = snt.DeepRNN(cores, skip_connections=False) unused_output = rnn(tf.zeros((batch_size, 128)), rnn.initial_state(batch_size=batch_size)) with mock.patch.object(tf.logging, "warning") as mocked_logging_warning: output_size = rnn.output_size # Correct size is automatically inferred. self.assertEqual(output_size, tf.TensorShape([42])) self.assertTrue(mocked_logging_warning.called) first_call_args = mocked_logging_warning.call_args[0] self.assertTrue("DeepRNN has been connected into the graph, " "so inferred output size" in first_call_args[0])
def testWarning(self): seq = snt.Sequential([snt.Linear(output_size=23), snt.Linear(output_size=42)]) seq(tf.placeholder(dtype=tf.float32, shape=[2, 3])) with mock.patch.object(tf.logging, "warning") as mocked_logging_warning: self.assertEqual((), seq.get_variables()) self.assertTrue(mocked_logging_warning.called) first_call_args = mocked_logging_warning.call_args[0] self.assertTrue("will always return an empty tuple" in first_call_args[0])
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): """Tests that events are discarded after a restart is detected. If a step value is observed to be lower than what was previously seen, this should force a discard of all previous items with the same tag that are outdated. Only file versions < 2 use this out-of-order discard logic. Later versions discard events based on the step value of SessionLog.START. """ warnings = [] self.stubs.Set(tf.logging, 'warn', warnings.append) gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) acc.Reload() ## Check that number of items are what they should be self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 200, 300]) gen.AddScalarTensor('s1', wall_time=1, step=101, value=20) gen.AddScalarTensor('s1', wall_time=1, step=201, value=20) gen.AddScalarTensor('s1', wall_time=1, step=301, value=20) acc.Reload() ## Check that we have discarded 200 and 300 from s1 self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301])
def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): """Tests that event discards after restart, only affect the misordered tag. If a step value is observed to be lower than what was previously seen, this should force a discard of all previous items that are outdated, but only for the out of order tag. Other tags should remain unaffected. Only file versions < 2 use this out-of-order discard logic. Later versions discard events based on the step value of SessionLog.START. """ warnings = [] self.stubs.Set(tf.logging, 'warn', warnings.append) gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) gen.AddScalarTensor('s1', wall_time=1, step=100, value=20) gen.AddScalarTensor('s2', wall_time=1, step=101, value=20) gen.AddScalarTensor('s1', wall_time=1, step=200, value=20) gen.AddScalarTensor('s2', wall_time=1, step=201, value=20) gen.AddScalarTensor('s1', wall_time=1, step=300, value=20) gen.AddScalarTensor('s2', wall_time=1, step=301, value=20) gen.AddScalarTensor('s1', wall_time=1, step=101, value=20) gen.AddScalarTensor('s3', wall_time=1, step=101, value=20) gen.AddScalarTensor('s1', wall_time=1, step=201, value=20) gen.AddScalarTensor('s1', wall_time=1, step=301, value=20) acc.Reload() ## Check that we have discarded 200 and 300 for s1 self.assertEqual([x.step for x in acc.Tensors('s1')], [100, 101, 201, 301]) ## Check that s1 discards do not affect s2 (written before out-of-order) ## or s3 (written after out-of-order). ## i.e. check that only events from the out of order tag are discarded self.assertEqual([x.step for x in acc.Tensors('s2')], [101, 201, 301]) self.assertEqual([x.step for x in acc.Tensors('s3')], [101])
def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): """Tests that events are discarded after a restart is detected. If a step value is observed to be lower than what was previously seen, this should force a discard of all previous items with the same tag that are outdated. Only file versions < 2 use this out-of-order discard logic. Later versions discard events based on the step value of SessionLog.START. """ warnings = [] self.stubs.Set(tf.logging, 'warn', warnings.append) gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) gen.AddScalar('s1', wall_time=1, step=100, value=20) gen.AddScalar('s1', wall_time=1, step=200, value=20) gen.AddScalar('s1', wall_time=1, step=300, value=20) acc.Reload() ## Check that number of items are what they should be self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) gen.AddScalar('s1', wall_time=1, step=101, value=20) gen.AddScalar('s1', wall_time=1, step=201, value=20) gen.AddScalar('s1', wall_time=1, step=301, value=20) acc.Reload() ## Check that we have discarded 200 and 300 from s1 self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): """Tests that event discards after restart, only affect the misordered tag. If a step value is observed to be lower than what was previously seen, this should force a discard of all previous items that are outdated, but only for the out of order tag. Other tags should remain unaffected. Only file versions < 2 use this out-of-order discard logic. Later versions discard events based on the step value of SessionLog.START. """ warnings = [] self.stubs.Set(tf.logging, 'warn', warnings.append) gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) gen.AddScalar('s1', wall_time=1, step=100, value=20) gen.AddScalar('s1', wall_time=1, step=200, value=20) gen.AddScalar('s1', wall_time=1, step=300, value=20) gen.AddScalar('s1', wall_time=1, step=101, value=20) gen.AddScalar('s1', wall_time=1, step=201, value=20) gen.AddScalar('s1', wall_time=1, step=301, value=20) gen.AddScalar('s2', wall_time=1, step=101, value=20) gen.AddScalar('s2', wall_time=1, step=201, value=20) gen.AddScalar('s2', wall_time=1, step=301, value=20) acc.Reload() ## Check that we have discarded 200 and 300 self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) ## Check that s1 discards do not affect s2 ## i.e. check that only events from the out of order tag are discarded self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301])
def reward_music_theory(self, action): """Computes cumulative reward for all music theory functions. Args: action: A one-hot encoding of the chosen action. Returns: Float reward value. """ reward = self.reward_key(action) tf.logging.debug('Key: %s', reward) prev_reward = reward reward += self.reward_tonic(action) if reward != prev_reward: tf.logging.debug('Tonic: %s', reward) prev_reward = reward reward += self.reward_penalize_repeating(action) if reward != prev_reward: tf.logging.debug('Penalize repeating: %s', reward) prev_reward = reward reward += self.reward_penalize_autocorrelation(action) if reward != prev_reward: tf.logging.debug('Penalize autocorr: %s', reward) prev_reward = reward reward += self.reward_motif(action) if reward != prev_reward: tf.logging.debug('Reward motif: %s', reward) prev_reward = reward reward += self.reward_repeated_motif(action) if reward != prev_reward: tf.logging.debug('Reward repeated motif: %s', reward) prev_reward = reward # New rewards based on Gauldin's book, "A Practical Approach to Eighteenth # Century Counterpoint" reward += self.reward_preferred_intervals(action) if reward != prev_reward: tf.logging.debug('Reward preferred_intervals: %s', reward) prev_reward = reward reward += self.reward_leap_up_back(action) if reward != prev_reward: tf.logging.debug('Reward leap up back: %s', reward) prev_reward = reward reward += self.reward_high_low_unique(action) if reward != prev_reward: tf.logging.debug('Reward high low unique: %s', reward) return reward
def restore_from_directory(self, directory=None, checkpoint_name=None, reward_file_name=None): """Restores this model from a saved checkpoint. Args: directory: Path to directory where checkpoint is located. If None, defaults to self.output_dir. checkpoint_name: The name of the checkpoint within the directory. reward_file_name: The name of the .npz file where the stored rewards are saved. If None, will not attempt to load stored rewards. """ if directory is None: directory = self.output_dir if checkpoint_name is not None: checkpoint_file = os.path.join(directory, checkpoint_name) else: tf.logging.info('Directory %s.', directory) checkpoint_file = tf.train.latest_checkpoint(directory) if checkpoint_file is None: tf.logging.fatal('Error! Cannot locate checkpoint in the directory') return # TODO(natashamjaques): Remove print statement once tf.logging outputs # to Jupyter notebooks (once the following issue is resolved: # https://github.com/tensorflow/tensorflow/issues/3047) print('Attempting to restore from checkpoint', checkpoint_file) tf.logging.info('Attempting to restore from checkpoint %s', checkpoint_file) self.saver.restore(self.session, checkpoint_file) if reward_file_name is not None: npz_file_name = os.path.join(directory, reward_file_name) # TODO(natashamjaques): Remove print statement once tf.logging outputs # to Jupyter notebooks (once the following issue is resolved: # https://github.com/tensorflow/tensorflow/issues/3047) print('Attempting to load saved reward values from file', npz_file_name) tf.logging.info('Attempting to load saved reward values from file %s', npz_file_name) npz_file = np.load(npz_file_name) self.rewards_batched = npz_file['train_rewards'] self.music_theory_rewards_batched = npz_file['train_music_theory_rewards'] self.note_rnn_rewards_batched = npz_file['train_note_rnn_rewards'] self.eval_avg_reward = npz_file['eval_rewards'] self.eval_avg_music_theory_reward = npz_file['eval_music_theory_rewards'] self.eval_avg_note_rnn_reward = npz_file['eval_note_rnn_rewards'] self.target_val_list = npz_file['target_val_list']
def run_epoch(session, config, graph, iterator, ops=None, summary_writer=None, summary_prefix=None, saver=None): """Runs the model on the given data.""" if not ops: ops = [] def should_monitor(step): return step and c['monitoring_frequency'] and (step + 1) % c['monitoring_frequency'] == 0 def should_save(step): return step and c['saving_frequency'] and (step + 1) % c['saving_frequency'] == 0 # Shortcuts, ugly but still increase readability c = config g = graph m = Monitor(summary_writer, summary_prefix) while g['step_number'].eval() < FLAGS.task * c['next_worker_delay']: pass # Statistics for step, (inputs, lengths) in enumerate(iterator): # Define what we feed feed_dict = {g['inputs']: inputs, g['input_lengths']: lengths} # Define what we fetch fetch = dict(g['observed']) fetch['total_neg_loglikelihood'] = g['total_neg_loglikelihood'] fetch['total_correct'] = g['total_correct'] fetch['_ops'] = ops # RUN!!! r = session.run(fetch, feed_dict) # Update the monitor accumulators m.total_neg_loglikelihood += r['total_neg_loglikelihood'] m.total_correct += r['total_correct'] # We do not predict the first words, that's why # batch_size has to subtracted from the total m.steps += 1 m.words += sum(lengths) - c['batch_size'] m.sentences += c['batch_size'] m.words_including_padding += c['batch_size'] * len(inputs[0]) m.step_number = g['step_number'].eval() m.learning_rate = float(g['learning_rate'].eval()) for key in g['observed']: m.observed[key] += r[key] if should_monitor(step): tf.logging.info('monitor') result = m.monitor() if saver and should_save(step): print("saved") saver.save(session, os.path.join(FLAGS.train_path, 'model')) if not should_monitor(step): result = m.monitor() if saver: saver.save(session, os.path.join(FLAGS.train_path, 'model')) return result
def main(_): # Configuration. num_unrolls = FLAGS.num_steps if FLAGS.seed: tf.set_random_seed(FLAGS.seed) # Problem. problem, net_config, net_assignments = util.get_config(FLAGS.problem, FLAGS.path) # Optimizer setup. if FLAGS.optimizer == "Adam": cost_op = problem() problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) problem_reset = tf.variables_initializer(problem_vars) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) optimizer_reset = tf.variables_initializer(optimizer.get_slot_names()) update = optimizer.minimize(cost_op) reset = [problem_reset, optimizer_reset] elif FLAGS.optimizer == "L2L": if FLAGS.path is None: logging.warning("Evaluating untrained L2L optimizer") optimizer = meta.MetaOptimizer(**net_config) meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments) _, update, reset, cost_op, _ = meta_loss else: raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer)) with ms.MonitoredSession() as sess: # Prevent accidental changes to the graph. tf.get_default_graph().finalize() total_time = 0 total_cost = 0 for _ in xrange(FLAGS.num_epochs): # Training. time, cost = util.run_epoch(sess, cost_op, [update], reset, num_unrolls) total_time += time total_cost += cost # Results. util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost, total_time, FLAGS.num_epochs)
def evaluate_model(sess, target_cross_entropy_losses, target_cross_entropy_loss_weights, global_step, summary_writer, summary_op): """Computes perplexity-per-word over the evaluation dataset. Summaries and perplexity-per-word are written out to the eval directory. Args: sess: Session object. model: Instance of ShowAndTellModel; the model to evaluate. global_step: Integer; global step of the model checkpoint. summary_writer: Instance of SummaryWriter. summary_op: Op for generating model summaries. """ # Log model summaries on a single batch. summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, global_step) # Compute perplexity over the entire dataset. num_eval_batches = int( math.ceil(num_eval_examples / batch_size)) start_time = time.time() sum_losses = 0. sum_weights = 0. for i in xrange(num_eval_batches): cross_entropy_losses, weights = sess.run([ target_cross_entropy_losses, target_cross_entropy_loss_weights ]) sum_losses += np.sum(cross_entropy_losses * weights) sum_weights += np.sum(weights) if not i % 100: tf.logging.info("Computed losses for %d of %d batches.", i + 1, num_eval_batches) eval_time = time.time() - start_time perplexity = math.exp(sum_losses / sum_weights) tf.logging.info("Perplexity = %f (%.2g sec)", perplexity, eval_time) # Log perplexity to the SummaryWriter. summary = tf.Summary() value = summary.value.add() value.simple_value = perplexity value.tag = "Perplexity" summary_writer.add_summary(summary, global_step) # Write the Events file to the eval directory. summary_writer.flush() tf.logging.info("Finished processing evaluation at global step %d.", global_step)
def run_once(global_step, target_cross_entropy_losses, target_cross_entropy_loss_weights, saver, summary_writer, summary_op): """Evaluates the latest model checkpoint. Args: model: Instance of ShowAndTellModel; the model to evaluate. saver: Instance of tf.train.Saver for restoring model Variables. summary_writer: Instance of SummaryWriter. summary_op: Op for generating model summaries. """ # The lastest ckpt model_path = tf.train.latest_checkpoint(checkpoint_dir) # print(model_path) # /home/dsigpu4/Samba/im2txt/model/train_tl/model.ckpt-20000 # exit() if not model_path: tf.logging.info("Skipping evaluation. No checkpoint found in: %s", checkpoint_dir) return with tf.Session() as sess: # Load model from checkpoint. tf.logging.info("Loading model from checkpoint: %s", model_path) saver.restore(sess, model_path) # global_step = tf.train.global_step(sess, model.global_step.name) step = tf.train.global_step(sess, global_step.name) tf.logging.info("Successfully loaded %s at global step = %d.", # os.path.basename(model_path), global_step) os.path.basename(model_path), step) # if global_step < min_global_step: if step < min_global_step: # tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step, tf.logging.info("Skipping evaluation. Global step = %d < %d", step, min_global_step) return # Start the queue runners. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # Run evaluation on the latest checkpoint. try: evaluate_model( sess=sess, target_cross_entropy_losses=target_cross_entropy_losses, target_cross_entropy_loss_weights=target_cross_entropy_loss_weights, global_step=step, summary_writer=summary_writer, summary_op=summary_op) except Exception, e: # pylint: disable=broad-except tf.logging.error("Evaluation failed.") coord.request_stop(e) coord.request_stop() coord.join(threads, stop_grace_period_secs=10)
def run(): """Runs evaluation in a loop, and logs summaries to TensorBoard.""" # Create the evaluation directory if it doesn't exist. if not tf.gfile.IsDirectory(eval_dir): tf.logging.info("Creating eval directory: %s", eval_dir) tf.gfile.MakeDirs(eval_dir) g = tf.Graph() with g.as_default(): images, input_seqs, target_seqs, input_mask = Build_Inputs(mode, input_file_pattern) net_image_embeddings = Build_Image_Embeddings(mode, images, train_inception) net_seq_embeddings = Build_Seq_Embeddings(input_seqs) _, target_cross_entropy_losses, target_cross_entropy_loss_weights, network = \ Build_Model(mode, net_image_embeddings, net_seq_embeddings, target_seqs, input_mask) global_step = tf.Variable( initial_value=0, dtype=tf.int32, name="global_step", trainable=False, collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.VARIABLES]) # Create the Saver to restore model Variables. saver = tf.train.Saver() # Create the summary operation and the summary writer. summary_op = tf.merge_all_summaries() summary_writer = tf.train.SummaryWriter(eval_dir) g.finalize() # Run a new evaluation run every eval_interval_secs. while True: start = time.time() tf.logging.info("Starting evaluation at " + time.strftime( "%Y-%m-%d-%H:%M:%S", time.localtime())) run_once(global_step, target_cross_entropy_losses, target_cross_entropy_loss_weights, saver, summary_writer, summary_op) time_to_next_eval = start + eval_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)
def _build_input_fn(input_file_pattern, batch_size, mode): """Build input function. Args: input_file_pattern: The file patter for examples batch_size: Batch size mode: The execution mode, as defined in tf.contrib.learn.ModeKeys. Returns: Tuple, dictionary of feature column name to tensor and labels. """ def _input_fn(): """Supplies the input to the model. Returns: A tuple consisting of 1) a dictionary of tensors whose keys are the feature names, and 2) a tensor of target labels if the mode is not INFER (and None, otherwise). """ logging.info("Reading files from %s", input_file_pattern) input_files = sorted(list(tf.gfile.Glob(input_file_pattern))) logging.info("Reading files from %s", input_files) include_target_column = (mode != tf.contrib.learn.ModeKeys.INFER) features_spec = tf.contrib.layers.create_feature_spec_for_parsing( feature_columns=_get_feature_columns(include_target_column)) if FLAGS.use_gzip: def gzip_reader(): return tf.TFRecordReader( options=tf.python_io.TFRecordOptions( compression_type=TFRecordCompressionType.GZIP)) reader_fn = gzip_reader else: reader_fn = tf.TFRecordReader features = tf.contrib.learn.io.read_batch_features( file_pattern=input_files, batch_size=batch_size, queue_capacity=3*batch_size, randomize_input=mode == tf.contrib.learn.ModeKeys.TRAIN, feature_queue_capacity=FLAGS.feature_queue_capacity, reader=reader_fn, features=features_spec) target = None if include_target_column: target = features.pop(FLAGS.target_field) return features, target return _input_fn