我们从Python开源项目中,提取了以下34个代码示例,用于说明如何使用tensorflow.NodeDef()。
def tf_obj_shape(input): """ Convert tf objects to shape tuple. Arguments: input: tf.TensorShape, tf.Tensor, tf.AttrValue or tf.NodeDef the corresponding tensorflow object Returns: tuple: shape of the tensorflow object """ if isinstance(input, tf.TensorShape): return tuple([int(i.value) for i in input]) elif isinstance(input, tf.Tensor): return tf_obj_shape(input.get_shape()) elif isinstance(input, tf.AttrValue): return tuple([int(d.size) for d in input.shape.dim]) elif isinstance(input, tf.NodeDef): return tf_obj_shape(input.attr['shape']) else: raise TypeError("Input to `tf_obj_shape` has the wrong type.")
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): def _assign(op): node_def = op if isinstance(op, tf.NodeDef) else op.node_def if node_def.op == "Variable": return ps_dev else: return "/gpu:%d" % gpu return _assign
def variables_device(self): """Returns the device to use for variables created inside the clone. Returns: A value suitable for `tf.device()`. """ device = '' if self._num_ps_tasks > 0: device += self._ps_device device += '/device:CPU:0' class _PSDeviceChooser(object): """Slim device chooser for variables when using PS.""" def __init__(self, device, tasks): self._device = device self._tasks = tasks self._task = 0 def choose(self, op): if op.device: return op.device node_def = op if isinstance(op, tf.NodeDef) else op.node_def if node_def.op == 'Variable': t = self._task self._task = (self._task + 1) % self._tasks d = '%s/task:%d' % (self._device, t) return d else: return op.device if not self._num_ps_tasks: return device else: chooser = _PSDeviceChooser(device, self._num_ps_tasks) return chooser.choose
def variable_device(device, name): """Fix the variable device to colocate its ops.""" if callable(device): var_name = tf.get_variable_scope().name + '/' + name var_def = tf.NodeDef(name=var_name, op='Variable') device = device(var_def) if device is None: device = '' return device
def choose(self, op): if op.device: return op.device node_def = op if isinstance(op, tf.NodeDef) else op.node_def if node_def.op in ["Variable", "VariableV2"]: return self._var_device else: return self._clone_device
def variables_device(self): """Returns the device to use for variables created inside the clone. Returns: A value suitable for `tf.device()`. """ device = '' if self._num_ps_tasks > 0: device += self._ps_device device += '/device:CPU:0' class _PSDeviceChooser(object): """Slim device chooser for variables when using PS.""" def __init__(self, device, tasks): self._device = device self._tasks = tasks self._task = 0 def choose(self, op): if op.device: return op.device node_def = op if isinstance(op, tf.NodeDef) else op.node_def if node_def.op.startswith('Variable'): t = self._task self._task = (self._task + 1) % self._tasks d = '%s/task:%d' % (self._device, t) return d else: return op.device if not self._num_ps_tasks: return device else: chooser = _PSDeviceChooser(device, self._num_ps_tasks) return chooser.choose
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))