我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用model.ActorCritic()。
def test(rank, args, shared_model): torch.manual_seed(args.seed + rank) env = gym.make(args.env_name) env.seed(args.seed + rank) model = ActorCritic(1, env.action_space) model.eval() state = env.reset() state = E.process_frame42(state) state = torch.from_numpy(state) reward_sum = 0 done = True start_time = time.time() # a quick hack to prevent the agent from stucking actions = deque(maxlen=100) episode_length = 0 while True: episode_length += 1 # Sync with the shared model if done: model.load_state_dict(shared_model.state_dict()) cx = Variable(torch.zeros(1, 256), volatile=True) hx = Variable(torch.zeros(1, 256), volatile=True) else: cx = Variable(cx.data, volatile=True) hx = Variable(hx.data, volatile=True) value, logit, (hx, cx) = model( (Variable(state.unsqueeze(0), volatile=True), (hx, cx))) prob = F.softmax(logit) action = prob.max(1)[1].data.numpy() state, reward, done, _ = env.step(action[0, 0]) state = E.process_frame42(state) done = done or episode_length >= args.max_episode_length reward_sum += reward # a quick hack to prevent the agent from stucking actions.append(action[0, 0]) if actions.count(actions[0]) == actions.maxlen: done = True if done: print("Time {}, episode reward {}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), reward_sum, episode_length)) reward_sum = 0 episode_length = 0 actions.clear() state = env.reset() state = E.process_frame42(state) time.sleep(60) state = torch.from_numpy(state)
def test(args, model, env): torch.manual_seed(args.seed) # env = create_atari_env(args.env_name) # env = create_car_racing_env() env.seed(args.seed) model = ActorCritic(env.observation_space.shape[0], env.action_space) model.eval() state = env.reset() state = torch.from_numpy(state) reward_sum = 0 done = True start_time = time.time() # a quick hack to prevent the agent from stucking actions = deque(maxlen=100) episode_length = 0 while True: #env.render() episode_length += 1 # Sync with the shared model if done: # model.load_state_dict(shared_model.state_dict()) cx = Variable(torch.zeros(1, model.lstm_size), volatile=True) hx = Variable(torch.zeros(1, model.lstm_size), volatile=True) else: cx = Variable(cx.data, volatile=True) hx = Variable(hx.data, volatile=True) value, logit, (hx, cx) = model( (Variable(state.unsqueeze(0), volatile=True), (hx, cx))) prob = F.softmax(logit) action = prob.max(1)[1].data.numpy() state, reward, done, _ = env.step(action[0, 0]) done = done or episode_length >= args.max_episode_length reward_sum += reward # a quick hack to prevent the agent from stucking actions.append(action[0, 0]) if actions.count(actions[0]) == actions.maxlen: done = True if done: print("Time {}, episode reward {}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), reward_sum, episode_length)) reward_sum = 0 episode_length = 0 actions.clear() state = env.reset() return # time.sleep(60) state = torch.from_numpy(state)
def test(rank, args, shared_model, counter): torch.manual_seed(args.seed + rank) env = create_atari_env(args.env_name) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], env.action_space) model.eval() state = env.reset() state = torch.from_numpy(state) reward_sum = 0 done = True start_time = time.time() # a quick hack to prevent the agent from stucking actions = deque(maxlen=100) episode_length = 0 while True: episode_length += 1 # Sync with the shared model if done: model.load_state_dict(shared_model.state_dict()) cx = Variable(torch.zeros(1, 256), volatile=True) hx = Variable(torch.zeros(1, 256), volatile=True) else: cx = Variable(cx.data, volatile=True) hx = Variable(hx.data, volatile=True) value, logit, (hx, cx) = model((Variable( state.unsqueeze(0), volatile=True), (hx, cx))) prob = F.softmax(logit) action = prob.max(1, keepdim=True)[1].data.numpy() state, reward, done, _ = env.step(action[0, 0]) done = done or episode_length >= args.max_episode_length reward_sum += reward # a quick hack to prevent the agent from stucking actions.append(action[0, 0]) if actions.count(actions[0]) == actions.maxlen: done = True if done: print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), counter.value, counter.value / (time.time() - start_time), reward_sum, episode_length)) reward_sum = 0 episode_length = 0 actions.clear() state = env.reset() time.sleep(60) state = torch.from_numpy(state)