我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用model.model_fn()。
def _run_export(self): export_dir = 'export_ckpt_' + re.findall('\d+', self._latest_checkpoint)[-1] tf.logging.info('Exporting model from checkpoint {0}'.format(self._latest_checkpoint)) prediction_graph = tf.Graph() try: exporter = tf.saved_model.builder.SavedModelBuilder(os.path.join(self._checkpoint_dir, export_dir)) except IOError: tf.logging.info('Checkpoint {0} already exported, continuing...'.format(self._latest_checkpoint)) return with prediction_graph.as_default(): image, name, inputs_dict = model.serving_input_fn() prediction_dict = model.model_fn(model.PREDICT, name, image, None, 6, None) saver = tf.train.Saver() inputs_info = {name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in inputs_dict.iteritems()} output_info = {name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in prediction_dict.iteritems()} signature_def = tf.saved_model.signature_def_utils.build_signature_def( inputs=inputs_info, outputs=output_info, method_name=sig_constants.PREDICT_METHOD_NAME ) with tf.Session(graph=prediction_graph) as session: saver.restore(session, self._latest_checkpoint) exporter.add_meta_graph_and_variables( session, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def}, legacy_init_op=my_main_op() ) exporter.save()
def run(target, is_chief, train_steps, job_dir, train_files, eval_files, num_epochs, learning_rate): num_channels = 6 hooks = list() # does not work well in distributed mode cause it only counts local steps (I think...) hooks.append(tf.train.StopAtStepHook(train_steps)) if is_chief: evaluation_graph = tf.Graph() with evaluation_graph.as_default(): # Features and label tensors image, ground_truth, name = model.input_fn(eval_files, 1, shuffle=False, shared_name=None) # Returns dictionary of tensors to be evaluated metric_dict = model.model_fn(model.EVAL, name, image, ground_truth, num_channels, learning_rate) # hook that performs evaluation separate from training hooks.append(EvaluationRunHook(job_dir, metric_dict, evaluation_graph)) hooks.append(CheckpointExporterHook(job_dir)) # Create a new graph and specify that as default with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter()): # Features and label tensors as read using filename queue image, ground_truth, name = model.input_fn(train_files, num_epochs, shuffle=True, shared_name='train_queue') # Returns the training graph and global step tensor train_op, log_hook, train_summaries = model.model_fn(model.TRAIN, name, image, ground_truth, num_channels, learning_rate) # Hook that logs training to the console hooks.append(log_hook) train_summary_hook = tf.train.SummarySaverHook(save_steps=1, output_dir=get_summary_dir(job_dir), summary_op=train_summaries) hooks.append(train_summary_hook) # Creates a MonitoredSession for training # MonitoredSession is a Session-like object that handles # initialization, recovery and hooks # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession with tf.train.MonitoredTrainingSession(master=target, is_chief=is_chief, checkpoint_dir=job_dir, hooks=hooks, save_checkpoint_secs=60*3, save_summaries_steps=1, log_step_count_steps=5) as session: # Run the training graph which returns the step number as tracked by # the global step tensor. # When train epochs is reached, session.should_stop() will be true. while not session.should_stop(): session.run(train_op)
def build_and_run_exports(latest, job_dir, serving_input_fn, hidden_units): """Given the latest checkpoint file export the saved model. Args: latest (string): Latest checkpoint file job_dir (string): Location of checkpoints and model files name (string): Name of the checkpoint to be exported. Used in building the export path. hidden_units (list): Number of hidden units learning_rate (float): Learning rate for the SGD """ prediction_graph = tf.Graph() exporter = tf.saved_model.builder.SavedModelBuilder( os.path.join(job_dir, 'export')) with prediction_graph.as_default(): features, inputs_dict = serving_input_fn() prediction_dict = model.model_fn( model.PREDICT, features.copy(), None, # labels hidden_units=hidden_units, learning_rate=None # learning_rate unused in prediction mode ) saver = tf.train.Saver() inputs_info = { name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in inputs_dict.iteritems() } output_info = { name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in prediction_dict.iteritems() } signature_def = tf.saved_model.signature_def_utils.build_signature_def( inputs=inputs_info, outputs=output_info, method_name=sig_constants.PREDICT_METHOD_NAME ) with tf.Session(graph=prediction_graph) as session: session.run([tf.local_variables_initializer(), tf.tables_initializer()]) saver.restore(session, latest) exporter.add_meta_graph_and_variables( session, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={ sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def }, legacy_init_op=main_op() ) exporter.save()
def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units): """Given the latest checkpoint file export the saved model. Args: latest (string): Latest checkpoint file job_dir (string): Location of checkpoints and model files name (string): Name of the checkpoint to be exported. Used in building the export path. hidden_units (list): Number of hidden units learning_rate (float): Learning rate for the SGD """ prediction_graph = tf.Graph() exporter = tf.saved_model.builder.SavedModelBuilder( os.path.join(job_dir, 'export', name)) with prediction_graph.as_default(): features, inputs_dict = serving_input_fn() prediction_dict = model.model_fn( model.PREDICT, features, None, # labels hidden_units=hidden_units, learning_rate=None # learning_rate unused in prediction mode ) saver = tf.train.Saver() inputs_info = { name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in inputs_dict.iteritems() } output_info = { name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in prediction_dict.iteritems() } signature_def = tf.saved_model.signature_def_utils.build_signature_def( inputs=inputs_info, outputs=output_info, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME ) with tf.Session(graph=prediction_graph) as session: session.run([tf.local_variables_initializer(), tf.tables_initializer()]) saver.restore(session, latest) exporter.add_meta_graph_and_variables( session, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def }, ) exporter.save()