blob: b2ac62e4f278636c3e2577ab8571e3fd6cbf6f2e [file] [log] [blame]
# Owner(s): ["oncall: pt2"]
import sys
import unittest
import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
from unittest.mock import patch
import functools
import torch.utils.checkpoint
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
)
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"torch.compile not supported on windows"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("torch.compile not supported on windows")
def count_philox_rand(gm, args, freq):
assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq
return gm
class TestFunctionalizationRngOps(TestCase):
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like_dynamic(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
for seed in range(1, 10):
shape = (seed, seed)
x = torch.rand(shape, device=device, dtype=dtype)
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
res = opt_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like_dynamic_bwd(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
for seed in range(1, 10):
shape = (seed, seed)
x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
torch.cuda.manual_seed(seed)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(seed)
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
res = opt_fn(x)
res.sum().backward()
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand(self, dtype, device):
shape = (10,)
def fn(x):
a = torch.rand(*shape, device=device, dtype=dtype) * x
a = torch.rand(*shape, device=device, dtype=dtype) * a
return a
x = torch.rand(*shape, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_autograd_function(self, dtype, device):
shape = (16, 16)
class Custom(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
custom = Custom.apply
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(123)
ref = custom(x)
ref.sum().backward()
torch.cuda.manual_seed(123)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
res = aot_custom(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_multiple_subgraphs(self, dtype, device):
# Checks that rng state is maintained when there are multiple aot traced
# graphs.
shape = (16, 16)
class CustomOp1(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
class CustomOp2(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
custom_op1 = CustomOp1.apply
custom_op2 = CustomOp2.apply
def fn(x):
a = custom_op1(x)
b = a.sin()
return custom_op2(b)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
fwd_compiler = functools.partial(count_philox_rand, freq=1)
bwd_compiler = functools.partial(count_philox_rand, freq=2)
aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
def aot_fn(x):
a = aot_custom_op1(x)
b = a.sin()
return aot_custom_op2(b)
for seed in range(10):
torch.cuda.manual_seed(seed)
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(seed)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(seed)
res = aot_fn(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_set_get_rng_state(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
state = torch.cuda.get_rng_state()
a = torch.rand_like(x) * a
torch.cuda.set_rng_state(state)
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
fwd_compiler = functools.partial(count_philox_rand, freq=3)
aot_fn = aot_function(fn, fwd_compiler)
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_min_cut_partitioner(self, dtype, device):
# Checks that the calling convention is maintained
shape = (16, 16)
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
a = torch.sin(a)
a = torch.sin(a)
a = torch.sin(a)
return a
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(123)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(123)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=0)
aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition)
# aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
res = aot_custom(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
# TODO - Dropout needs more work because of offset calculation
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
@dtypes(torch.float32)
def test_checkpoint(self, dtype, device):
def g(x, y):
return torch.nn.functional.dropout(x, 0.6)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
# x = torch.rand(2, 2, device="cuda", requires_grad=True)
x = torch.ones(2, 2, device="cuda", requires_grad=True)
y = torch.rand(2, 2, device="cuda", requires_grad=True)
torch.cuda.manual_seed(123)
ref = fn(x, y)
# With checkpointing we should recompute dropout in bwd, and should see philox_rand
fwd_compiler = functools.partial(count_philox_rand, freq=1)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
# We cant check accuracy here because rand_like generated different rand numbers than dropout
res = aot_fn(x, y)
res.sum().backward()
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_dropout_decomp(self, dtype, device):
def fn(x):
return torch.nn.functional.dropout(x, 0.6) * x
x = torch.rand(10, device=device, dtype=dtype)
# Ensure the decomp is happening
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
# We cant check accuracy here because rand_like generated different rand numbers than dropout
aot_fn(x)
only_for = ("cuda",)
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
class NegativeTest(TestCase):
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_on_cpu(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
aot_fn = aot_function(fn, nop)
with self.assertRaises(RuntimeError):
aot_fn(x)
only_for = ("cpu",)
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
if __name__ == "__main__":
run_tests()