我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用tensorflow.Saver()。
def _build_tf_graph(self): """Build the TF graph, setup model saving and setup a TF session. Notes ----- This method initializes a TF Saver and a TF Session via ```python self._saver = tf.train.Saver() self._session = tf.Session()
These calls are made after `self._set_up_graph()`` is called. See the main class docs for how to properly call this method from a child class. """ self._set_up_graph() self._saver = tf.train.Saver() self._session = tf.Session()
```
def saver(self): if self._saver is None: self._saver = tf.train.Saver() return self._saver
def tf_saver(self): if not hasattr(self, '_tf_saver'): self._tf_saver = tf.train.Saver( *self.tfsaver_args, **self.tfsaver_kwargs) return self._tf_saver
def __init__(self, session, saver, args): """ Create a model session. Do not call this constructor directly. To instantiate a ModelSession object, use the create and restore class methods. :param session: the session in which this model is running :type session: tf.Session :param saver: object used to serialize this session :type saver: tf.Saver """ self.session, self.saver, self.args = session, saver, args
def create(cls, **kwargs): """ Create a new model session. :param kwargs: optional graph parameters :type kwargs: dict :return: new model session :rtype: ModelSession """ session = tf.Session() with session.graph.as_default(): cls.create_graph(**kwargs) session.run(tf.initialize_all_variables()) return cls(session, tf.train.Saver())
def saver(self, **kwargs): """Returns a Saver for all (trainable and model) variables used by the model. Model variables include e.g. moving mean and average in BatchNorm. :return: tf.Saver """ return tf.train.Saver(self.vars, **kwargs)
def initialize(self, no_scratch=False): """Fetch record then uses tf's saver.restore.""" if self.do_restore: # First, determine which checkpoint to use. if self.from_ckpt is not None: # Use a cached checkpoint file. ckpt_filename = self.from_ckpt log.info('Restoring variables from checkpoint %s ...' % ckpt_filename) else: # Otherwise, use a database checkpoint. self.load_rec() if self.load_data is None else None if self.load_data is not None: rec, ckpt_filename = self.load_data log.info('Restoring variables from record %s (step %d)...' % (str(rec['_id']), rec['step'])) else: # No db checkpoint to load. ckpt_filename = None if ckpt_filename is not None: all_vars = tf.global_variables() + tf.local_variables() # get list of all variables self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars) # Next, determine which vars should be restored from the specified checkpoint. restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars) restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values())) restore_names = [name for name, var in restore_stripped.items()] # Actually load the vars. log.info('Restored Vars:\n' + str(restore_names)) tf_saver_restore = tf.train.Saver(restore_vars) tf_saver_restore.restore(self.sess, ckpt_filename) log.info('... done restoring.') # Reinitialize all other, unrestored vars. unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names] unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names] log.info('Unrestored Vars:\n' + str(unrestored_var_names)) self.sess.run(tf.variables_initializer(unrestored_vars)) # initialize variables not restored assert len(self.sess.run(tf.report_uninitialized_variables())) == 0, ( self.sess.run(tf.report_uninitialized_variables())) if not self.do_restore or (self.load_data is None and self.from_ckpt is None): init_op_global = tf.global_variables_initializer() self.sess.run(init_op_global) init_op_local = tf.local_variables_initializer() self.sess.run(init_op_local)
def get_restore_vars(self, save_file, all_vars=None): """Create the `var_list` init argument to tf.Saver from save_file. Extracts the subset of variables from tf.global_variables that match the name and shape of variables saved in the checkpoint file, and returns these as a list of variables to restore. To support multi-model training, a model prefix is prepended to all tf global_variable names, although this prefix is stripped from all variables before they are saved to a checkpoint. Thus, Args: save_file: path of tf.train.Saver checkpoint. Returns: dict: checkpoint variables. """ reader = tf.train.NewCheckpointReader(save_file) var_shapes = reader.get_variable_to_shape_map() log.info('Saved Vars:\n' + str(var_shapes.keys())) var_shapes = { # Strip the prefix off saved var names. strip_prefix_from_name(self.params['model_params']['prefix'], name): shape for name, shape in var_shapes.items()} # Map old vars from checkpoint to new vars via load_param_dict. mapped_var_shapes = self.remap_var_list(var_shapes) log.info('Saved shapes:\n' + str(mapped_var_shapes)) if all_vars is None: all_vars = tf.global_variables() + tf.local_variables() # get list of all variables all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars) # Specify which vars are to be restored vs. reinitialized. if self.load_param_dict is None: restore_vars = {name: var for name, var in all_vars.items() if name in mapped_var_shapes} else: # associate checkpoint names with actual variables load_var_dict = {} for ckpt_var_name, curr_var_name in self.load_param_dict.items(): for curr_name, curr_var in all_vars.items(): if curr_name == curr_var_name: load_var_dict[ckpt_var_name] = curr_var break restore_vars = load_var_dict restore_vars = self.filter_var_list(restore_vars) # Ensure the vars to restored have the correct shape. var_list = {} for name, var in restore_vars.items(): var_shape = var.get_shape().as_list() if var_shape == mapped_var_shapes[name]: var_list[name] = var return var_list
def test(sess, queues, dbinterface, validation_targets, save_intermediate_freq=None): """ Actually runs the testing evaluation loop. Args: sess (tensorflow.Session): Object in which to run calculations queues (list of CustomQueue): Objects containing asynchronously queued data iterators dbinterface (DBInterface object): Saver through which to save results validation_targets (dict of tensorflow objects): Objects on which validation will be computed. save_intermediate_freq (None or int): How frequently to save intermediate results captured during test None means no intermediate saving will be saved Returns: dict: Validation summary. dict: Results. """ # Collect args in a dict of lists test_args = { 'queues': queues, 'dbinterface': dbinterface, 'validation_targets': validation_targets, 'save_intermediate_freq': save_intermediate_freq} _ttargs = [{key: value[i] for (key, value) in test_args.items()} for i in range(len(queues))] for ttarg in _ttargs: ttarg['coord'], ttarg['threads'] = start_queues(sess) ttarg['dbinterface'].start_time_step = time.time() validation_summary = run_targets_dict(sess, ttarg['validation_targets'], save_intermediate_freq=ttarg['save_intermediate_freq'], dbinterface=ttarg['dbinterface'], validation_only=True) res = [] for ttarg in _ttargs: ttarg['dbinterface'].sync_with_host() res.append(ttarg['dbinterface'].outrecs) stop_queues(sess, ttarg['queues'], ttarg['coord'], ttarg['threads']) return validation_summary, res