我们从Python开源项目中,提取了以下40个代码示例,用于说明如何使用torch.matmul()。
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: projected_tensor_1 = torch.matmul(tensor_1, self._tensor_1_projection) projected_tensor_2 = torch.matmul(tensor_2, self._tensor_2_projection) # Here we split the last dimension of the tensors from (..., projected_dim) to # (..., num_heads, projected_dim / num_heads), using tensor.view(). last_dim_size = projected_tensor_1.size(-1) // self.num_heads new_shape = list(projected_tensor_1.size())[:-1] + [self.num_heads, last_dim_size] split_tensor_1 = projected_tensor_1.view(*new_shape) last_dim_size = projected_tensor_2.size(-1) // self.num_heads new_shape = list(projected_tensor_2.size())[:-1] + [self.num_heads, last_dim_size] split_tensor_2 = projected_tensor_2.view(*new_shape) # And then we pass this off to our internal similarity function. Because the similarity # functions don't care what dimension their input has, and only look at the last dimension, # we don't need to do anything special here. It will just compute similarity on the # projection dimension for each head, returning a tensor of shape (..., num_heads). return self._internal_similarity(split_tensor_1, split_tensor_2)
def backward(self, grad_output): means, = self.saved_tensors T = self.R.shape[0] dim = means.dim() # Add batch axis if necessary if dim == 2: T_, D = means.shape B = 1 grad_output = grad_output.view(B, T, -1) else: B, T_, D = means.shape grad = torch.matmul(self.R.transpose(0, 1), grad_output) reshaped = not (T == T_) if not reshaped: grad = grad.view(B, self.num_windows, T, -1).transpose( 1, 2).contiguous().view(B, T, D) if dim == 2: return grad.view(-1, D) return grad
def test_toeplitz_matmul_batch(): cols = torch.Tensor([ [1, 6, 4, 5], [2, 3, 1, 0], [1, 2, 3, 1], ]) rows = torch.Tensor([ [1, 2, 1, 1], [2, 0, 0, 1], [1, 5, 1, 0], ]) rhs_mats = torch.randn(3, 4, 2) # Actual lhs_mats = torch.zeros(3, 4, 4) for i, (col, row) in enumerate(zip(cols, rows)): lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row)) actual = torch.matmul(lhs_mats, rhs_mats) # Fast toeplitz res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats) assert utils.approx_equal(res, actual)
def diag(self): batch_size, n_data, n_interp = self.left_interp_indices.size() # Batch compute the non-zero values of the outer products w_left^k w_right^k^T left_interp_values = self.left_interp_values.unsqueeze(3) right_interp_values = self.right_interp_values.unsqueeze(2) interp_values = torch.matmul(left_interp_values, right_interp_values) # Batch compute Toeplitz values that will be non-zero for row k left_interp_indices = self.left_interp_indices.unsqueeze(3).expand(batch_size, n_data, n_interp, n_interp) left_interp_indices = left_interp_indices.contiguous() right_interp_indices = self.right_interp_indices.unsqueeze(2).expand(batch_size, n_data, n_interp, n_interp) right_interp_indices = right_interp_indices.contiguous() batch_interp_indices = Variable(left_interp_indices.data.new(batch_size)) torch.arange(0, batch_size, out=batch_interp_indices.data) batch_interp_indices = batch_interp_indices.view(batch_size, 1, 1, 1) batch_interp_indices = batch_interp_indices.expand(batch_size, n_data, n_interp, n_interp).contiguous() base_var_vals = self.base_lazy_variable._batch_get_indices(batch_interp_indices.view(-1), left_interp_indices.view(-1), right_interp_indices.view(-1)) base_var_vals = base_var_vals.view(left_interp_indices.size()) diag = (interp_values * base_var_vals).sum(3).sum(2).sum(0) return diag
def diag(self): n_data, n_interp = self.left_interp_indices.size() # Batch compute the non-zero values of the outer products w_left^k w_right^k^T left_interp_values = self.left_interp_values.unsqueeze(2) right_interp_values = self.right_interp_values.unsqueeze(1) interp_values = torch.matmul(left_interp_values, right_interp_values) # Batch compute Toeplitz values that will be non-zero for row k left_interp_indices = self.left_interp_indices.unsqueeze(2).expand(n_data, n_interp, n_interp).contiguous() right_interp_indices = self.right_interp_indices.unsqueeze(1).expand(n_data, n_interp, n_interp).contiguous() base_var_vals = self.base_lazy_variable._get_indices(left_interp_indices.view(-1), right_interp_indices.view(-1)) base_var_vals = base_var_vals.view(left_interp_indices.size()) diag = (interp_values * base_var_vals).sum(2).sum(1) return diag
def test_matmul_out(self): def check_matmul(size1, size2): a = torch.randn(size1) b = torch.randn(size2) expected = torch.matmul(a, b) out = torch.Tensor(expected.size()).zero_() # make output non-contiguous out = out.transpose(-1, -2).contiguous().transpose(-1, -2) self.assertFalse(out.is_contiguous()) torch.matmul(a, b, out=out) self.assertEqual(expected, out) check_matmul((2, 3, 4), (2, 4, 5)) check_matmul((2, 3, 4), (4, 5))
def predict(self, x): batch_size, dims = x.size() query = F.normalize(self.query_proj(x), dim=1) # Find the k-nearest neighbors of the query scores = torch.matmul(query, torch.t(self.keys_var)) cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1) # softmax of cosine similarities - embedding softmax_score = F.softmax(self.softmax_temperature * cosine_similarity) # retrive memory values - prediction y_hat_indices = topk_indices_var.data[:, 0] y_hat = self.values[y_hat_indices] return y_hat, softmax_score
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: intermediate = torch.matmul(tensor_1, self._weight_matrix) result = (intermediate * tensor_2).sum(dim=-1) return self._activation(result + self._bias)
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2]) dot_product = torch.matmul(combined_tensors, self._weight_vector) return self._activation(dot_product + self._bias)
def forward(self, means): # TODO: remove this self.save_for_backward(means) T = self.R.shape[0] dim = means.dim() # Add batch axis if necessary if dim == 2: T_, D = means.shape B = 1 means = means.view(B, T_, D) else: B, T_, D = means.shape # Check if means has proper shape reshaped = not (T == T_) if not reshaped: static_dim = means.shape[-1] // self.num_windows reshaped_means = means.contiguous().view( B, T, self.num_windows, -1).transpose( 1, 2).contiguous().view(B, -1, static_dim) else: static_dim = means.shape[-1] reshaped_means = means out = torch.matmul(self.R, reshaped_means) if dim == 2: return out.view(-1, static_dim) return out
def pairwise_ranking_loss(margin, x, v): zero = torch.zeros(1) diag_margin = margin * torch.eye(x.size(0)) if not args.no_cuda: zero, diag_margin = zero.cuda(), diag_margin.cuda() zero, diag_margin = Variable(zero), Variable(diag_margin) x = x / torch.norm(x, 2, 1, keepdim=True) v = v / torch.norm(v, 2, 1, keepdim=True) prod = torch.matmul(x, v.transpose(0, 1)) diag = torch.diag(prod) for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0)
def matmul(self, other): """Matrix product of two tensors. See :func:`torch.matmul`.""" return torch.matmul(self, other)
def __matmul__(self, other): if not torch.is_tensor(other): return NotImplemented return self.matmul(other)
def matmul(self, other): return torch.matmul(self, other)
def __matmul__(self, other): if not isinstance(other, Variable): return NotImplemented return self.matmul(other)
def test_toeplitz_matmul(): col = torch.Tensor([1, 6, 4, 5]) row = torch.Tensor([1, 2, 1, 1]) rhs_mat = torch.randn(4, 2) # Actual lhs_mat = utils.toeplitz.toeplitz(col, row) actual = torch.matmul(lhs_mat, rhs_mat) # Fast toeplitz res = utils.toeplitz.toeplitz_matmul(col, row, rhs_mat) assert utils.approx_equal(res, actual)
def test_toeplitz_matmul_batchmat(): col = torch.Tensor([1, 6, 4, 5]) row = torch.Tensor([1, 2, 1, 1]) rhs_mat = torch.randn(3, 4, 2) # Actual lhs_mat = utils.toeplitz.toeplitz(col, row) actual = torch.matmul(lhs_mat.unsqueeze(0), rhs_mat) # Fast toeplitz res = utils.toeplitz.toeplitz_matmul(col.unsqueeze(0), row.unsqueeze(0), rhs_mat) assert utils.approx_equal(res, actual)
def test_left_interp_on_a_vector(): vector = torch.randn(6) res = left_interp(interp_indices, interp_values, Variable(vector)).data actual = torch.matmul(interp_matrix, vector) assert approx_equal(res, actual)
def test_batch_left_interp_on_a_vector(): vector = torch.randn(6) actual = torch.matmul(batch_interp_matrix, vector.unsqueeze(-1).unsqueeze(0)).squeeze(0) res = left_interp(batch_interp_indices, batch_interp_values, Variable(vector)).data assert approx_equal(res, actual)
def test_batch_left_interp_on_a_matrix(): batch_matrix = torch.randn(6, 3) res = left_interp(batch_interp_indices, batch_interp_values, Variable(batch_matrix)).data actual = torch.matmul(batch_interp_matrix, batch_matrix.unsqueeze(0)) assert approx_equal(res, actual)
def test_batch_left_interp_on_a_batch_matrix(): batch_matrix = torch.randn(2, 6, 3) res = left_interp(batch_interp_indices, batch_interp_values, Variable(batch_matrix)).data actual = torch.matmul(batch_interp_matrix, batch_matrix) assert approx_equal(res, actual)
def test_forward_batch(): i = torch.LongTensor([[0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1], [2, 0, 2, 2, 0, 2]]) v = torch.FloatTensor([3, 4, 5, 6, 7, 8]) sparse = torch.sparse.FloatTensor(i, v, torch.Size([2, 2, 3])) dense = Variable(torch.randn(2, 3, 3)) res = gpytorch.dsmm(Variable(sparse), dense) actual = torch.matmul(Variable(sparse.to_dense()), dense) assert(torch.norm(res.data - actual.data) < 1e-5)
def test_backward_batch(): i = torch.LongTensor([[0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1], [2, 0, 2, 2, 0, 2]]) v = torch.FloatTensor([3, 4, 5, 6, 7, 8]) sparse = torch.sparse.FloatTensor(i, v, torch.Size([2, 2, 3])) dense = Variable(torch.randn(2, 3, 4), requires_grad=True) dense_copy = Variable(dense.data.clone(), requires_grad=True) grad_output = torch.randn(2, 2, 4) res = gpytorch.dsmm(Variable(sparse), dense) res.backward(grad_output) actual = torch.matmul(Variable(sparse.to_dense()), dense_copy) actual.backward(grad_output) assert(torch.norm(dense.grad.data - dense_copy.grad.data) < 1e-5)
def _derivative_quadratic_form_factory(self, lhs, rhs): def closure(left_factor, right_factor): left_grad = left_factor.transpose(-1, -2).matmul(right_factor.matmul(rhs.transpose(-1, -2))) right_grad = lhs.transpose(-1, -2).matmul(left_factor.transpose(-1, -2)).matmul(right_factor) return left_grad, right_grad return closure
def evaluate(self): return torch.matmul(self.lhs, self.rhs)
def _matmul_closure_factory(self, tensor): def closure(rhs_tensor): return torch.matmul(tensor, rhs_tensor) return closure
def diag(self): """ Gets the diagonal of the Kronecker Product matrix wrapped by this object. """ if len(self.J_lefts[0]) != len(self.J_rights[0]): raise RuntimeError('diag not supported for non-square interpolated Toeplitz matrices.') d, n_data, n_interp = self.J_lefts.size() n_grid = len(self.columns[0]) left_interps_values = self.C_lefts.unsqueeze(3) right_interps_values = self.C_rights.unsqueeze(2) interps_values = torch.matmul(left_interps_values, right_interps_values) left_interps_indices = self.J_lefts.unsqueeze(3).expand(d, n_data, n_interp, n_interp) right_interps_indices = self.J_rights.unsqueeze(2).expand(d, n_data, n_interp, n_interp) toeplitz_indices = (left_interps_indices - right_interps_indices).fmod(n_grid).abs().long() toeplitz_vals = Variable(self.columns.data.new(d, n_data * n_interp * n_interp).zero_()) mask = self.columns.data.new(d, n_data * n_interp * n_interp).zero_() for i in range(d): mask[i] += torch.ones(n_data * n_interp * n_interp) temp = self.columns.index_select(1, Variable(toeplitz_indices.view(d, -1)[i])) toeplitz_vals += Variable(mask) * temp.view(toeplitz_indices.size()) mask[i] -= torch.ones(n_data * n_interp * n_interp) diag = (Variable(interps_values) * toeplitz_vals).sum(3).sum(2) diag = diag.prod(0) if self.added_diag is not None: diag += self.added_diag return diag
def forward(self, input_d, input_e, mask_d=None, mask_e=None): ''' Args: input_d: Tensor the decoder input tensor with shape = [batch, length_decoder, input_size] input_e: Tensor the child input tensor with shape = [batch, length_encoder, input_size] mask_d: Tensor or None the mask tensor for decoder with shape = [batch, length_decoder] mask_e: Tensor or None the mask tensor for encoder with shape = [batch, length_encoder] Returns: Tensor the energy tensor with shape = [batch, num_label, length, length] ''' assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.' batch, length_decoder, _ = input_d.size() _, length_encoder, _ = input_e.size() # compute decoder part: [batch, length_decoder, input_size_decoder] * [input_size_decoder, hidden_size] # the output shape is [batch, length_decoder, hidden_size] # then --> [batch, 1, length_decoder, hidden_size] out_d = torch.matmul(input_d, self.W_d).unsqueeze(1) # compute decoder part: [batch, length_encoder, input_size_encoder] * [input_size_encoder, hidden_size] # the output shape is [batch, length_encoder, hidden_size] # then --> [batch, length_encoder, 1, hidden_size] out_e = torch.matmul(input_e, self.W_e).unsqueeze(2) # add them together [batch, length_encoder, length_decoder, hidden_size] out = F.tanh(out_d + out_e + self.b) # product with v # [batch, length_encoder, length_decoder, hidden_size] * [hidden, num_label] # [batch, length_encoder, length_decoder, num_labels] # then --> [batch, num_labels, length_decoder, length_encoder] return torch.matmul(out, self.v).transpose(1, 3)
def matmul(self, other): r"""Matrix product of two tensors. See :func:`torch.matmul`.""" return torch.matmul(self, other)
def _forward_age_cls(self, feat): ''' Input: feat: CNN feature (ReLUed) Output: age_out: output age prediction (for evaluation) age_fc: final fc layer output (for compute loss) ''' #fc_out = self.age_cls(feat) fc_out = self.age_cls(F.relu(feat)) if self.opts.cls_type == 'dex': # Deep EXpectation age_scale = np.arange(self.opts.min_age, self.opts.max_age + 1, 1.0) age_scale = Variable(fc_out.data.new(age_scale)).unsqueeze(1) age_out = torch.matmul(F.softmax(fc_out), age_scalei).view(-1) elif self.opts.cls_type == 'oh': # Ordinal Hyperplane fc_out = F.sigmoid(fc_out) age_out = fc_out.sum(dim = 1) + self.opts.min_age elif self.opts.cls_type == 'reg': # Regression age_out = fc_out.view(-1) age_out = age_out + self.opts.min_age return age_out, fc_out
def _compute_age(self, feat_relu): ''' input: feat: output of feat_embed layer (after relu) output: age_out age_fc_out ''' age_fc_out = self.age_cls(feat_relu) if self.opts.cls_type == 'dex': # deep expectation age_scale = np.arange(self.opts.min_age, self.opts.max_age + 1, 1.0) age_scale = Variable(age_fc_out.data.new(age_scale)).unsqueeze(1) age_out = torch.matmul(F.softmax(age_fc_out), age_scale).view(-1) elif self.opts.cls_type == 'oh': # ordinal hyperplane age_fc_out = F.sigmoid(age_fc_out) age_out = age_fc_out.sum(dim = 1) + self.opts.min_age elif self.opts.cls_type == 'reg': # regression age_out = self.age_fc_out.view(-1) + self.opts.min_age return age_out, age_fc_out
def forward(self, encoded_question, question_length, encoded_support, support_length, correct_start, answer2question, is_eval): # casting long_tensor = torch.cuda.LongTensor if encoded_question.is_cuda else torch.LongTensor answer2question = answer2question.type(long_tensor) # computing single time attention over question attention_scores = self._linear_question_attention(encoded_question) q_mask = misc.mask_for_lengths(question_length) attention_scores = attention_scores.squeeze(2) + q_mask question_attention_weights = F.softmax(attention_scores) question_state = torch.matmul(question_attention_weights.unsqueeze(1), encoded_question).squeeze(1) # Prediction # start start_input = torch.cat([question_state.unsqueeze(1) * encoded_support, encoded_support], 2) q_start_state = self._linear_q_start(start_input) + self._linear_q_start_q(question_state).unsqueeze(1) start_scores = self._linear_start_scores(F.relu(q_start_state)).squeeze(2) support_mask = misc.mask_for_lengths(support_length) start_scores = start_scores + support_mask _, predicted_start_pointer = start_scores.max(1) def align(t): return torch.index_select(t, 0, answer2question) if is_eval: start_pointer = predicted_start_pointer else: # use correct start during training, because p(end|start) should be optimized start_pointer = correct_start.type(long_tensor) predicted_start_pointer = align(predicted_start_pointer) start_scores = align(start_scores) start_input = align(start_input) encoded_support = align(encoded_support) question_state = align(question_state) support_mask = align(support_mask) # end u_s = [] for b, p in enumerate(start_pointer): u_s.append(encoded_support[b, p.data[0]]) u_s = torch.stack(u_s) end_input = torch.cat([encoded_support * u_s.unsqueeze(1), start_input], 2) q_end_state = self._linear_q_end(end_input) + self._linear_q_end_q(question_state).unsqueeze(1) end_scores = self._linear_end_scores(F.relu(q_end_state)).squeeze(2) end_scores = end_scores + support_mask max_support = support_length.max().data[0] if is_eval: end_scores += misc.mask_for_lengths(start_pointer, max_support, mask_right=False) _, predicted_end_pointer = end_scores.max(1) return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer
def test_functional_blas(self): def compare(fn, *args): unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg for arg in args) unpacked_result = fn(*unpacked_args) packed_result = fn(*args).data # if non-Variable torch function returns a scalar, compare to scalar if not torch.is_tensor(unpacked_result): assert packed_result.dim() == 1 assert packed_result.nelement() == 1 packed_result = packed_result[0] self.assertEqual(packed_result, unpacked_result) def test_blas_add(fn, x, y, z): # Checks all signatures compare(fn, x, y, z) compare(fn, 0.5, x, y, z) compare(fn, 0.5, x, 0.25, y, z) def test_blas(fn, x, y): compare(fn, x, y) test_blas(torch.mm, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas_add(torch.addmm, Variable(torch.randn(2, 4)), Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas(torch.mv, Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas_add(torch.addmv, Variable(torch.randn(2)), Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas(torch.ger, Variable(torch.randn(5)), Variable(torch.randn(6))) test_blas_add(torch.addr, Variable(torch.randn(5, 6)), Variable(torch.randn(5)), Variable(torch.randn(6))) test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6))) test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4))) test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6))) test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4))) test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4))) test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10))) test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4)))
def forward(self, input_d, input_e, mask_d=None, mask_e=None): ''' Args: input_d: Tensor the decoder input tensor with shape = [batch, length_decoder, input_size] input_e: Tensor the child input tensor with shape = [batch, length_encoder, input_size] mask_d: Tensor or None the mask tensor for decoder with shape = [batch, length_decoder] mask_e: Tensor or None the mask tensor for encoder with shape = [batch, length_encoder] Returns: Tensor the energy tensor with shape = [batch, num_label, length, length] ''' assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.' batch, length_decoder, _ = input_d.size() _, length_encoder, _ = input_e.size() # compute decoder part: [num_label, input_size_decoder] * [batch, input_size_decoder, length_decoder] # the output shape is [batch, num_label, length_decoder] out_d = torch.matmul(self.W_d, input_d.transpose(1, 2)).unsqueeze(3) # compute decoder part: [num_label, input_size_encoder] * [batch, input_size_encoder, length_encoder] # the output shape is [batch, num_label, length_encoder] out_e = torch.matmul(self.W_e, input_e.transpose(1, 2)).unsqueeze(2) # output shape [batch, num_label, length_decoder, length_encoder] if self.biaffine: # compute bi-affine part # [batch, 1, length_decoder, input_size_decoder] * [num_labels, input_size_decoder, input_size_encoder] # output shape [batch, num_label, length_decoder, input_size_encoder] output = torch.matmul(input_d.unsqueeze(1), self.U) # [batch, num_label, length_decoder, input_size_encoder] * [batch, 1, input_size_encoder, length_encoder] # output shape [batch, num_label, length_decoder, length_encoder] output = torch.matmul(output, input_e.unsqueeze(1).transpose(2, 3)) output = output + out_d + out_e + self.b else: output = out_d + out_d + self.b if mask_d is not None: output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) return output
def query(self, x, y, predict=False): """ Compute the nearest neighbor of the input queries. Arguments: x: A normalized matrix of queries of size (batch_size x key_dim) y: A matrix of correct labels (batch_size x 1) Returns: y_hat, A (batch-size x 1) matrix - the nearest neighbor to the query in memory_size softmax_score, A (batch_size x 1) matrix - A normalized score measuring the similarity between query and nearest neighbor loss - average loss for memory module """ batch_size, dims = x.size() query = F.normalize(self.query_proj(x), dim=1) #query = F.normalize(torch.matmul(x, self.query_proj), dim=1) # Find the k-nearest neighbors of the query scores = torch.matmul(query, torch.t(self.keys_var)) cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1) # softmax of cosine similarities - embedding softmax_score = F.softmax(self.softmax_temperature * cosine_similarity) # retrive memory values - prediction topk_indices = topk_indices_var.detach().data y_hat_indices = topk_indices[:, 0] y_hat = self.values[y_hat_indices] loss = None if not predict: # Loss Function # topk_indices = (batch_size x topk) # topk_values = (batch_size x topk x value_size) # collect the memory values corresponding to the topk scores batch_size, topk_size = topk_indices.size() flat_topk = flatten(topk_indices) flat_topk_values = self.values[topk_indices] topk_values = flat_topk_values.resize_(batch_size, topk_size) correct_mask = torch.eq(topk_values, torch.unsqueeze(y.data, dim=1)).float() correct_mask_var = ag.Variable(correct_mask, requires_grad=False) pos_score, pos_idx = torch.topk(torch.mul(cosine_similarity, correct_mask_var), 1, dim=1) neg_score, neg_idx = torch.topk(torch.mul(cosine_similarity, 1-correct_mask_var), 1, dim=1) # zero-out correct scores if there are no correct values in topk values mask = 1.0 - torch.eq(torch.sum(correct_mask_var, dim=1), 0.0).float() pos_score = torch.mul(pos_score, torch.unsqueeze(mask, dim=1)) #print(pos_score, neg_score) loss = MemoryLoss(pos_score, neg_score, self.margin) # Update memory self.update(query, y, y_hat, y_hat_indices) return y_hat, softmax_score, loss
def _age_gradient(self, feat_age): ''' compute age branch gradient direction in age_embed layer input: feat_age: output of age_embed layer (before relu) ''' cls = self.age_cls feat = feat_age.detach() feat.requires_grad = True feat.volatile = False feat = feat.clone() feat.retain_grad() age_fc_out = cls.cls(cls.relu(feat)) if self.opts.cls_type == 'dex': # deep expectation age_scale = np.arange(self.opts.min_age, self.opts.max_age + 1, 1.0) age_scale = Variable(age_fc_out.data.new(age_scale)).unsqueeze(1) age_out = torch.matmul(F.softmax(age_fc_out), age_scale).view(-1) elif self.opts.cls_type == 'oh': # ordinal hyperplane age_fc_out = F.sigmoid(age_fc_out) age_out = age_fc_out.sum(dim = 1) + self.opts.min_age elif self.opts.cls_type == 'reg': # regression age_out = self.age_fc_out.view(-1) + self.opts.min_age age_out.sum().backward() age_grad = feat.grad # normalization age_grad = age_grad / age_grad.norm(p = 2, dim = 1, keepdim = True) age_grad.detach_() age_grad.volatile = False age_grad.requires_grad = False cls.cls.zero_grad() return age_grad