我们从Python开源项目中,提取了以下11个代码示例,用于说明如何使用torch.cross()。
def test_cross(self): x = torch.rand(100, 3, 100) y = torch.rand(100, 3, 100) res1 = torch.cross(x, y) res2 = torch.Tensor() torch.cross(res2, x, y) self.assertEqual(res1, res2)
def forward(self, input, other): self.save_for_backward(input, other) return torch.cross(input, other, self.dim)
def backward(self, grad_output): input, other = self.saved_tensors grad_input = torch.cross(other, grad_output, self.dim) grad_other = torch.cross(grad_output, input, self.dim) return grad_input, grad_other
def test_cross(self): x = torch.rand(100, 3, 100) y = torch.rand(100, 3, 100) res1 = torch.cross(x, y) res2 = torch.Tensor() torch.cross(x, y, out=res2) self.assertEqual(res1, res2)
def forward(ctx, input, other, dim=-1): ctx.dim = dim ctx.save_for_backward(input, other) return torch.cross(input, other, ctx.dim)
def backward(ctx, grad_output): input, other = ctx.saved_variables grad_input = other.cross(grad_output, ctx.dim) grad_other = grad_output.cross(input, ctx.dim) return grad_input, grad_other, None