我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用utils.get_model_dir()。
def main(_): model_dir = util.get_model_dir(conf, ['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step', 'is_train', 'random_seed', 'log_level', 'display', 'runtime_base_dir', 'occlude_start_row', 'num_generated_images']) util.preprocess_conf(conf) validate_parameters(conf) data = 'mnist' if conf.data == 'color-mnist' else conf.data DATA_DIR = os.path.join(conf.runtime_base_dir, conf.data_dir, data) SAMPLE_DIR = os.path.join(conf.runtime_base_dir, conf.sample_dir, conf.data, model_dir) util.check_and_create_dir(DATA_DIR) util.check_and_create_dir(SAMPLE_DIR) dataset = get_dataset(DATA_DIR, conf.q_levels) with tf.Session() as sess: network = Network(sess, conf, dataset.height, dataset.width, dataset.channels) stat = Statistic(sess, conf.data, conf.runtime_base_dir, model_dir, tf.trainable_variables()) stat.load_model() if conf.is_train: train(dataset, network, stat, SAMPLE_DIR) else: generate(network, dataset.height, dataset.width, SAMPLE_DIR)
def main(_): model_dir = get_model_dir(conf, ['is_train', 'random_seed', 'monitor', 'display', 'log_level']) preprocess_conf(conf) with tf.Session() as sess: # environment env = gym.make(conf.env_name) env._seed(conf.random_seed) assert isinstance(env.observation_space, gym.spaces.Box), \ "observation space must be continuous" assert isinstance(env.action_space, gym.spaces.Box), \ "action space must be continuous" # exploration strategy if conf.noise == 'ou': strategy = OUExploration(env, sigma=conf.noise_scale) elif conf.noise == 'brownian': strategy = BrownianExploration(env, conf.noise_scale) elif conf.noise == 'linear_decay': strategy = LinearDecayExploration(env) else: raise ValueError('Unkown exploration strategy: %s' % conf.noise) # networks shared_args = { 'sess': sess, 'input_shape': env.observation_space.shape, 'action_size': env.action_space.shape[0], 'hidden_dims': conf.hidden_dims, 'use_batch_norm': conf.use_batch_norm, 'use_seperate_networks': conf.use_seperate_networks, 'hidden_w': conf.hidden_w, 'action_w': conf.action_w, 'hidden_fn': conf.hidden_fn, 'action_fn': conf.action_fn, 'w_reg': conf.w_reg, } logger.info("Creating prediction network...") pred_network = Network( scope='pred_network', **shared_args ) logger.info("Creating target network...") target_network = Network( scope='target_network', **shared_args ) target_network.make_soft_update_from(pred_network, conf.tau) # statistic stat = Statistic(sess, conf.env_name, model_dir, pred_network.variables, conf.update_repeat) agent = NAF(sess, env, strategy, pred_network, target_network, stat, conf.discount, conf.batch_size, conf.learning_rate, conf.max_steps, conf.update_repeat, conf.max_episodes) agent.run(conf.monitor, conf.display, conf.is_train)