我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用tensorflow.contrib.slim.get_variables()。
def save_model_for_prediction(self, save_ckpt_fn, vars_to_save=None): """Save model data only needed for prediction. Args: save_ckpt_fn: checkpoint file to save. vars_to_save: a list of variables to save. """ if vars_to_save is None: vars_to_save = slim.get_model_variables() vars_restore_to_exclude = [] for scope in self.dm_model.restore_scope_exclude: vars_restore_to_exclude.extend(slim.get_variables(scope)) # remove not restored variables. vars_to_save = [ v for v in vars_to_save if v not in vars_restore_to_exclude ] base_model.save_model(save_ckpt_fn, self.sess, vars_to_save)
def load_pretrained_model(self): """ Load the pretrained weights into the non-trainable layer :return: """ print('Load the pretrained weights into the non-trainable layer...') from tensorflow.python.framework import ops trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES) reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt) pretrained_model_variables = reader.get_variable_to_shape_map() for variable in trainable_variables: variable_name = variable.name.split(':')[0] if variable_name in self.skip_layer: continue if variable_name not in pretrained_model_variables: continue print('load ' + variable_name) with tf.variable_scope('', reuse=True): var = tf.get_variable(variable_name, trainable=False) data = reader.get_tensor(variable_name) self.sess.run(var.assign(data))
def load_model_from_checkpoint_fn(self, model_fn): """Load weights from file and keep in memory. Args: model_fn: saved model file. """ # self.dm_model.use_graph() print "start loading from checkpoint file..." if self.vars_to_restore is None: self.vars_to_restore = slim.get_variables() restore_fn = slim.assign_from_checkpoint_fn(model_fn, self.vars_to_restore) print "restoring model from {}".format(model_fn) restore_fn(self.sess) print "model restored."
def load_ckpt(self, sess, ckpt='ckpts/vgg_16.ckpt'): variables = slim.get_variables(scope='vgg_16') init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt, variables) sess.run(init_assign_op, init_feed_dict)
def build_network(self): state = tf.placeholder(tf.float32, [None, 84, 84, 4]) cnn_1 = slim.conv2d(state, 16, [8,8], stride=4, scope=self.name + '/cnn_1', activation_fn=nn.relu) cnn_2 = slim.conv2d(cnn_1, 32, [4,4], stride=2, scope=self.name + '/cnn_2', activation_fn=nn.relu) flatten = slim.flatten(cnn_2) fcc_1 = slim.fully_connected(flatten, 256, scope=self.name + '/fcc_1', activation_fn=nn.relu) adv_probas = slim.fully_connected(fcc_1, self.nb_actions, scope=self.name + '/adv_probas', activation_fn=nn.softmax) value_state = slim.fully_connected(fcc_1, 1, scope=self.name + '/value_state', activation_fn=None) tf.summary.scalar("model/cnn1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_1'))) tf.summary.scalar("model/cnn2_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_2'))) tf.summary.scalar("model/fcc1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/fcc_1'))) tf.summary.scalar("model/adv_probas_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/adv_probas'))) tf.summary.scalar("model/value_state_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/value_state'))) #Input self._tf_state = state #Output self._tf_adv_probas = adv_probas self._tf_value_state = value_state