我们从Python开源项目中,提取了以下37个代码示例,用于说明如何使用torch.split()。
def forward(self, x): x_shape = x.size() # (b, c, h, w) offset = self.offset_filter(x) # (b, 2*c, h, w) offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w) offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) if not self.input_shape or self.input_shape != x_shape: self.input_shape = x_shape grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w) grid_w = torch.Tensor(grid_w) grid_h = torch.Tensor(grid_h) if self.cuda: grid_w = grid_w.cuda() grid_h = grid_h.cuda() self.grid_w = nn.Parameter(grid_w) self.grid_h = nn.Parameter(grid_h) offset_w = offset_w + self.grid_w # (b*c, h, w) offset_h = offset_h + self.grid_h # (b*c, h, w) x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w) x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w) x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w) x = self.regular_filter(x) return x
def node_forward(self, inputs, child_c, child_h): child_h_sum = torch.sum(child_h, dim=0, keepdim=True) iou = self.ioux(inputs) + self.iouh(child_h_sum) i, o, u = torch.split(iou, iou.size(1) // 3, dim=1) i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u) f = F.sigmoid( self.fh(child_h) + self.fx(inputs).repeat(len(child_h), 1) ) fc = torch.mul(f, child_c) c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True) h = torch.mul(o, F.tanh(c)) return c, h
def get_distance_losses(self, A, AB, A_to_AB=True ): As = torch.split(A, 1) ABs = torch.split(AB, 1) loss_distance_A = 0.0 num_pairs = 0 min_length = len(As) for i in xrange(min_length - 1): for j in xrange(i + 1, min_length): num_pairs += 1 loss_distance_A_ij = \ self.get_individual_distance_loss(As[i], As[j], ABs[i], ABs[j], A_to_AB) loss_distance_A += loss_distance_A_ij loss_distance_A = loss_distance_A / num_pairs return loss_distance_A
def forward(self, inputs, batch_size, hidden_cell=None): if hidden_cell is None: # then must init with zeros if use_cuda: hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda()) cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda()) else: hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size)) cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size)) hidden_cell = (hidden, cell) _, (hidden,cell) = self.lstm(inputs.float(), hidden_cell) # hidden is (2, batch_size, hidden_size), we want (batch_size, 2*hidden_size): hidden_forward, hidden_backward = torch.split(hidden,1,0) hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1) # mu and sigma: mu = self.fc_mu(hidden_cat) sigma_hat = self.fc_sigma(hidden_cat) sigma = torch.exp(sigma_hat/2.) # N ~ N(0,1) z_size = mu.size() if use_cuda: N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda()) else: N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size))) z = mu + sigma*N # mu and sigma_hat are needed for LKL loss return z, mu, sigma_hat
def backward(self, outputs, targets, weights, normalizer, criterion, regression=False): outputs_split = torch.split(outputs, self.batch_size, self.dim) targets_split = torch.split(targets, self.batch_size, self.dim) weights_split = torch.split(weights, self.batch_size, self.dim) grad_output = [] loss = 0 for out_t, targ_t, w_t in zip(outputs_split, targets_split, weights_split): grad_output_t, loss_t = super(MemEfficientGenerator, self).backward( out_t, targ_t, w_t, normalizer, criterion, regression) grad_output.append(grad_output_t) loss += loss_t grad_output = torch.cat(grad_output, self.dim) return grad_output, loss
def forward(self, inp): #if inp.dim() > 2: # inp = inp.permute(0, 2, 1) #inp = inp.contiguous().view(-1, self.L) if not (type(inp) == Variable): inp = Variable(inp[0]) if hasattr(self.arguments, 'pack_num'): N = inp.size(0) Ncut = int(N/self.arguments.pack_num) split = torch.split(inp, Ncut, dim=0) inp = torch.cat(split, dim=1) h1 = F.tanh((self.l1(inp))) #h2 = F.tanh(self.l2_bn(self.l2(h1))) if self.arguments.tr_method == 'adversarial_wasserstein': output = (self.l3(h1)) else: output = F.sigmoid(self.l3(h1)) return output, h1
def __init__(self, root, single_spkr=False): self.root = root self.npzs = self.make_dataset(self.root) if len(self.npzs) == 0: raise(RuntimeError("Found 0 npz in subfolders of: " + root + "\n" "Supported image extensions are: " + self.NPZ_EXTENSION)) if single_spkr: self.speakers = defaultdict(lambda: 0) else: self.speakers = [] for fname in self.npzs: self.speakers += [os.path.basename(fname).split('_')[0]] self.speakers = list(set(self.speakers)) self.speakers.sort() self.speakers = {v: i for i, v in enumerate(self.speakers)} code2phone = np.load(self.npzs[0])['code2phone'] self.dict = {v: k for k, v in enumerate(code2phone)}
def __init__(self, src, trgt, spkr, seq_len): self.seq_len = seq_len self.start = True self.speakers = spkr self.srcBatch = src[0] self.srcLenths = src[1] # split batch self.tgtBatch = list(torch.split(trgt[0], self.seq_len, 0)) self.tgtBatch.reverse() self.len = len(self.tgtBatch) # split length list batch_seq_len = len(self.tgtBatch) self.tgtLenths = [self.split_length(l, batch_seq_len) for l in trgt[1]] self.tgtLenths = torch.stack(self.tgtLenths) self.tgtLenths = list(torch.split(self.tgtLenths, 1, 1)) self.tgtLenths = [x.squeeze() for x in self.tgtLenths] self.tgtLenths.reverse() assert len(self.tgtLenths) == len(self.tgtBatch)
def get_distance_losses(self): As = torch.split(self.real_A, 1) Bs = torch.split(self.real_B, 1) ABs = torch.split(self.fake_B, 1) BAs = torch.split(self.fake_A, 1) loss_distance_A = 0.0 loss_distance_B = 0.0 num_pairs = 0 min_length = min(len(As), len(Bs)) for i in xrange(min_length - 1): for j in xrange(i + 1, min_length): num_pairs += 1 loss_distance_A_ij, loss_distance_B_ij = \ self.get_individual_distance_loss(As[i], As[j], ABs[i], ABs[j], Bs[i], Bs[j], BAs[i], BAs[j]) loss_distance_A += loss_distance_A_ij loss_distance_B += loss_distance_B_ij loss_distance_A = loss_distance_A / num_pairs loss_distance_B = loss_distance_B / num_pairs return loss_distance_A, loss_distance_B
def forward(self, inputs, z, hidden_cell=None): if hidden_cell is None: # then we must init from z hidden,cell = torch.split(F.tanh(self.fc_hc(z)),hp.dec_hidden_size,1) hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) outputs,(hidden,cell) = self.lstm(inputs, hidden_cell) # in training we feed the lstm with the whole input in one shot # and use all outputs contained in 'outputs', while in generate # mode we just feed with the last generated sample: if self.training: y = self.fc_params(outputs.view(-1, hp.dec_hidden_size)) else: y = self.fc_params(hidden.view(-1, hp.dec_hidden_size)) # separate pen and mixture params: params = torch.split(y,6,1) params_mixture = torch.stack(params[:-1]) # trajectory params_pen = params[-1] # pen up/down # identify mixture params: pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy = torch.split(params_mixture,1,2) # preprocess params:: if self.training: len_out = Nmax+1 else: len_out = 1 pi = F.softmax(pi.t().squeeze()).view(len_out,-1,hp.M) sigma_x = torch.exp(sigma_x.t().squeeze()).view(len_out,-1,hp.M) sigma_y = torch.exp(sigma_y.t().squeeze()).view(len_out,-1,hp.M) rho_xy = torch.tanh(rho_xy.t().squeeze()).view(len_out,-1,hp.M) mu_x = mu_x.t().squeeze().contiguous().view(len_out,-1,hp.M) mu_y = mu_y.t().squeeze().contiguous().view(len_out,-1,hp.M) q = F.softmax(params_pen).view(len_out,-1,3) return pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy,q,hidden,cell
def make_image(sequence, epoch, name='_output_'): """plot drawing with separated strokes""" strokes = np.split(sequence, np.where(sequence[:,2]>0)[0]+1) fig = plt.figure() ax1 = fig.add_subplot(111) for s in strokes: plt.plot(s[:,0],-s[:,1]) canvas = plt.get_current_fig_manager().canvas canvas.draw() pil_image = PIL.Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb()) name = str(epoch)+name+'.jpg' pil_image.save(name,"JPEG") plt.close("all")
def unbundle(state): if state is None: return itertools.repeat(None) return torch.split(torch.cat(state, 1), 1, 0)
def predict(self, outputs, targets, weights, criterion): outputs_split = torch.split(outputs, self.batch_size, self.dim) targets_split = torch.split(targets, self.batch_size, self.dim) weights_split = torch.split(weights, self.batch_size, self.dim) preds = [] loss = 0 for out_t, targ_t, w_t in zip(outputs_split, targets_split, weights_split): preds_t, loss_t = super(MemEfficientGenerator, self).predict( out_t, targ_t, w_t, criterion) preds.append(preds_t) loss += loss_t preds = torch.cat(preds, self.dim) return preds, loss
def forward(self, input_, hx): """ Args: input_: A (batch, input_size) tensor containing input features. hx: A tuple (h_0, c_0), which contains the initial hidden and cell state, where the size of both states is (batch, hidden_size). Returns: h_1, c_1: Tensors containing the next hidden and cell state. """ h_0, c_0 = hx batch_size = h_0.size(0) bias_batch = (self.bias.unsqueeze(0) .expand(batch_size, *self.bias.size())) wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) wi = torch.mm(input_, self.weight_ih) f, i, o, g = torch.split(wh_b + wi, split_size=self.hidden_size, dim=1) c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) h_1 = torch.sigmoid(o) * torch.tanh(c_1) return h_1, c_1
def forward(self, input_, hx, time): """ Args: input_: A (batch, input_size) tensor containing input features. hx: A tuple (h_0, c_0), which contains the initial hidden and cell state, where the size of both states is (batch, hidden_size). time: The current timestep value, which is used to get appropriate running statistics. Returns: h_1, c_1: Tensors containing the next hidden and cell state. """ h_0, c_0 = hx batch_size = h_0.size(0) bias_batch = (self.bias.unsqueeze(0) .expand(batch_size, *self.bias.size())) wh = torch.mm(h_0, self.weight_hh) wi = torch.mm(input_, self.weight_ih) bn_wh = self.bn_hh(wh, time=time) bn_wi = self.bn_ih(wi, time=time) f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, split_size=self.hidden_size, dim=1) c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time)) return h_1, c_1
def memoryEfficientLoss(outputs, targets, generator, crit, max_generator_batches, eval=False): """Memory efficient loss. :param outputs: seq_len x batch_size x logits_size :param targets: seq_len x batch_size :param generator: :param crit: :param max_generator_batches: :param eval: :return: """ # compute generations one piece at a time num_correct, loss = 0, 0 outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size batch_size = outputs.size(1) outputs_split = torch.split(outputs, max_generator_batches) targets_split = torch.split(targets, max_generator_batches) for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): # out_t = seq_len x batch_size x logits_size # targ_t = seq_len x batch_size out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size scores_t = generator(out_t) # seq_len * batch_size x voc_size loss_t = crit(scores_t, targ_t.view(-1)) # scholar (1-d) pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1 num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum() num_correct += num_correct_t loss += loss_t.data[0] if not eval: loss_t.div(batch_size).backward() grad_output = None if outputs.grad is None else outputs.grad.data return loss, grad_output, num_correct
def set_proposal_params(self, tensor_of_proposal_means_stds_coeffs): n_components = int(tensor_of_proposal_means_stds_coeffs.size(0) / 3) self.proposal_means, self.proposal_stds, self.proposal_coeffs = torch.split(tensor_of_proposal_means_stds_coeffs, n_components)
def split(self, split_size, dim=0): """Splits this tensor into a tuple of tensors. See :func:`torch.split`. """ return torch.split(self, split_size, dim)
def split(self, split_size, dim=0): return torch.split(self, split_size, dim)
def execute(self): maxLen = max([len(e) for e in self.progs]) for s in range(maxLen): nodes = [] for i in range(len(self.progs)): prog = self.progs[i] if len(prog) <= s: continue nodes += [prog[s]] groupedNodes = {} for node in nodes: groupedNodes.setdefault(node.cellInd, []).append(node) for cellInd, nodes in groupedNodes.items(): arity = nodes[0].arity cell = self.cells[cellInd] outData = [node.inpData[0] for node in nodes] if arity==1: arg = t.cat(outData, 0) outData = cell(arg) outData = t.split(outData, 1, 0) elif arity==2: arg1 = t.cat(outData, 0) arg2 = t.cat([node.inpData[1] for node in nodes], 0) outData = cell(arg1, arg2) outData = t.split(outData, 1, 0) for node, outDat in zip(nodes, outData): if node.prev is None: node.outData = outDat else: node.prev.inpData += [outDat] outData = [prog[-1].outData for prog in self.progs] return t.cat(outData, 0)
def sample_outputs(generator, Nsamples, arguments): inp = torch.randn(Nsamples, arguments.L1) if arguments.cuda: inp = inp.cuda() out = generator.forward(Variable(inp)) if arguments.task == 'images': out = out.contiguous().view(-1, arguments.nfts, arguments.T) return torch.split(out.data, split_size=1, dim=0)
def forward(self, x, ident, context, start=True): out, attns = [], [] o_t = x[0] self.init_buffer(ident, start) for o_tm1 in torch.split(x, 1): if not self.training: o_tm1 = o_t.unsqueeze(0) # predict weighted context based on S c_t, mu_t, alpha_t = self.attn(self.S_t, context.transpose(0, 1), self.mu_t) # advance mu and update buffer self.S_t = self.update_buffer(self.S_t, c_t, o_tm1, ident) self.mu_t = mu_t # predict next time step based on buffer content ot_out = self.N_o(self.S_t.view(self.S_t.size(0), -1)) sp_out = self.F_o(ident) o_t = self.output(ot_out + sp_out) out += [o_t] attns += [alpha_t.squeeze()] out_seq = torch.stack(out) attns_seq = torch.stack(attns) return out_seq, attns_seq
def loader(self, path): feat = np.load(path) txt = feat['phonemes'].astype('int64') txt = torch.from_numpy(txt) audio = feat['audio_features'] audio = torch.from_numpy(audio) spkr = os.path.basename(path).split('_')[0] return txt, audio, spkr
def forward(self, x, lstm_hidden_vb=None): p = x.view(x.size(0), self.input_dims[0] * self.input_dims[1]) p = self.rl1(self.fc1(p)) p = self.rl2(self.fc2(p)) p = self.rl3(self.fc3(p)) p = self.rl4(self.fc4(p)) p = p.view(-1, self.hidden_dim) if self.enable_lstm: p_, v_ = torch.split(lstm_hidden_vb[0],1) c_p, c_v = torch.split(lstm_hidden_vb[1],1) p, c_p = self.lstm(p, (p_, c_p)) p_out = self.policy_5(p) sig = self.policy_sig(p) sig = self.softplus(sig) v = x.view(x.size(0), self.input_dims[0] * self.input_dims[1]) v = self.rl1_v(self.fc1_v(v)) v = self.rl2_v(self.fc2_v(v)) v = self.rl3_v(self.fc3_v(v)) v = self.rl4_v(self.fc4_v(v)) v = v.view(-1, self.hidden_dim) if self.enable_lstm: v, c_v = self.lstm_v(v, (v_, c_v)) v_out = self.value_5(v) if self.enable_lstm: return p_out, sig, v_out, (torch.cat((p,v),0), torch.cat((c_p, c_v),0)) else: return p_out, sig, v_out
def forward(self, input_, hx): """ Args: input_: A (batch, input_size) tensor containing input features. hx: initial hidden, where the size of the state is (batch, hidden_size). Returns: newh: Tensors containing the next hidden state. """ batch_size = hx.size(0) bias_batch = (self.gate_bias.unsqueeze(0) .expand(batch_size, *self.gate_bias.size())) gate_Wh = torch.addmm(bias_batch, hx, self.gate_W) gate_Ux = torch.mm(input_, self.gate_U) r, z = torch.split(gate_Ux + gate_Wh, split_size=self.hidden_size, dim=1) Ux = torch.mm(input_, self.U) unitary = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB) unitary = unitary * r newh = Ux + unitary newh = self._modReLU(newh, self.bias) newh = hx * z + (1-z) * newh return newh
def reader(self): with open(self.filepath, 'r') as f: if self.has_header: next(f) for line in f: w, *vec = line.split() yield w, vec
def shards(data, size=25, test=False): """ Generator over variables that will be involved in a costly loss computation such as the softmax. It yields dictionaries of the same form as the input, where the variables have been splitted in smaller shards and detach from the graph. It expects the consumer to back propagate through them in shards of given a size. After all shards are consumed, the generator will take care of backprop further from the input using the accumulated gradients. """ # Inspired by www.github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Loss.py if test: yield data return detached = dict(detach_vars(data)) splits = ((key, torch.split(v, size)) for key, v in detached.items()) keys, splits = zip(*splits) for split in zip(*splits): yield dict(zip(keys, split)) # go and accumulate some loss inputs, grads = [], [] for key, var in detached.items(): if var.grad is not None: inputs.append(data[key]), grads.append(var.grad.data) torch.autograd.backward(inputs, grads, retain_graph=True) # Initializers
def split(self, split_size, dim=0): r"""Splits this tensor into tensor chunks of :attr:`split_size` size. See :func:`torch.split`. """ return torch.split(self, split_size, dim)
def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ mask: torch.Tensor = None) -> torch.Tensor: """ Compute a weighted average of the ``tensors``. The input tensors an be any shape with at least two dimensions, but must all be the same shape. When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``. When ``do_layer_norm=False`` the ``mask`` is ignored. """ if len(tensors) != self.mixture_size: raise ConfigurationError("{} tensors were passed, but the module was initialized to " "mix {} tensors.".format(len(tensors), self.mixture_size)) def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): tensor_masked = tensor * broadcast_mask mean = torch.sum(tensor_masked) / num_elements_not_masked variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked return (tensor - mean) / torch.sqrt(variance + 1E-12) normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter in self.scalar_parameters]), dim=0) normed_weights = torch.split(normed_weights, split_size=1) if not self.do_layer_norm: pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * tensor) return self.gamma * sum(pieces) else: mask_float = mask.float() broadcast_mask = mask_float.unsqueeze(-1) input_dim = tensors[0].size(-1) num_elements_not_masked = torch.sum(mask_float) * input_dim pieces = [] for weight, tensor in zip(normed_weights, tensors): pieces.append(weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked)) return self.gamma * sum(pieces)
def forward(self, buffers, transitions): buffers = [list(torch.split(b.squeeze(1), 1, 0)) for b in torch.split(buffers, 1, 1)] stacks = [[buf[0], buf[0]] for buf in buffers] if hasattr(self, 'tracker'): self.tracker.reset_state() else: assert transitions is not None if transitions is not None: num_transitions = transitions.size(0) # trans_loss, trans_acc = 0, 0 else: num_transitions = len(buffers[0]) * 2 - 3 for i in range(num_transitions): if transitions is not None: trans = transitions[i] if hasattr(self, 'tracker'): tracker_states, trans_hyp = self.tracker(buffers, stacks) if trans_hyp is not None: trans = trans_hyp.max(1)[1] # if transitions is not None: # trans_loss += F.cross_entropy(trans_hyp, trans) # trans_acc += (trans_preds.data == trans.data).mean() # else: # trans = trans_preds else: tracker_states = itertools.repeat(None) lefts, rights, trackings = [], [], [] batch = zip(trans.data, buffers, stacks, tracker_states) for transition, buf, stack, tracking in batch: if transition == 3: # shift stack.append(buf.pop()) elif transition == 2: # reduce rights.append(stack.pop()) lefts.append(stack.pop()) trackings.append(tracking) if rights: reduced = iter(self.reduce(lefts, rights, trackings)) for transition, stack in zip(trans.data, stacks): if transition == 2: stack.append(next(reduced)) # if trans_loss is not 0: return bundle([stack.pop() for stack in stacks])[0]
def alpha_loss(outputs, targets, generator, crit, max_generator_batches, rewards, proposed_weights, tau, alpha, eval=False): """Loss function of proposed method. :param outputs: seq_len x batch_size x logits_size :param targets: seq_len x batch_size :param generator: :param crit: :param max_generator_batches: :param eval: :return: """ # compute generations one piece at a time num_correct, loss = 0, 0 outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) # seq_len x batch_size x logits_size batch_size = outputs.size(1) outputs_split = torch.split(outputs, max_generator_batches) targets_split = torch.split(targets, max_generator_batches) # TODO(sotetsuk): fix to calculate at once importance_list = [] p_sample_efficiency_list = [] q_sample_efficiency_list = [] pq_sample_efficiency_list = [] for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): out_t = out_t.view(-1, out_t.size(2)) # seq_len * batch_size x logits_size scores_t = generator(out_t) # seq_len * batch_size x voc_size proposed_weights = torch.FloatTensor(proposed_weights) log_q_weights = torch.FloatTensor(rewards) / tau loss_t, importance_t, p_sample_efficiency_t, q_sample_efficiency_t, pq_sample_efficiency_t = crit(scores_t, targ_t.view(-1), proposed_weights, log_q_weights, alpha, rewards) # scholar (1-d) pred_t = scores_t.max(1)[1] # seq_len * batch_size x 1 num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(Constants.PAD).data).sum() num_correct += num_correct_t loss += loss_t.data[0] importance_list += importance_t p_sample_efficiency_list += p_sample_efficiency_t q_sample_efficiency_list += q_sample_efficiency_t pq_sample_efficiency_list += pq_sample_efficiency_t if not eval: loss_t.div(batch_size).backward() grad_output = None if outputs.grad is None else outputs.grad.data return loss, grad_output, num_correct, importance_list, p_sample_efficiency_list, q_sample_efficiency_list, pq_sample_efficiency_list
def shards(state, shard_size, eval=False): """ Args: state: A dictionary which corresponds to the output of *LossCompute.make_shard_state(). The values for those keys are Tensor-like or None. shard_size: The maximum size of the shards yielded by the model. eval: If True, only yield the state, nothing else. Otherwise, yield shards. Yields: Each yielded shard is a dict. Side effect: After the last shard, this function does back-propagation. """ if eval: yield state else: # non_none: the subdict of the state dictionary where the values # are not None. non_none = dict(filter_shard_state(state)) # Now, the iteration: # state is a dictionary of sequences of tensor-like but we # want a sequence of dictionaries of tensors. # First, unzip the dictionary into a sequence of keys and a # sequence of tensor-like sequences. keys, values = zip(*((k, torch.split(v, shard_size)) for k, v in non_none.items())) # Now, yield a dictionary for each shard. The keys are always # the same. values is a sequence of length #keys where each # element is a sequence of length #shards. We want to iterate # over the shards, not over the keys: therefore, the values need # to be re-zipped by shard and then each shard can be paired # with the keys. for shard_tensors in zip(*values): yield dict(zip(keys, shard_tensors)) # Assumed backprop'd variables = ((state[k], v.grad.data) for k, v in non_none.items() if isinstance(v, Variable) and v.grad is not None) inputs, grads = zip(*variables) torch.autograd.backward(inputs, grads)