我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用torch.autograd.Function()。
def test_save_none_for_backward(self): test_case = self class MyFn(Function): def forward(self, input): self.save_for_backward(None, input, None) return input * input def backward(self, grad_output): n1, input, n2 = self.saved_tensors test_case.assertIsNone(n1) test_case.assertIsNone(n2) return 2 * input * grad_output x = Variable(torch.randn(5, 5), requires_grad=True) y = MyFn()(x) y.sum().backward() self.assertEqual(x.grad.data, 2 * x.data)
def test_mark_non_differentiable(self): class MyFunction(Function): @staticmethod def forward(ctx, input): output = input > 0 ctx.mark_non_differentiable(output) return output @staticmethod def backward(ctx, grad_output): return (grad_output * 0).type(torch.DoubleTensor) x = Variable(torch.randn(5, 5), requires_grad=True) mask = MyFunction.apply(x) self.assertFalse(mask.requires_grad) y = x.masked_fill(mask, 0) y.sum().backward()
def test_assign_traces(self): """Check that output Variables are assigned traces before they are saved.""" @traceable class MyFn(Function): @staticmethod def forward(ctx, a): out = a * 2 ctx.save_for_backward(out) return out @staticmethod def backward(ctx, grad_a): a, = ctx.saved_variables return a * grad_a x = Variable(torch.randn(10, 10), requires_grad=True) trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1) out.sum().backward() torch._C._jit_pass_dce(trace) self.assertExpected(str(trace))
def test_backward_device(self): # check that current device matches the variable's device device = [None] class Identity(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, grad_output): device[0] = torch.cuda.current_device() return grad_output.clone() v = Variable(torch.randn(1).cuda(1), requires_grad=True) Identity.apply(v).backward() self.assertEqual(device[0], 1)
def test_reentrant(self): y_data = torch.randn(2, 2) class Reenter(Function): @staticmethod def forward(ctx, x_data): ctx.x = Variable(x_data, requires_grad=True) ctx.y = Variable(y_data, requires_grad=True) ctx.output_var = ctx.x * ctx.y return ctx.output_var.data @staticmethod def backward(ctx, grad_output): ctx.output_var.sum().backward() return ctx.x.grad * grad_output x = Variable(torch.randn(2, 2), requires_grad=True) out = Reenter.apply(x) out.sum().backward() self.assertEqual(x.grad.data, y_data)
def test_symbolic_mismatch(self): class MyFun(Function): @staticmethod def symbolic(g, x): # The inside of this function should never be invoked, because # we will fail due to an argument mismatch first. assert False @staticmethod def forward(ctx, x, y): return x + y x = Variable(torch.randn(2, 2).fill_(1.0)) y = Variable(torch.randn(2, 2).fill_(1.0)) with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"): export_to_string(FuncModule(MyFun().apply), (x, y)) # TODO: Do an nn style test for these
def test_assign_traces(self): """Check that output Variables are assigned traces before they are saved.""" @traceable class MyFn(Function): @staticmethod def forward(ctx, a): out = a * 2 ctx.save_for_backward(out) return out @staticmethod def backward(ctx, grad_a): a, = ctx.saved_variables return a * grad_a x = Variable(torch.randn(10, 10), requires_grad=True) trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1) out.sum().backward() torch._C._jit_pass_dce(trace) self.assertExpectedTrace(trace)
def test_inplace_check(self): class MyInplaceFn(Function): @staticmethod def forward(self, x): x.add_(1) self.mark_dirty(x) return x @staticmethod def backward(self, grad): return grad @torch.jit.compile(nderivs=0) def fn(x): return MyInplaceFn.apply(x) x = Variable(torch.randn(5, 5)) fn(x) # trace with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'): fn(x)
def test_function_returns_input(self): class MyFunction(Function): @staticmethod def forward(ctx, x): return x @staticmethod def backward(ctx, grad): return grad * 2 v = Variable(torch.ones(1), requires_grad=True) MyFunction.apply(v).backward() self.assertEqual(v.grad.data.tolist(), [2]) v.grad.data.zero_() MyFunction.apply(v.clone()).backward() self.assertEqual(v.grad.data.tolist(), [2])
def test_mark_non_differentiable_mixed(self): class MyFunction(Function): @staticmethod def forward(ctx, input): a = input + 1 b = input + 2 ctx.mark_non_differentiable(a) return a, b @staticmethod def backward(ctx, grad_a, grad_b): self.assertTrue((grad_a == 0).all()) self.assertTrue((grad_b == 1).all()) return grad_b x = Variable(torch.randn(5, 5), requires_grad=True) a, b = MyFunction.apply(x) self.assertFalse(a.requires_grad) self.assertTrue(b.requires_grad) b.sum().backward() self.assertEqual(x.grad.data, torch.ones(5, 5))
def test_mark_non_differentiable_none(self): # This used to segfault because MyFunction would send back null # gradients to MulBackward, which is implemented in C++. C++ # implemented functions expect incoming grad_ouptuts to be non-null. class MyFunction(Function): @staticmethod def forward(ctx, input): output = input.clone() ctx.mark_non_differentiable(output) return output @staticmethod def backward(ctx, grad_output): return None x = Variable(torch.randn(5, 5), requires_grad=True) r = MyFunction.apply(x * x) (r * x).sum().backward()
def test_reentrant(self): y_data = torch.randn(2, 2) class Reenter(Function): @staticmethod def forward(ctx, x_data): ctx.x = Variable(x_data, requires_grad=True) ctx.y = Variable(y_data, requires_grad=True) ctx.output_var = ctx.x * ctx.y return ctx.output_var.data @staticmethod def backward(ctx, grad_output): ctx.output_var.sum().backward() return ctx.x.grad * grad_output x = Variable(torch.randn(2, 2), requires_grad=True) out = Reenter.apply(x) out.sum().backward(create_graph=True) self.assertEqual(x.grad.data, y_data)
def test_inplace_view_python(self): # in-place modifications of Python-autograd created view a = Variable(torch.randn(4, 4), requires_grad=True) b = Variable(torch.randn(2, 2), requires_grad=True) class PyAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.mark_dirty(x) x.add_(y) return x @staticmethod def backward(ctx, grad): return grad, grad def func(root, b): x = root.clone() PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b) PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b) return x gradcheck(func, [a, b], raise_exception=True) go = Variable(torch.randn(a.size()), requires_grad=True) gradgradcheck(func, (a, b), (go,))
def test_function(self): class MyFunction(Function): @staticmethod def forward(ctx, tensor1, scalar, tensor2): ctx.scalar = scalar ctx.save_for_backward(tensor1, tensor2) return tensor1 + scalar * tensor2 + tensor1 * tensor2 @staticmethod def backward(ctx, grad_output): var1, var2 = ctx.saved_variables # NOTE: self is the test case here self.assertIsInstance(var1, Variable) self.assertIsInstance(var2, Variable) self.assertIsInstance(grad_output, Variable) return (grad_output + grad_output * var2, None, grad_output * ctx.scalar + grad_output * var1) x, y = self._function_test(MyFunction) x_grad_desc = graph_desc(x.grad.grad_fn) y_grad_desc = graph_desc(y.grad.grad_fn) self.assertEqual( x_grad_desc, 'Identity(AddBackward(ExpandBackward(AccumulateGrad()), ' 'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))') self.assertEqual( y_grad_desc, 'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), ' 'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
def test_once_differentiable(self): class MyFunction(Function): @staticmethod def forward(ctx, tensor1, scalar, tensor2): ctx.scalar = scalar ctx.save_for_backward(tensor1, tensor2) return tensor1 + scalar * tensor2 + tensor1 * tensor2 @staticmethod @once_differentiable def backward(ctx, grad_output): t1, t2 = ctx.saved_tensors # NOTE: self is the test case here self.assertTrue(torch.is_tensor(t1)) self.assertTrue(torch.is_tensor(t2)) self.assertTrue(torch.is_tensor(grad_output)) return (grad_output + grad_output * t2, None, grad_output * ctx.scalar + grad_output * t1) x, y = self._function_test(MyFunction) x_grad_desc = graph_desc(x.grad.grad_fn) y_grad_desc = graph_desc(y.grad.grad_fn) self.assertEqual(graph_desc(x.grad.grad_fn), 'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))') self.assertEqual(graph_desc(y.grad.grad_fn), 'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
def test_hook_none(self): # WARNING: this is a test for autograd internals. # You should never have to use such things in your code. class NoneGradientFunction(Function): def forward(self, x, y): assert self.needs_input_grad[0] assert not self.needs_input_grad[1] return x, y def backward(self, grad_x, grad_y): return grad_x, None fn = NoneGradientFunction() was_called = [False] def hook(grad_input, grad_output): self.assertIsInstance(grad_input, tuple) self.assertIsInstance(grad_output, tuple) self.assertIsNotNone(grad_input[0]) self.assertIsNone(grad_input[1]) self.assertIsNotNone(grad_output[0]) self.assertIsNotNone(grad_output[1]) was_called[0] = True fn.register_hook(hook) x = Variable(torch.randn(5, 5), requires_grad=True) y = Variable(torch.randn(5, 5)) sum(fn(x, y)).sum().backward() self.assertTrue(was_called[0])
def test_gc_in_destructor(self): """ Previously, if a Function destructor triggered a garbage collection, the Variable's tp_dealloc handler would get called twice leading to a segfault. """ class CollectOnDelete(Function): def __del__(self): gc.collect() for i in range(10): Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
def test_too_many_grads(self): class MyFn(Function): def forward(self, input): return input def backward(self, grad_output): return grad_output, None, None x = Variable(torch.randn(5, 5), requires_grad=True) y = MyFn()(x) y.sum().backward() self.assertEqual(x.grad.data, x.data.clone().fill_(1))
def test_dep_nograd(self): class F1(Function): def forward(self, input): out = torch.randn(input.size()) self.mark_non_differentiable(out) return input, out def backward(self, grad_output, ignored): return grad_output class F2(Function): def forward(self, input, ignored): return input def backward(self, grad_output): return grad_output, None x = Variable(torch.randn(5), requires_grad=True) a, b = F1()(x) b = b + 1 # separate F1 from F2 by another op self.assertTrue(a.requires_grad) self.assertFalse(b.requires_grad) c = F2()(a, b) c.backward(torch.ones(c.size())) self.assertEqual(x.grad.data, torch.ones(x.size()))
def test_return_leaf(self): class Identity(Function): def forward(self, a, b): return a, a + b def backward(self, grad_a, grad_b): return grad_a + grad_b, grad_b class Inplace(InplaceFunction): def forward(self, a, b): self.mark_dirty(a) return a.add_(b), b + 2 def backward(self, grad_a, grad_b): return grad_a, grad_a + grad_b x = Variable(torch.randn(5, 5), requires_grad=True) y = Variable(torch.randn(5, 5), requires_grad=True) q, p = Identity()(x, y) # Make sure hooks only receive grad from usage of q, not x. q.register_hook( lambda grad: self.assertEqual(grad.data, torch.ones(5, 5))) (q + p + x).sum().backward() self.assertEqual(x.grad.data, torch.ones(5, 5) * 3) self.assertEqual(y.grad.data, torch.ones(5, 5)) del q, p # these need to be freed, or next part will raise an error
def test_legacy_fail(self): class MyLegacyFn(Function): def forward(self, x): return x def backward(self, grad_output): return grad_output x = Variable(torch.Tensor([0]), requires_grad=True) trace = torch._C._tracer_enter((x,), 0) self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(x)) torch._C._tracer_exit((x,))
def test_once_differentiable(self): class MyFunction(Function): @staticmethod def forward(ctx, tensor1, scalar, tensor2): ctx.scalar = scalar ctx.save_for_backward(tensor1, tensor2) return tensor1 + scalar * tensor2 + tensor1 * tensor2 @staticmethod @once_differentiable def backward(ctx, grad_output): t1, t2 = ctx.saved_tensors # NOTE: self is the test case here self.assertTrue(torch.is_tensor(t1)) self.assertTrue(torch.is_tensor(t2)) self.assertTrue(torch.is_tensor(grad_output)) return (grad_output + grad_output * t2, None, grad_output * ctx.scalar + grad_output * t1) x, y = self._function_test(MyFunction) self.assertEqual(graph_desc(x.grad.grad_fn), 'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))') self.assertEqual(graph_desc(y.grad.grad_fn), 'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
def test_save_output_nr(self): x = Variable(torch.randn(10), requires_grad=True) class MultiOutputFn(Function): @staticmethod def forward(ctx, x): return x[:5], x[5:] @staticmethod def backward(ctx, *grad): return torch.cat(grad) a, b = MultiOutputFn.apply(x) self.assertEqual(b.output_nr, 1) class TestFn(Function): @staticmethod def forward(ctx, b): ctx.save_for_backward(b) return b * 2 @staticmethod def backward(ctx, grad_b): b, = ctx.saved_variables self.assertEqual(b.output_nr, 1) TestFn.apply(b).sum().backward()
def inv_matmul_factory(matmul_closure_factory=_default_matmul_closure_factor, derivative_quadratic_form_factory=_default_derivative_quadratic_form_factory): class InvMatmul(Function): def __init__(self, *args): self.args = args def forward(self, *args): closure_args = self.args + args[:-1] rhs = args[-1] res = LinearCG().solve(matmul_closure_factory(*closure_args), rhs) self.save_for_backward(*(list(args) + [res])) return res def backward(self, grad_output): if derivative_quadratic_form_factory is None: raise NotImplementedError args = self.saved_tensors[:-2] closure_args = self.args + args res = self.saved_tensors[-1] arg_grads = [None] * len(args) rhs_grad = None # input_1 gradient if any(self.needs_input_grad[:-1]): lhs_matrix_grad = LinearCG().solve(matmul_closure_factory(*closure_args), grad_output) lhs_matrix_grad = lhs_matrix_grad.mul_(-1) if res.ndimension() == 1: res = res.unsqueeze(1) if lhs_matrix_grad.ndimension() == 1: lhs_matrix_grad = lhs_matrix_grad.unsqueeze(1) arg_grads = list(derivative_quadratic_form_factory(*args)(lhs_matrix_grad.t(), res.t())) # input_2 gradient if self.needs_input_grad[-1]: rhs_grad = LinearCG().solve(matmul_closure_factory(*closure_args), grad_output) return tuple(arg_grads + [rhs_grad]) return InvMatmul
def test_inplace_flags(self): class InplaceFn(Function): @staticmethod def forward(ctx, x): ctx.mark_dirty(x) return x.add_(1) @staticmethod def backward(ctx, go): return go class RegularFn(Function): @staticmethod def forward(ctx, x): return x.add(1) @staticmethod def backward(ctx, go): return go x = Variable(torch.Tensor([0]), requires_grad=True) trace = torch._C._tracer_enter((x,), 0) y = RegularFn.apply(x) y = InplaceFn.apply(y) y = InplaceFn.apply(y) y = RegularFn.apply(y) torch._C._tracer_exit((y,)) ops = [n for n in trace.graph().nodes()] for op in ops: self.assertTrue(op.hasAttribute('inplace')) inplace_flags = [False, True, True, False] for op, is_inplace in zip(ops, inplace_flags): self.assertEqual(op.i('inplace'), is_inplace)
def test_function(self): class MyFunction(Function): @staticmethod def forward(ctx, tensor1, scalar, tensor2): ctx.scalar = scalar ctx.save_for_backward(tensor1, tensor2) return tensor1 + scalar * tensor2 + tensor1 * tensor2 @staticmethod def backward(ctx, grad_output): var1, var2 = ctx.saved_variables # NOTE: self is the test case here self.assertIsInstance(var1, Variable) self.assertIsInstance(var2, Variable) self.assertIsInstance(grad_output, Variable) return (grad_output + grad_output * var2, None, grad_output * ctx.scalar + grad_output * var1) x, y = self._function_test(MyFunction) x_grad_desc = graph_desc(x.grad.grad_fn) y_grad_desc = graph_desc(y.grad.grad_fn) self.assertEqual( x_grad_desc, 'CloneBackward(AddBackward1(ExpandBackward(AccumulateGrad()), ' 'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))') self.assertEqual( y_grad_desc, 'CloneBackward(AddBackward1(MulBackward0(ExpandBackward(AccumulateGrad())), ' 'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
def test_legacy_function_none_grad(self): class MyFunction(Function): def forward(self, x): return torch.zeros(2, 2, 2) def backward(self, grad_output): return None shape = (2, 3) v = Variable(torch.ones(shape), requires_grad=True) y = v[0, 0].expand(3, 5).t().sum() MyFunction()(y).sum().backward() self.assertEqual(v.grad.data, torch.zeros(shape))
def test_hook_none(self): # WARNING: this is a test for autograd internals. # You should never have to use such things in your code. class NoneGradientFunction(Function): def forward(self, x, y): assert self.needs_input_grad[0] assert not self.needs_input_grad[1] return x, y def backward(self, grad_x, grad_y): return grad_x, None fn = NoneGradientFunction() was_called = [False] def hook(grad_input, grad_output): self.assertIsInstance(grad_input, tuple) self.assertIsInstance(grad_output, tuple) self.assertIsNotNone(grad_input[0]) self.assertIsNotNone(grad_input[1]) self.assertIsNotNone(grad_output[0]) self.assertIsNotNone(grad_output[1]) was_called[0] = True fn.register_hook(hook) x = Variable(torch.randn(5, 5), requires_grad=True) y = Variable(torch.randn(5, 5)) sum(fn(x, y)).sum().backward() self.assertTrue(was_called[0])
def test_sparse_backward(self): class FixedGradientFunction(Function): def __init__(self, grad): self.grad = grad def forward(self, x): return x def backward(self, grad_x): return self.grad size = torch.Size([6, 3, 2]) i1 = torch.LongTensor([ [0, 3, 4], [0, 2, 2], ]) v1 = torch.DoubleTensor([[1, 2], [4, 5], [7, 8]]) sparse_grad1 = torch.sparse.DoubleTensor(i1, v1, size) i2 = torch.LongTensor([ [0, 1, 3, 4], [0, 1, 2, 2], ]) v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]]) sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size) dense_grad = torch.rand(size).double() sparse_fn1 = FixedGradientFunction(sparse_grad1) sparse_fn2 = FixedGradientFunction(sparse_grad2) dense_fn = FixedGradientFunction(dense_grad) # sparse first x = Variable(torch.randn(5, 5), requires_grad=True) (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2) # dense first x = Variable(torch.randn(5, 5), requires_grad=True) (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2) # sparse only x = Variable(torch.randn(5, 5), requires_grad=True) (sparse_fn1(x) + sparse_fn2(x)).sum().backward() self.assertEqual(x.grad.data, sparse_grad1 + sparse_grad2)