blob: 3173e81f549b8767161bd3ebdef096c7c7a0f341 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import torch
from torch.testing._internal.jit_utils import JitTestCase
from typing import List
class TestAutodiffJit(JitTestCase):
def test_undefined_tensor_lists(self):
def fn(tensor_list: List[torch.Tensor], add_tensor):
cat = torch.cat(tensor_list, dim=1)
r = torch.sin(cat + add_tensor)
return r
fn_s = torch.jit.script(fn)
a = torch.rand((3, 6), requires_grad=True)
b = torch.rand((3, 10), requires_grad=True)
x = [a, b]
y = torch.rand((3, 16), requires_grad=True)
ret = fn_s(x, y)
ret.sum().backward()
ret = fn_s(x, y)
ret.sum().backward()
ret = fn_s(x, y)
s = ret.sum()
# backward_fn expects 2 inputs: (grad_output, current_grad_r)
# current_grad_r is provided because we need to add this contribution
# to grad_r when we return it.
backward_fn = s.grad_fn.next_functions[0][0]
# check behavior with defined tensor
grad_out = torch.rand((3, 16))
grad_inputs = backward_fn(grad_out, None)
# expect 3 tensors: grad_y, grad_a, grad_b
self.assertEqual(3, len(grad_inputs))
for x in grad_inputs:
self.assertTrue(isinstance(x, torch.Tensor))
# now test with undefined grad_out
grad_inputs = backward_fn(None, None)
# expect all of them to be None
self.assertEqual(3, len(grad_inputs))
for x in grad_inputs:
if x is not None:
self.assertEqual(0, torch.max(torch.abs(x)).item())
def test_requires_grad_outputs(self):
# outputs should require_grad only if eager outputs would require_grad.
def fn(a, b, c):
return a.relu() + b.relu(), c.relu()
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
fn_s = torch.jit.script(fn)
for i in range(4):
x, y = fn_s(a, b, c)
self.assertFalse(x.requires_grad)
self.assertTrue(y.requires_grad)
def test_requires_grad_outputs_profiled_twice(self):
# the value "r" is used twice, by gammaln and by entr, so it is profiled twice.
# So during autodiff graph formation the profile nodes are unmerged because
# they are aliasing. Then the DifferentiableGraph doesn't have a profile
# node on the output. The requires_grad info should then be added onto the
# output value (otherwise autodiff will make the output require_grad).
# Note: this relies on gammaln and entr not having autodiff implementations.
def fn(a, b, c):
r = a.relu().relu()
return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu()
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_requires_grad_outputs_side_effects(self):
# same as above, but also add a CallFunction in between.
@torch.jit.ignore
def python_fn(x):
return x.relu()
def fn(a, b, c):
r = a.relu().relu()
z = python_fn(r)
return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu()
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_autodiff_requires_grad_nograd(self):
@torch.jit.ignore
def python_fn(x):
return x.relu()
def fn(a, b, c):
x = a.sin().relu()
y = python_fn(b)
with torch.no_grad():
z = x + c
return x, y, z
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=True)
b = torch.rand((10, 10), requires_grad=True)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)