我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用torch.size()。
def _batch2torch(self, batch, batch_size): """ List of transitions -> Batch of transitions -> pytorch tensors. Returns: states: torch.size([batch_size, hist_len, w, h]) a/r/d: torch.size([batch_size, 1]) """ # check-out pytorch dqn tutorial. # (t1, t2, ... tn) -> t((s1, s2, ..., sn), (a1, a2, ... an) ...) batch = BatchTransition(*zip(*batch)) # lists to tensors state_batch = torch.cat(batch.state, 0).type(self.dtype.FT) / 255 action_batch = self.dtype.LT(batch.action).unsqueeze(1) reward_batch = self.dtype.FT(batch.reward).unsqueeze(1) next_state_batch = torch.cat(batch.state_, 0).type(self.dtype.FT) / 255 # [False, False, True, False] -> [1, 1, 0, 1]::ByteTensor mask = 1 - self.dtype.BT(batch.done).unsqueeze(1) return [batch_size, state_batch, action_batch, reward_batch, next_state_batch, mask]
def __init__(self, env, env_type, hist_len, state_dims, cuda=None): super(PreprocessFrames, self).__init__(env) self.env_type = env_type self.state_dims = state_dims self.hist_len = hist_len self.env_wh = self.env.observation_space.shape[0:2] self.env_ch = self.env.observation_space.shape[2] self.wxh = self.env_wh[0] * self.env_wh[1] # need to find a better way if self.env_type == "atari": self._preprocess = self._atari_preprocess elif self.env_type == "catch": self._preprocess = self._catch_preprocess print("[Preprocess Wrapper] for %s with state history of %d frames." % (self.env_type, hist_len)) self.cuda = False if cuda is None else cuda self.dtype = dtype = TorchTypes(self.cuda) self.rgb = dtype.FT([.2126, .7152, .0722]) # torch.size([1, 4, 24, 24]) """ self.hist_state = torch.FloatTensor(1, hist_len, *state_dims) self.hist_state.fill_(0) """ self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0) for i in range(hist_len)})