我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用torch.bmm()。
def forward(self, inp, hidden): outp = self.bilstm.forward(inp, hidden)[0] size = outp.size() # [bsz, len, nhid] compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] concatenated_inp = [transformed_inp for i in range(self.attention_hops)] concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] alphas = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len] penalized_alphas = alphas + ( -10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float()) # [bsz, hop, len] + [bsz, hop, len] alphas = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] alphas = alphas.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] return torch.bmm(alphas, outp), alphas
def backward(self, grad_output): batch1, batch2 = self.saved_tensors grad_add_matrix = grad_batch1 = grad_batch2 = None if self.needs_input_grad[0]: grad_add_matrix = grad_output if self.alpha != 1: grad_add_matrix = grad_add_matrix.mul(self.alpha) if any(self.needs_input_grad[1:]): batch_grad_output = (grad_output .unsqueeze(0) .expand(batch1.size(0), batch1.size(1), batch2.size(2))) if self.needs_input_grad[1]: grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) if self.beta != 1: grad_batch1 *= self.beta if self.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output) if self.beta != 1: grad_batch2 *= self.beta return grad_add_matrix, grad_batch1, grad_batch2
def backward(self, grad_output): batch1, batch2 = self.saved_tensors grad_add_batch = grad_batch1 = grad_batch2 = None if self.needs_input_grad[0]: grad_add_batch = grad_output if self.alpha != 1: grad_add_batch = grad_add_batch.mul(self.alpha) if self.needs_input_grad[1]: grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2)) if self.beta != 1: grad_batch1 *= self.beta if self.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output) if self.beta != 1: grad_batch2 *= self.beta return grad_add_batch, grad_batch1, grad_batch2
def updateOutput(self, input): assert len(input) == 2 a, b = input assert a.ndimension() == 2 or a.ndimension() == 3 assert a.dim() == b.dim() if a.ndimension() == 2: if self.transA: a = a.t() if self.transB: b = b.t() self.output.resize_(a.size(0), b.size(1)) torch.mm(self.output, a, b) else: if self.transA: a = a.transpose(2, 3) if self.transB: b = b.transpose(2, 3) self.output.resize_(a.size(0), a.size(1), b.size(2)) torch.bmm(self.output, a, b) return self.output
def updateOutput(self, input): M, v = input assert M.ndimension() == 2 or M.ndimension() == 3 if M.ndimension() == 2: assert v.ndimension() == 1 if self.trans: M = M.transpose(0, 1) self.output.resize_(M.size(0)) torch.mv(self.output, M, v) else: assert v.ndimension() == 2 if self.trans: M = M.transpose(1, 2) self.output.resize_(M.size(0), M.size(1), 1) torch.bmm(self.output, M, v.view(v.size(0), v.size(1), 1)).resize_(M.size(0), M.size(1)) return self.output
def forward(self, q, k, v, attn_mask): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head residual = q bsz, len_q, d_model = q.size() len_k, len_v = k.size(1), v.size(1) def reshape(x): """[bsz, len, d_*] -> [n_head x (bsz*len) x d_*]""" return x.repeat(n_head, 1, 1).view(n_head, -1, d_model) q_s, k_s, v_s = map(reshape, [q, k, v]) q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) outputs = self.attention(q_s, k_s, v_s, attn_mask.repeat(n_head, 1, 1)) outputs = torch.cat(torch.split(outputs, bsz, dim=0), dim=-1).view(-1, n_head*d_v) outputs = F.dropout(self.w_o(outputs), p=self.dropout).view(bsz, len_q, -1) return self.lm(outputs + residual)
def forward(self, x, target_embedding, encoder_out): residual = x # attention x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5) x = self.bmm(x, encoder_out[0]) # softmax over last dim sz = x.size() x = F.softmax(x.view(sz[0] * sz[1], sz[2])) x = x.view(sz) attn_scores = x x = self.bmm(x, encoder_out[1]) # scale attention output s = encoder_out[1].size(1) x = x * (s * math.sqrt(1.0 / s)) # project back x = (self.out_projection(x) + residual) * math.sqrt(0.5) return x, attn_scores
def forward(self, q, k, v, attn_mask=None): attn = torch.bmm(q, k.transpose(1, 2)) / self.temper if attn_mask is not None: assert attn_mask.size() == attn.size(), \ 'Attention mask shape {} mismatch ' \ 'with Attention logit tensor shape ' \ '{}.'.format(attn_mask.size(), attn.size()) attn.data.masked_fill_(attn_mask, -float('inf')) attn = self.softmax(attn) attn = self.dropout(attn) output = torch.bmm(attn, v) return output, attn
def forward(self, inputs, context): """ inputs: batch x dim context: batch x sourceL x dim """ targetT = self.linear_in(inputs).unsqueeze(2) # batch x dim x 1 # Get attention attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL if self.mask is not None: attn.data.masked_fill_(self.mask, -_INF) attn = self.sm(attn) attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim contextCombined = torch.cat((weightedContext, inputs), 1) contextOutput = self.tanh(self.linear_out(contextCombined)) return contextOutput, attn
def forward(self, input, context): """ input: batch x dim context: batch x sourceL x dim """ targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1 # Get attention attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL if self.mask is not None: attn.data.masked_fill_(self.mask, -float('inf')) attn = self.sm(attn) attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim contextCombined = torch.cat((weightedContext, input), 1) contextOutput = self.tanh(self.linear_out(contextCombined)) return contextOutput, attn
def forward(self, x): batchsize = x.size()[0] trans = self.stn(x) # regressing the transforming parameters using STN x = x.transpose(2,1) # bz x 2048 x 3 x = torch.bmm(x, trans) # (bz x 2048 x 3) x (bz x 3 x 3) x = x.transpose(2,1) # bz x 3 x 2048 x = F.relu(self.bn1(self.conv1(x))) pointfeat = x # bz x 64 x 2048 x = F.relu(self.bn2(self.conv2(x))) # bz x 128 x 2048 x = self.bn3(self.conv3(x)) # bz x 1024 x 2048 x = self.mp1(x) x = x.view(-1, 1024) # bz x 1024 if self.global_feat: # using global feats for classification return x, trans else: x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([x, pointfeat], 1), trans
def backward(ctx, grad_output): batch1, batch2 = ctx.saved_variables grad_add_matrix = grad_batch1 = grad_batch2 = None if ctx.needs_input_grad[0]: grad_add_matrix = grad_output if ctx.alpha != 1: grad_add_matrix = grad_add_matrix.mul(ctx.alpha) if any(ctx.needs_input_grad[1:]): batch_grad_output = (grad_output .unsqueeze(0) .expand(batch1.size(0), batch1.size(1), batch2.size(2))) if ctx.needs_input_grad[1]: grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) if ctx.beta != 1: grad_batch1 *= ctx.beta if ctx.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output) if ctx.beta != 1: grad_batch2 *= ctx.beta return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output): batch1, batch2 = ctx.saved_variables grad_add_batch = grad_batch1 = grad_batch2 = None if ctx.needs_input_grad[0]: grad_add_batch = grad_output if ctx.alpha != 1: grad_add_batch = grad_add_batch.mul(ctx.alpha) if ctx.needs_input_grad[1]: grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2)) if ctx.beta != 1: grad_batch1 *= ctx.beta if ctx.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output) if ctx.beta != 1: grad_batch2 *= ctx.beta return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input): assert len(input) == 2 a, b = input assert a.ndimension() == 2 or a.ndimension() == 3 assert a.dim() == b.dim() if a.ndimension() == 2: if self.transA: a = a.t() if self.transB: b = b.t() self.output.resize_(a.size(0), b.size(1)) torch.mm(a, b, out=self.output) else: if self.transA: a = a.transpose(2, 3) if self.transB: b = b.transpose(2, 3) self.output.resize_(a.size(0), a.size(1), b.size(2)) torch.bmm(a, b, out=self.output) return self.output
def updateOutput(self, input): M, v = input assert M.ndimension() == 2 or M.ndimension() == 3 if M.ndimension() == 2: assert v.ndimension() == 1 if self.trans: M = M.transpose(0, 1) self.output.resize_(M.size(0)) torch.mv(M, v, out=self.output) else: assert v.ndimension() == 2 if self.trans: M = M.transpose(1, 2) self.output.resize_(M.size(0), M.size(1), 1) torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1)) return self.output
def _test_btrisolve(self, cast): a = torch.FloatTensor((((1.3722, -0.9020), (1.8849, 1.9169)), ((0.7187, -1.1695), (-0.0139, 1.3572)), ((-1.6181, 0.7148), (1.3728, 0.1319)))) b = torch.FloatTensor(((4.02, 6.19), (-1.56, 4.00), (9.81, -4.09))) a, b = cast(a), cast(b) info = cast(torch.IntTensor()) LU_data, pivots = a.btrifact(info=info) self.assertEqual(info.abs().sum(), 0) x = torch.btrisolve(b, LU_data, pivots) b_ = torch.bmm(a, x.unsqueeze(2)).squeeze() self.assertEqual(b_, b)
def forward(self, input1): self.input1 = input1 output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size()) self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size()) for i in range(input1.size(0)): self.batchgrid[i] = self.grid if input1.is_cuda: self.batchgrid = self.batchgrid.cuda() output = output.cuda() batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3) batchgrid_temp.contiguous() input_temp = torch.transpose(input1, 1, 2) input_temp.contiguous() output_temp = torch.bmm(batchgrid_temp, input_temp) output = output_temp.view(-1, self.height, self.width, 2) output.contiguous() return output
def forward(self, input, context): """Propogate input through the network. input: batch x dim context: batch x sourceL x dim """ target = self.linear_in(input).unsqueeze(2) # batch x dim x 1 # Get attention attn = torch.bmm(context, target).squeeze(2) # batch x sourceL attn = self.sm(attn) attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim h_tilde = torch.cat((weighted_context, input), 1) h_tilde = self.tanh(self.linear_out(h_tilde)) return h_tilde, attn
def updateOutput(self, input): assert len(input) == 2 a, b = input assert a.ndimension() == 2 or a.ndimension() == 3 assert a.dim() == b.dim() if a.ndimension() == 2: if self.transA: a = a.t() if self.transB: b = b.t() self.output.resize_(a.size(0), b.size(1)) torch.mm(a, b, out=self.output) else: if self.transA: a = a.transpose(1, 2) if self.transB: b = b.transpose(1, 2) self.output.resize_(a.size(0), a.size(1), b.size(2)) torch.bmm(a, b, out=self.output) return self.output
def forward(self, x): batchsize = x.size()[0] trans = self.stn(x) x = x.transpose(2,1) x = torch.bmm(x, trans) x = x.transpose(2,1) x = F.relu(self.bn1(self.conv1(x))) pointfeat = x x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) x = self.mp1(x) x = x.view(-1, 1024) if self.global_feat: return x, trans else: x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([x, pointfeat], 1), trans
def calc_score(self, att_query, att_keys): """ att_query is: b x t_q x n att_keys is b x t_k x n return b x t_q x t_k scores """ b, t_k, n = list(att_keys.size()) t_q = att_query.size(1) if self.mode == 'bahdanau': att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) sum_qk = att_query + att_keys sum_qk = sum_qk.view(b * t_k * t_q, n) out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k) elif self.mode == 'dot_prod': out = torch.bmm(att_query, att_keys.transpose(1, 2)) if self.normalize: out.div_(n ** 0.5) return out
def forward(self, v, z): ''' :param v: batch_size (B) x latent_size (L) :param z: batch_size (B) x latent_size (L) :return: z_new = z - 2* v v_T / norm(v,2) * z ''' # v * v_T vvT = torch.bmm( v.unsqueeze(2), v.unsqueeze(1) ) # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L # v * v_T * z vvTz = torch.bmm( vvT, z.unsqueeze(2) ).squeeze(2) # A * z : batchdot( B x L x L * B x L x 1 ).squeeze(2) = (B x L x 1).squeeze(2) = B x L # calculate norm ||v||^2 norm_sq = torch.sum( v * v, 1 ) # calculate norm-2 for each row : B x 1 norm_sq = norm_sq.expand( norm_sq.size(0), v.size(1) ) # expand sizes : B x L # calculate new z z_new = z - 2 * vvTz / norm_sq # z - 2 * v * v_T * z / norm2(v) return z_new
def forward(self, L, z): ''' :param L: batch_size (B) x latent_size^2 (L^2) :param z: batch_size (B) x latent_size (L) :return: z_new = L*z ''' # L->tril(L) L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part) I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) ) if self.args.cuda: LTmask = LTmask.cuda() I = I.cuda() LTmask = Variable(LTmask) LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal # z_new = L * z z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L return z_new
def forward(ctx, theta, size): assert type(size) == torch.Size N, C, H, W = size ctx.size = size if theta.is_cuda: ctx.is_cuda = True AffineGridGenerator._enforce_cudnn(theta) grid = theta.new(N, H, W, 2) theta = theta.contiguous() torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W) else: ctx.is_cuda = False base_grid = theta.new(N, H, W, 3) linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]) base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0]) linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]) base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1]) base_grid[:, :, :, 2] = 1 ctx.base_grid = base_grid grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2)) grid = grid.view(N, H, W, 2) return grid
def backward(ctx, grad_grid): N, C, H, W = ctx.size assert grad_grid.size() == torch.Size([N, H, W, 2]) assert ctx.is_cuda == grad_grid.is_cuda if grad_grid.is_cuda: AffineGridGenerator._enforce_cudnn(grad_grid) grad_theta = grad_grid.new(N, 2, 3) grad_grid = grad_grid.contiguous() torch._C._cudnn_affine_grid_generator_backward(grad_theta, grad_grid, N, C, H, W) else: base_grid = ctx.base_grid grad_theta = torch.bmm( base_grid.view(N, H * W, 3).transpose(1, 2), grad_grid.view(N, H * W, 2)) grad_theta = grad_theta.transpose(1, 2) return grad_theta, None
def backward(ctx, grad_output): batch1, batch2 = ctx.saved_variables grad_add_matrix = grad_batch1 = grad_batch2 = None if ctx.needs_input_grad[0]: grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size) if ctx.alpha != 1: grad_add_matrix = grad_add_matrix.mul(ctx.alpha) if any(ctx.needs_input_grad[1:]): batch_grad_output = (grad_output .unsqueeze(0) .expand(batch1.size(0), batch1.size(1), batch2.size(2))) if ctx.needs_input_grad[1]: grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) if ctx.beta != 1: grad_batch1 *= ctx.beta if ctx.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output) if ctx.beta != 1: grad_batch2 *= ctx.beta return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output): batch1, batch2 = ctx.saved_variables grad_add_batch = grad_batch1 = grad_batch2 = None if ctx.needs_input_grad[0]: grad_add_batch = maybe_unexpand(grad_output, ctx.add_batch_size) if ctx.alpha != 1: grad_add_batch = grad_add_batch.mul(ctx.alpha) if ctx.needs_input_grad[1]: grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2)) if ctx.beta != 1: grad_batch1 *= ctx.beta if ctx.needs_input_grad[2]: grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output) if ctx.beta != 1: grad_batch2 *= ctx.beta return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def backward(self, gradE): A, X, C = self.saved_variables with torch.cuda.device_of(A): gradA = Variable(A.data.new().resize_as_(A.data)) gradX = Variable(A.data.new().resize_as_(X.data)) gradC = Variable(A.data.new().resize_as_(C.data)) if isinstance(A.data, torch.cuda.FloatTensor): with torch.cuda.device_of(A.data): encoding_lib.Encoding_Float_aggregate_backward(gradA.data, gradE.data, A.data, X.data, C.data) elif isinstance(A.data, torch.cuda.DoubleTensor): with torch.cuda.device_of(A.data): encoding_lib.Encoding_Double_aggregate_backward(gradA.data, gradE.data, A.data, X.data, C.data) else: raise RuntimeError('Unimplemented data type!') gradX.data.copy_(torch.bmm(A, gradE).data) gradC.data.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0).data) return gradA, gradX, gradC
def backward(self, gradE): A, X, C = self.saved_tensors with torch.cuda.device_of(A): gradA = A.new().resize_as_(A) gradX = A.new().resize_as_(X) gradC = A.new().resize_as_(C) if isinstance(A, torch.cuda.FloatTensor): with torch.cuda.device_of(A): encoding_lib.Encoding_Float_aggregateE_backward(gradA, gradE, A, X, C) elif isinstance(A, torch.cuda.DoubleTensor): with torch.cuda.device_of(A): encoding_lib.Encoding_Double_aggregateE_backward(gradA, gradE, A, X, C) else: raise RuntimeError('Unimplemented data type!') gradX.copy_(torch.bmm(A, gradE)) gradC.copy_((-gradE*A.sum(1).unsqueeze(2)).sum(0)) return gradA, gradX, gradC
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry): """ Solve KKT equations for the affine step""" nineq, nz, neq, nBatch = get_sizes(G, A) invQ_rx = rx.btrisolve(*Q_LU) if neq > 0: h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry, invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1) else: h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz w = -(h.btrisolve(*S_LU)) g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1) if neq > 0: g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1) g2 = -rs - w[:, neq:] dx = g1.btrisolve(*Q_LU) ds = g2 / d dz = w[:, neq:] dy = w[:, :neq] if neq > 0 else None return dx, ds, dz, dy
def forward(self, x): batchsize = x.size()[0] if self.trans: trans = self.stn(x) x = x.transpose(2,1) x = torch.bmm(x, trans) x = x.transpose(2,1) x = F.relu(self.bn1(self.conv1(x))) pointfeat = x x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) x,_ = torch.max(x, 2) x = x.view(-1, 1024) if self.trans: if self.global_feat: return x, trans else: x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([x, pointfeat], 1), trans else: return x
def forward(self, output, context): batch_size = output.size(0) hidden_size = output.size(2) input_size = context.size(1) # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len) attn = torch.bmm(output, context.transpose(1, 2)) if self.mask is not None: attn.data.masked_fill_(self.mask, -float('inf')) attn = F.softmax(attn.view(-1, input_size)).view(batch_size, -1, input_size) # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim) mix = torch.bmm(attn, context) # concat -> (batch, out_len, 2*dim) combined = torch.cat((mix, output), dim=2) # output -> (batch, out_len, dim) output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size) return output, attn
def forward(self, h, att_feats, p_att_feats): # The p_att_feats here is already projected att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size att = p_att_feats.view(-1, att_size, self.att_hid_size) att_h = self.h2att(h) # batch * att_hid_size att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size dot = att + att_h # batch * att_size * att_hid_size dot = F.tanh(dot) # batch * att_size * att_hid_size dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size dot = self.alpha_net(dot) # (batch * att_size) * 1 dot = dot.view(-1, att_size) # batch * att_size weight = F.softmax(dot) # batch * att_size att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size return att_res
def forward(self, input, hidden, encoder_output, encoder_outputs): embedded = self.embedding(input).view(1, 1, -1) embedded = self.dropout(embedded) attn_weights = F.softmax( self.attn(torch.cat((embedded[0], hidden[0]), 1))) attn_weights = attn_weights.cuda() if use_cuda else attn_weights attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) attn_applied = attn_applied.cuda() if use_cuda else attn_applied output = torch.cat((embedded[0], attn_applied[0]), 1) output = output.cuda() if use_cuda else output output = self.attn_combine(output).unsqueeze(0) for i in range(self.n_layers): output = F.relu(output) output = output.cuda() if use_cuda else output output, hidden = self.gru(output, hidden) output = F.log_softmax(self.out(output[0])) output = output.cuda() if use_cuda else output return output, hidden, attn_weights
def forward(self, query, ref): """ Args: query: is the hidden state of the decoder at the current time step. batch x dim ref: the set of hidden states from the encoder. sourceL x batch x hidden_dim """ # ref is now [batch_size x hidden_dim x sourceL] ref = ref.permute(1, 2, 0) q = self.project_query(query).unsqueeze(2) # batch x dim x 1 e = self.project_ref(ref) # batch_size x hidden_dim x sourceL # expand the query by sourceL # batch x dim x sourceL expanded_q = q.repeat(1, 1, e.size(2)) # batch x 1 x hidden_dim v_view = self.v.unsqueeze(0).expand( expanded_q.size(0), len(self.v)).unsqueeze(1) # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL] u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1) if self.use_tanh: logits = self.C * self.tanh(u) else: logits = u return e, logits
def forward(self, inputs): """ Args: inputs: [embedding_dim x batch_size x sourceL] of embedded inputs """ (encoder_hx, encoder_cx) = self.encoder.enc_init_state encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) # encoder forward pass enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) # grab the hidden state and process it via the process block process_block_state = enc_h_t[-1] for i in range(self.n_process_block_iters): ref, logits = self.process_block(process_block_state, enc_outputs) process_block_state = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2) # produce the final scalar output out = self.decoder(process_block_state) return out
def forward(self, dec_out, enc_outs, enc_att=None, mask=None): """ Parameters: ----------- - dec_out: torch.Tensor(batch_size x hid_dim) - enc_outs: torch.Tensor(seq_len x batch_size x hid_dim) - enc_att: (optional), torch.Tensor(seq_len x batch_size x att_dim) - mask: (optional), torch.ByteTensor(batch_size x seq_len) """ # (batch x seq_len) weights = self.scorer(dec_out, enc_outs, enc_att=enc_att) if mask is not None: # weights = weights * mask.float() weights.data.masked_fill_(1 - mask.data, -float('inf')) weights = F.softmax(weights, dim=1) # (eq 7) context = weights.unsqueeze(1).bmm(enc_outs.transpose(0, 1)).squeeze(1) # (eq 5) linear out combining context and hidden context = F.tanh(self.linear_out(torch.cat([context, dec_out], 1))) return context, weights
def _access(self, memory_vb): # write """ variables needed: wl_curr_vb: [batch_size x num_heads x mem_hei] erase_vb: [batch_size x num_heads x mem_wid] -> /in (0, 1) add_vb: [batch_size x num_heads x mem_wid] -> w/ no restrictions in range memory_vb: [batch_size x mem_hei x mem_wid] returns: memory_vb: [batch_size x mem_hei x mem_wid] NOTE: IMPORTANT: https://github.com/deepmind/dnc/issues/10 """ # first let's do erasion weighted_erase_vb = torch.bmm(self.wl_curr_vb.contiguous().view(-1, self.mem_hei, 1), self.erase_vb.contiguous().view(-1, 1, self.mem_wid)).view(-1, self.num_heads, self.mem_hei, self.mem_wid) keep_vb = torch.prod(1. - weighted_erase_vb, dim=1) memory_vb = memory_vb * keep_vb # finally let's write (do addition) return memory_vb + torch.bmm(self.wl_curr_vb.transpose(1, 2), self.add_vb)
def forward(ctx, theta, size): assert type(size) == torch.Size N, C, H, W = size ctx.size = size if theta.is_cuda: AffineGridGenerator._enforce_cudnn(theta) assert False ctx.is_cuda = False base_grid = theta.new(N, H, W, 3) linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]) base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0]) linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]) base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1]) base_grid[:, :, :, 2] = 1 ctx.base_grid = base_grid grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2)) grid = grid.view(N, H, W, 2) return grid