def __init__(self, input_args, dest_nodes = None): super(TensorflowParser, self).__init__() # load model files into Keras graph from six import string_types as _string_types if isinstance(input_args, _string_types): model = TensorflowParser._load_meta(input_args) elif isinstance(input_args, tuple): model = TensorflowParser._load_meta(input_args[0]) self.ckpt_data = TensorflowParser._load_weights(input_args[1]) self.weight_loaded = True if dest_nodes != None: from tensorflow.python.framework.graph_util import extract_sub_graph model = extract_sub_graph(model, dest_nodes.split(',')) # Build network graph self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def save_graph_only(sess, output_file_path, output_node_names, as_text=False): """Save a small version of the graph based on a session and the output node names.""" for node in sess.graph_def.node: node.device = '' graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names) output_dir, output_filename = os.path.split(output_file_path) graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text)
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))