我们从Python开源项目中,提取了以下32个代码示例,用于说明如何使用tensorflow.python.framework.graph_util.convert_variables_to_constants()。
def setUp(self): self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars") if not os.path.exists(self.base_path): os.mkdir(self.base_path) # Create a simple graph with a variable, then convert variables to # constants and export the graph. with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, name="x") w = tf.Variable(3.0) y = tf.sub(w * x, 7.0, name="y") # pylint: disable=unused-variable tf.add_to_collection("meta", "this is meta") with self.test_session(graph=g) as session: tf.initialize_all_variables().run() new_graph_def = graph_util.convert_variables_to_constants( session, g.as_graph_def(), ["y"]) filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME) tf.train.export_meta_graph( filename, graph_def=new_graph_def, collection_list=["meta"])
def setUp(self): self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars") if not os.path.exists(self.base_path): os.mkdir(self.base_path) # Create a simple graph with a variable, then convert variables to # constants and export the graph. with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, name="x") w = tf.Variable(3.0) y = tf.sub(w * x, 7.0, name="y") # pylint: disable=unused-variable tf.add_to_collection("meta", "this is meta") with self.test_session(graph=g) as session: tf.global_variables_initializer().run() new_graph_def = graph_util.convert_variables_to_constants( session, g.as_graph_def(), ["y"]) filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME) tf.train.export_meta_graph( filename, graph_def=new_graph_def, collection_list=["meta"])
def freeze_graph(model_folder): from tensorflow.python.framework import graph_util checkpoint = tf.train.get_checkpoint_state(args.model_folder) input_checkpoint = checkpoint.model_checkpoint_path absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_node_names = "Accuracy/predictions" output_graph = absolute_model_folder + "/frozen_model_2.pb" clear_devices = True saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: saver.restore(sess, input_checkpoint) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph( model_dir, output_nodes_list, output_graph_name='frozen_model.pb' ): """ reduce a saved model and metadata down to a deployable file """ from tensorflow.python.framework import graph_util LOGGER.info('Attempting to freeze graph at {}'.format(model_dir)) checkpoint = tf.train.get_checkpoint_state(model_dir) input_checkpoint = checkpoint.model_checkpoint_path if input_checkpoint is None: LOGGER.error('Cannot load checkpoint at {}'.format(model_dir)) return None absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_dir + '/' + output_graph_name saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: saver.restore(sess, input_checkpoint) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_nodes_list ) with tf.gfile.GFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) LOGGER.info('Froze graph with {} ops'.format( len(output_graph_def.node) )) return output_graph
def freeze_graph( model_dir, output_nodes_list, output_graph_name='frozen_model.pb' ): """ reduce a saved model and metadata down to a deployable file; following https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc output_nodes_list = e.g., ['softmax_linear/logits'] """ from tensorflow.python.framework import graph_util LOGGER.info('Attempting to freeze graph at {}'.format(model_dir)) checkpoint = tf.train.get_checkpoint_state(model_dir) input_checkpoint = checkpoint.model_checkpoint_path if input_checkpoint is None: LOGGER.error('Cannot load checkpoint at {}'.format(model_dir)) return None absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_dir + '/' + output_graph_name saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: saver.restore(sess, input_checkpoint) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_nodes_list ) with tf.gfile.GFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) LOGGER.info('Froze graph with {} ops'.format( len(output_graph_def.node) )) return output_graph
def freeze_graph_def(sess, input_graph_def, output_node_names): for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] # Get the list of important nodes whitelist_names = [] for node in input_graph_def.node: if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')): whitelist_names.append(node.name) # Replace all the variables in the graph with constants of the same values output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names) return output_graph_def
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
def generatePB(pb_dest = "model_mnist_bnn.pb"): gd = sess.graph.as_graph_def() gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output']) with gfile.FastGFile(pb_dest, 'wb') as f: f.write(gd2.SerializeToString()) print('pb saved')
def generatePB(pb_dest = "model.pb"): gd = sess.graph.as_graph_def() gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output']) with gfile.FastGFile(pb_dest, 'wb') as f: f.write(gd2.SerializeToString()) print('pb saved')
def generatePB(pb_dest = "cifar_bnn_new.pb"): gd = sess.graph.as_graph_def() gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output']) with gfile.FastGFile(pb_dest, 'wb') as f: f.write(gd2.SerializeToString()) print('pb saved')
def freeze_graph(model_folder): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + "/frozen_model.pb" # Before exporting our graph, we need to precise what is our output node # NOTE: this variables is plural, because you can have multiple output nodes output_node_names = "Accuracy/predictions" # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated clear_devices = True # We import the meta graph and retrive a Saver saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constant output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",") # We split on comma for convenience ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def run(): log.info('Run freeze restore') y = tf.Variable([float(88.8), float(5)], name='y1') # print(y.op.node_def) init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) # sess.run(y) g = sess.graph g_def = g.as_graph_def() # print node names # print([n.name for n in g_def.node]) # constants constants = graph_util.convert_variables_to_constants( sess, g_def, ['y1']) # serialize s = constants.SerializeToString() # print(len(g_def.node)) print_nodes(g.as_graph_def(), 'before restore:') _ = restore_graph(s) print_nodes(g.as_graph_def(), 'after restore:') t = g.get_tensor_by_name('restore/y1:0') sess.run(y.assign(y + t)) print(sess.run(y)) # print(len(g_def.node)) # print(sess.run(y.assign([float(99.9)]))) # print(n) # print(sess.run(y.assign(n))) # g2 = tf.Graph() # g2_def = g2.as_graph_def() # print([n.name for n in g2_def.node]) # run()
def _set_variable_and_publish(self, sess, iteration_id, transaction_id, group_id): # v = variable # s = v.to_proto().SerializeToString() # h = ':'.join('{:02x}'.format(ord(c)) for c in s) variable_names = [var.op.name for var in self.variables] g = sess.graph g_def = g.as_graph_def() constants = graph_util.convert_variables_to_constants( sess, g_def, variable_names) s = constants.SerializeToString() parallel_count = self.infra_info['parallel_count'] self.rc.set(transaction_id, s) message = json.dumps({ 'key': 'set_variable', 'transaction_id': transaction_id, 'group_id': group_id, 'variables': variable_names, 'worker_id': self.worker_id, 'train_id': self.train_id, 'iteration_id': iteration_id, 'parallel_count': parallel_count }) self.r.publish(channel=channel, message=message) log.debug('pub %s' % iteration_id) return len(s)
def _set_variable_and_publish(self, sess, iteration_id, variables, transaction_id, group_id): # v = variable # s = v.to_proto().SerializeToString() # h = ':'.join('{:02x}'.format(ord(c)) for c in s) variable_names = [var.op.name for var in variables] g = sess.graph g_def = g.as_graph_def() constants = graph_util.convert_variables_to_constants( sess, g_def, variable_names) s = constants.SerializeToString() parallel_count = self.infra_info['parallel_count'] self.rc.set(transaction_id, s) message = json.dumps({ 'key': 'set_variable', 'transaction_id': transaction_id, 'group_id': group_id, 'variables': variable_names, 'worker_id': self.worker_id, 'train_id': self.train_id, 'iteration_id': iteration_id, 'parallel_count': parallel_count }) self.r.publish(channel=channel, message=message) self._log('pub %s' % iteration_id) return len(s)
def freeze_graph(model_folder, net_name): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + "/%s.pb" % net_name # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep and what part it can dump # NOTE: this variable is plural, because you can have multiple output nodes output_node_names = "Placeholder,Placeholder_1,Placeholder_2,out/add" # We clear devices to allow TensorFlow to control on which device it will load operations clear_devices = True # We import the meta graph and retrieve a Saver saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) graph_def = sess.graph.as_graph_def() for node in graph_def.node: print node.name # We use a built-in TF helper to export variables to constants output_graph_def = graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_names.split(",") # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def export_graph(input_path, output_path, output_nodes, debug=False): # todo: might want to look at http://stackoverflow.com/a/39578062/195651 checkpoint = tf.train.latest_checkpoint(input_path) importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # type: tf.Graph gd = graph.as_graph_def() # type: tf.GraphDef if debug: op_names = [op.name for op in graph.get_operations()] print(op_names) # fix batch norm nodes # https://github.com/tensorflow/tensorflow/issues/3628 for node in gd.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in range(len(node.input)): if 'moving_' in node.input[index]: node.input[index] += '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] if debug: print('Freezing the graph ...') with tf.Session() as sess: importer.restore(sess, checkpoint) output_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_nodes) tf.train.write_graph(output_graph_def, path.dirname(output_path), path.basename(output_path), as_text=False)
def freeze_graph(model_folder): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = '/'.join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + '/frozen_model.pb' # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep and what part it can dump # NOTE: this variable is plural, because you can have multiple output nodes output_node_names = 'generate_output/output' # We clear devices to allow TensorFlow to control on which device it will load operations clear_devices = True # We import the meta graph and retrieve a Saver saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constants output_graph_def = graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_names.split(",") # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) print('%d ops in the final graph.' % len(output_graph_def.node))
def main(_): output_node_names = "output_prob" session_config = tf.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_fraction with tf.Session(config=session_config) as sess: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) saver = tf.train.import_meta_graph(ckpt + '.meta') if ckpt: saver.restore(sess, ckpt) # for node in input_graph_def.node: # print(node.name, node.op, node.input) # Retrieve the protobuf graph definition and fix the batch norm nodes # Fix for bug of BN. # Ref 1 Solution: https://github.com/davidsandberg/facenet/issues/161 # Ref 2 Official Issue: https://github.com/tensorflow/tensorflow/issues/3628 gd = sess.graph.as_graph_def() for node in gd.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in range(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] output_graph_def = graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights gd, # The graph_def is used to retrieve the nodes output_node_names.split(",") # The output node names are used to select the usefull nodes ) with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'model.pb'), "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def convertGraph( modelPath, outdir, numoutputs, prefix, name): ''' Converts an HD5F file to a .pb file for use with Tensorflow. Args: modelPath (str): path to the .h5 file outdir (str): path to the output directory numoutputs (int): prefix (str): the prefix of the output aliasing name (str): Returns: None ''' #NOTE: If using Python > 3.2, this could be replaced with os.makedirs( name, exist_ok=True ) if not os.path.isdir(outdir): os.mkdir(outdir) K.set_learning_phase(0) net_model = load_model(modelPath) # Alias the outputs in the model - this sometimes makes them easier to access in TF pred = [None]*numoutputs pred_node_names = [None]*numoutputs for i in range(numoutputs): pred_node_names[i] = prefix+'_'+str(i) pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i]) print('Output nodes names are: ', pred_node_names) sess = K.get_session() # Write the graph in human readable f = 'graph_def_for_reference.pb.ascii' tf.train.write_graph(sess.graph.as_graph_def(), outdir, f, as_text=True) print('Saved the graph definition in ascii format at: ', osp.join(outdir, f)) # Write the graph in binary .pb file from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_io constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names) graph_io.write_graph(constant_graph, outdir, name, as_text=False) print('Saved the constant graph (ready for inference) at: ', osp.join(outdir, name))
def train_network(graph, batch_size, num_epochs, pb_file_path): init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) epoch_delta = 2 for epoch_index in range(num_epochs): for i in range(12): sess.run([graph['optimize']], feed_dict={ graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)), graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]]) }) if epoch_index % epoch_delta == 0: total_batches_in_train_set = 0 total_correct_times_in_train_set = 0 total_cost_in_train_set = 0. for i in range(12): return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={ graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)), graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]]) }) mean_cost_in_batch = sess.run(graph['cost'], feed_dict={ graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)), graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]]) }) total_batches_in_train_set += 1 total_correct_times_in_train_set += return_correct_times_in_batch total_cost_in_train_set += (mean_cost_in_batch * batch_size) total_batches_in_test_set = 0 total_correct_times_in_test_set = 0 total_cost_in_test_set = 0. for i in range(3): return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={ graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)), graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]]) }) mean_cost_in_batch = sess.run(graph['cost'], feed_dict={ graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)), graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]]) }) total_batches_in_test_set += 1 total_correct_times_in_test_set += return_correct_times_in_batch total_cost_in_test_set += (mean_cost_in_batch * batch_size) acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size) acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size) print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set, total_batches_in_test_set * batch_size, total_cost_in_test_set, acy_on_train * 100.0, total_correct_times_in_train_set, total_batches_in_train_set * batch_size, total_cost_in_train_set)) constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString())
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes): """Converts all variables in a graph and checkpoint into constants.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not tf.gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not tf.train.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with tf.gfile.FastGFile(input_saver, mode) as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def _calculate_average_and_put(self, group_id, item, m): keys = item['keys'] tf.reset_default_graph() sess = tf.Session() new_vars = [] m_cal_and_put = SimpleMeasurement('cal_and_put', m) m_init = SimpleMeasurement('init', m) init_op = tf.global_variables_initializer() sess.run(init_op) m_init.end_measure() for v in item['variables']: count = 0 name = 'average_%s' % v ts = [] for key in keys: raw = self.rc.get(key) # TODO: check raw is not None util.restore_graph(key, raw) g = sess.graph t = g.get_tensor_by_name('%s/%s:0' % (key, v)) ts.append(t) count += 1 m_cal = SimpleMeasurement('cal', m) avg = tf.foldl(tf.add, ts) / count new_var = tf.Variable(avg, name=name) sess.run(new_var.initializer) sess.run(new_var) new_vars.append(name) m_cal.end_measure() g = sess.graph g_def = g.as_graph_def() constants = graph_util.convert_variables_to_constants( sess, g_def, new_vars) s = constants.SerializeToString() self.rc.set(group_id, s) sess.close() m_cal_and_put.end_measure()
def freeze_graph(model_folder): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + "/frozen_model.pb" # Before exporting our graph, we need to precise what is our output node # this variables is plural, because you can have multiple output nodes # freeze?????????????,??????????????? # ?????????????? # ?????????,freeze????????????????????,?????????? # ??,output_node_names????????????? output_node_names = "softmaxLayer/Softmax" # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated clear_devices = True # We import the meta graph and retrive a Saver saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights # ???????????????,????????????,???????????,??????frozen # ??????????????? with tf.Session() as sess: saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constant output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",") # We split on comma for convenience ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, verbose=True): """Converts all variables in a graph and checkpoint into constants.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not tf.gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 if not tf.gfile.Glob(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with tf.gfile.FastGFile(input_saver, mode) as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) if verbose == True: print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes): """Converts all variables in a graph and checkpoint into constants.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not tf.gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 if not tf.gfile.Glob(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with tf.gfile.FastGFile(input_saver, mode) as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(model_folder): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + "/frozen_model.pb" # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep and what part it can dump # NOTE: this variables is plural, because you can have multiple output # nodes output_node_names = "29_fully_connected" # We clear the devices, to allow TensorFlow to control on the loading # where it wants operations to be calculated clear_devices = True # We import the meta graph and retrive a Saver saver = tf.train.import_meta_graph( input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constant output_graph_def = graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes # The output node names are used to select the usefull nodes output_node_names.split(",") ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))