blob: 17966cfb85a655c63efca0fc0635ae8c8fd33ff8 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import contextlib
import functools
import unittest
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.compile import nop
from torch._dynamo import compiled_autograd
from torch._functorch.aot_autograd import aot_module_simplified
from torch.utils.hooks import RemovableHandle
def compiler_fn(gm):
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
def global_hook_0(grad):
return grad * 4
def global_hook_1(grad):
return grad / 2
def global_hook_2(grad):
return grad * 3
h0 = None
class ClassWithVal:
def __init__(self, val):
self.val = val
class HooksTests(torch._dynamo.test_case.TestCase):
def test_tensor_only_register_hook_in_graph_lambda(self):
def fn(x):
x.register_hook(lambda grad: grad * 2)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_register_hook_in_graph_lambda(self):
def fn(x, y, z):
x.register_hook(lambda grad: grad * 2)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_lambda(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
x.register_hook(lambda grad: grad * 3)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_multi_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(h, None)
self.assertNotEqual(h2, None)
self.assertEqual(h2, h)
def test_tensor_register_hook_repeated_handle_return(self):
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
h2 = handle
z = z * z
return x, y * y, z, handle, handle
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertIsInstance(h, RemovableHandle)
self.assertIs(h2, h)
def test_removed_handle_return(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x, y, z):
handle = x.register_hook(lambda grad: grad * 2)
z = z * z
handle.remove()
handle.remove()
return x, y * y, z, handle, handle
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(cnt.frame_count, 1)
self.assertIsInstance(h, RemovableHandle)
self.assertIs(h2, h)
def test_tensor_register_hook_repeated_handle_not_local(self):
def fn(x, y, z, mod):
mod.handle = x.register_hook(lambda grad: grad * 2)
z = z * z
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
mod = torch.nn.Module()
mod.handle = None
v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
self.assertNotEqual(mod.handle, None)
def test_tensor_only_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x):
x.register_hook(local_hook)
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 0)
def test_tensor_only_register_hook_in_graph_local_inner(self):
def fn(x):
def local_hook(grad):
return grad * 2
z = x * x
x.register_hook(local_hook)
z.register_hook(local_hook)
return x, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_local(self):
def local_hook(grad):
return grad * 2
def fn(x, y, z):
x.register_hook(local_hook)
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_hook_in_graph_break_handle_local(self):
def local_hook(grad):
return grad * 2
def local_hook2(grad):
return grad * 3
def fn(x, y, z):
handle = x.register_hook(local_hook)
z = z * z
handle.remove()
x.register_hook(local_hook2)
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
def test_tensor_register_global_hook(self):
def fn(x):
x.register_hook(global_hook_0)
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks(self):
def fn(x):
x.register_hook(global_hook_0) # * 4
x.register_hook(global_hook_1) # / 2
x.register_hook(global_hook_2) # * 3
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_multiple_hooks_handles_in_list(self):
def fn(x):
h0 = x.register_hook(global_hook_0) # * 4
h1 = x.register_hook(global_hook_1) # / 2
h2 = x.register_hook(global_hook_2) # * 3
return x, x * x, h0, h1, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r, handle_0, handle_1, handle_2 = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
handle_0.remove()
handle_1.remove()
handle_2.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
self.assertEqual(cnts.frame_count, 1)
def test_tensor_register_global_hooks_handles_in_list(self):
def fn(x):
global h0
h0 = x.register_hook(global_hook_0) # * 4
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r = fn(v)
self.assertIsNotNone(h0)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
h0.remove()
v.backward(torch.tensor([1.0, 2.0, 3.0]))
# Handles gone, grad is just applied as is
self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
# NYI!
self.assertEqual(cnts.frame_count, 0)
def test_intermediary_hooks(self):
# Graph breaks because compiled_autograd is not set
def simple_hook(g):
return g * 2
def f(x):
y = x + 1
y.register_hook(simple_hook)
z = y + 1
return z
out = torch.randn(1, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
res = fn(out)
res.backward()
self.assertEqual(res, f(out))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(out.grad, torch.Tensor([2.0]))
def test_intermediary_hooks_same_on_aot_eager(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_input_hooks_same(self):
backends = ["eager", "aot_eager", "inductor"]
for backend in backends:
def my_hook(grad, *, k=0):
return grad + k
hook = functools.partial(my_hook, k=3)
class MyMod(torch.nn.Module):
def forward(self, x):
x.register_hook(hook)
y = x.mul(2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
with compiled_autograd.enable(compiler_fn):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_intermediary_hooks_same_on_inductor(self):
def my_hook(grad, *, k=0):
return grad + k
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
hook1 = functools.partial(my_hook, k=3)
hook2 = functools.partial(my_hook, k=4)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
mod_compiled = aot_module_simplified(mod, (x1,), nop)
aot_out = mod_compiled(x1)
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
self.assertEqual(dynamo_out, eager_out)
self.assertEqual(x0.grad, x1.grad)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self):
class SomePyClass:
count = 0
def do_stuff(self, grad):
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
# Eager 2
self.assertEqual(obj.count, 2)
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, eager_out)
# Eager 2 + compiled 2
self.assertEqual(obj.count, 4)
self.assertEqual(x0.grad, x2.grad)
def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break(
self,
):
class SomePyClass:
grad_as_str = "None"
count = 0
def write_grad_as_str_and_do_stuff(self, grad):
self.grad_as_str = str(grad)
if self.count % 2 == 0:
r = grad * grad
else:
r = grad + grad
print("Break!")
self.count += 1
return r
def complex_state_touching_hook(grad, *, obj):
return obj.write_grad_as_str_and_do_stuff(grad)
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(complex_state_touching_hook, obj=obj)
hook2 = functools.partial(complex_state_touching_hook, obj=obj)
y.register_hook(hook1)
y.register_hook(hook2)
z = y.mul(3)
return (z,)
mod = MyMod()
obj = SomePyClass()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0, obj)
eager_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(obj.count, 2)
def test_register_hook_partial_guarding(
self,
):
def some_hook(grad, *, obj):
return grad + obj.val
class MyMod(torch.nn.Module):
def forward(self, x, obj):
y = x.mul(2)
hook1 = functools.partial(some_hook, obj=obj)
y.register_hook(hook1)
z = y.mul(3)
return (z,)
mod = MyMod()
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
obj3 = ClassWithVal(11)
cnt = torch._dynamo.testing.CompileCounter()
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
self.assertEqual(cnt.frame_count, 1)
def test_hook_with_closure(self):
def fn(x, obj):
y = x.sin()
x.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_closure_eager(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_fw = torch._dynamo.testing.CompileCounter()
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_fw.frame_count, 1)
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_closure_aot(self):
def fn(x, obj):
y = x.sin()
y.register_hook(lambda grad: grad + obj.val)
z = y.sin()
return z
cnt_bw = torch._dynamo.testing.CompileCounter()
opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
obj1 = ClassWithVal(torch.tensor(88))
obj2 = ClassWithVal(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd.enable(
functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(cnt_bw.frame_count, 1)
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_no_recompile_on_hook_identity_change(self):
def my_hook(grad, k=0):
return grad + k
def my_hook2(grad):
return grad * 2
class MyMod(torch.nn.Module):
def forward(self, x):
y = x.mul(2)
y.register_hook(my_hook)
y.register_hook(my_hook)
z = y.mul(3)
return (z,)
mod = MyMod()
x0 = torch.ones(4, requires_grad=True)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
comp_out = comp_mod(x1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(cnts.frame_count, 1)
my_hook = my_hook2 # noqa: F811
self.assertEqual(x0.grad, x1.grad)
eager_out = mod(x0)
eager_out[0].backward(torch.ones(4))
comp_out = comp_mod(x1)
self.assertEqual(cnts.frame_count, 1)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)
def test_functools_arg_vary(self):
def pre_hook(grad, *, k):
return grad * k
hook = functools.partial(pre_hook, k=1)
@torch.compile(backend="eager", fullgraph=True)
def h(x):
y = x.mul(2)
y.register_hook(hook)
return y.mul(3)
with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)):
x = torch.randn(2, requires_grad=True)
h(x).sum().backward()
orig_grad = x.grad
x.grad = None
hook = functools.partial(pre_hook, k=2)
h(x).sum().backward()
self.assertEqual(orig_grad * 2, x.grad)
def test_post_acc_grad_hook(self):
def hook(input_t):
input_t.mul_(input_t.grad)
input_t.grad.mul_(5)
def reg_and_mul(x, y):
x.register_post_accumulate_grad_hook(hook)
return x * y
cnts = None
def test_fn(fn):
fn(x, y)
b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)
x.backward(b)
if cnts:
self.assertEqual(cnts.frame_count, 1)
# These same exact assertions run on both eager and compiled
# X goes to x*2 becaue of mul_
self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
# This test proves grad aliasing works -
self.assertEqual(x.grad, b * 5)
# Eager values
x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
test_fn(reg_and_mul)
# Compiled
for backend in ["eager", "aot_eager", "inductor"]:
for compiled_bwd in [False, True]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
compiled_bwd_ctx = (
compiled_autograd.enable(
torch.compile(backend=backend, fullgraph=True)
)
if compiled_bwd
else contextlib.nullcontext()
)
with compiled_bwd_ctx:
test_fn(compiled_fn)
def test_recompile(self):
def hook(param):
param.grad *= 2
x = torch.ones(10)
x.requires_grad = True
def run(input):
return x * input
x.register_post_accumulate_grad_hook(hook)
with compiled_autograd.enable(compiler_fn):
for i in range(5):
with unittest.mock.patch(
"torch._dynamo.config.error_on_recompile", True
):
# Mimic optimizer.zero_grad() to clear the gradient
x.grad = None
run(i).sum().backward()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()