Python utils 模块,get_model_dir() 实例源码

我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用utils.get_model_dir()

项目:gated-pixel-cnn    作者:jakebelew    | 项目源码 | 文件源码
def main(_):
  model_dir = util.get_model_dir(conf, 
      ['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step',
       'is_train', 'random_seed', 'log_level', 'display', 'runtime_base_dir', 
       'occlude_start_row', 'num_generated_images'])
  util.preprocess_conf(conf)
  validate_parameters(conf)

  data = 'mnist' if conf.data == 'color-mnist' else conf.data 
  DATA_DIR = os.path.join(conf.runtime_base_dir, conf.data_dir, data)
  SAMPLE_DIR = os.path.join(conf.runtime_base_dir, conf.sample_dir, conf.data, model_dir)

  util.check_and_create_dir(DATA_DIR)
  util.check_and_create_dir(SAMPLE_DIR)

  dataset = get_dataset(DATA_DIR, conf.q_levels)

  with tf.Session() as sess:
    network = Network(sess, conf, dataset.height, dataset.width, dataset.channels)

    stat = Statistic(sess, conf.data, conf.runtime_base_dir, model_dir, tf.trainable_variables())
    stat.load_model()

    if conf.is_train:
      train(dataset, network, stat, SAMPLE_DIR)
    else:
      generate(network, dataset.height, dataset.width, SAMPLE_DIR)
项目:NAF-tensorflow    作者:carpedm20    | 项目源码 | 文件源码
def main(_):
  model_dir = get_model_dir(conf, 
      ['is_train', 'random_seed', 'monitor', 'display', 'log_level'])
  preprocess_conf(conf)

  with tf.Session() as sess:
    # environment
    env = gym.make(conf.env_name)
    env._seed(conf.random_seed)

    assert isinstance(env.observation_space, gym.spaces.Box), \
      "observation space must be continuous"
    assert isinstance(env.action_space, gym.spaces.Box), \
      "action space must be continuous"

    # exploration strategy
    if conf.noise == 'ou':
      strategy = OUExploration(env, sigma=conf.noise_scale)
    elif conf.noise == 'brownian':
      strategy = BrownianExploration(env, conf.noise_scale)
    elif conf.noise == 'linear_decay':
      strategy = LinearDecayExploration(env)
    else:
      raise ValueError('Unkown exploration strategy: %s' % conf.noise)

    # networks
    shared_args = {
      'sess': sess,
      'input_shape': env.observation_space.shape,
      'action_size': env.action_space.shape[0],
      'hidden_dims': conf.hidden_dims, 
      'use_batch_norm': conf.use_batch_norm,
      'use_seperate_networks': conf.use_seperate_networks,
      'hidden_w': conf.hidden_w, 'action_w': conf.action_w,
      'hidden_fn': conf.hidden_fn, 'action_fn': conf.action_fn,
      'w_reg': conf.w_reg,
    }

    logger.info("Creating prediction network...")
    pred_network = Network(
      scope='pred_network', **shared_args
    )

    logger.info("Creating target network...")
    target_network = Network(
      scope='target_network', **shared_args
    )
    target_network.make_soft_update_from(pred_network, conf.tau)

    # statistic
    stat = Statistic(sess, conf.env_name, model_dir, pred_network.variables, conf.update_repeat)

    agent = NAF(sess, env, strategy, pred_network, target_network, stat,
                conf.discount, conf.batch_size, conf.learning_rate,
                conf.max_steps, conf.update_repeat, conf.max_episodes)

    agent.run(conf.monitor, conf.display, conf.is_train)