| # Owner(s): ["oncall: pt2"] |
| import sys |
| import unittest |
| import torch |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| run_tests, |
| ) |
| |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes |
| from functorch.compile import aot_function, nop, min_cut_rematerialization_partition |
| from unittest.mock import patch |
| import functools |
| import torch.utils.checkpoint |
| |
| |
| from torch.testing._internal.common_utils import ( |
| IS_CI, |
| IS_WINDOWS, |
| ) |
| |
| if IS_WINDOWS and IS_CI: |
| sys.stderr.write( |
| "torch.compile not supported on windows" |
| ) |
| if __name__ == "__main__": |
| sys.exit(0) |
| raise unittest.SkipTest("torch.compile not supported on windows") |
| |
| def count_philox_rand(gm, args, freq): |
| assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq |
| return gm |
| |
| class TestFunctionalizationRngOps(TestCase): |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_rand_like(self, dtype, device): |
| def fn(x): |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| x = torch.rand(10, device=device, dtype=dtype) |
| |
| for seed in range(10): |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| |
| torch.cuda.manual_seed(seed) |
| aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) |
| res = aot_fn(x) |
| |
| self.assertEqual(ref, res) |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_rand_like_dynamic(self, dtype, device): |
| def fn(x): |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| for seed in range(1, 10): |
| shape = (seed, seed) |
| x = torch.rand(shape, device=device, dtype=dtype) |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| |
| torch.cuda.manual_seed(seed) |
| opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) |
| res = opt_fn(x) |
| |
| self.assertEqual(ref, res) |
| |
| |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_rand_like_dynamic_bwd(self, dtype, device): |
| def fn(x): |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| for seed in range(1, 10): |
| shape = (seed, seed) |
| x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| ref.sum().backward() |
| |
| torch.cuda.manual_seed(seed) |
| opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True) |
| res = opt_fn(x) |
| res.sum().backward() |
| |
| self.assertEqual(ref, res) |
| |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_rand(self, dtype, device): |
| shape = (10,) |
| |
| def fn(x): |
| a = torch.rand(*shape, device=device, dtype=dtype) * x |
| a = torch.rand(*shape, device=device, dtype=dtype) * a |
| return a |
| |
| x = torch.rand(*shape, device=device, dtype=dtype) |
| |
| for seed in range(10): |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| |
| torch.cuda.manual_seed(seed) |
| aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2)) |
| res = aot_fn(x) |
| |
| self.assertEqual(ref, res) |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_autograd_function(self, dtype, device): |
| shape = (16, 16) |
| |
| class Custom(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.save_for_backward(x) |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| @staticmethod |
| def backward(ctx, grad_out): |
| x, = ctx.saved_tensors |
| return grad_out * torch.rand_like(grad_out) * torch.cos(x) |
| |
| custom = Custom.apply |
| |
| x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) |
| |
| x_clone = x.clone().detach().requires_grad_(True) |
| |
| torch.cuda.manual_seed(123) |
| ref = custom(x) |
| ref.sum().backward() |
| |
| torch.cuda.manual_seed(123) |
| fwd_compiler = functools.partial(count_philox_rand, freq=2) |
| bwd_compiler = functools.partial(count_philox_rand, freq=1) |
| aot_custom = aot_function(custom, fwd_compiler, bwd_compiler) |
| res = aot_custom(x_clone) |
| res.sum().backward() |
| |
| self.assertEqual(ref, res) |
| self.assertEqual(x.grad, x_clone.grad) |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_multiple_subgraphs(self, dtype, device): |
| # Checks that rng state is maintained when there are multiple aot traced |
| # graphs. |
| shape = (16, 16) |
| |
| class CustomOp1(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.save_for_backward(x) |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| @staticmethod |
| def backward(ctx, grad_out): |
| x, = ctx.saved_tensors |
| return grad_out * torch.rand_like(grad_out) * torch.cos(x) |
| |
| class CustomOp2(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.save_for_backward(x) |
| a = torch.rand_like(x) * x |
| return a |
| |
| @staticmethod |
| def backward(ctx, grad_out): |
| x, = ctx.saved_tensors |
| return grad_out * torch.rand_like(grad_out) * torch.rand_like(x) |
| |
| |
| custom_op1 = CustomOp1.apply |
| custom_op2 = CustomOp2.apply |
| |
| def fn(x): |
| a = custom_op1(x) |
| b = a.sin() |
| return custom_op2(b) |
| |
| fwd_compiler = functools.partial(count_philox_rand, freq=2) |
| bwd_compiler = functools.partial(count_philox_rand, freq=1) |
| aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler) |
| fwd_compiler = functools.partial(count_philox_rand, freq=1) |
| bwd_compiler = functools.partial(count_philox_rand, freq=2) |
| aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler) |
| |
| def aot_fn(x): |
| a = aot_custom_op1(x) |
| b = a.sin() |
| return aot_custom_op2(b) |
| |
| |
| for seed in range(10): |
| torch.cuda.manual_seed(seed) |
| x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) |
| x_clone = x.clone().detach().requires_grad_(True) |
| |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| ref.sum().backward() |
| |
| torch.cuda.manual_seed(seed) |
| res = aot_fn(x_clone) |
| res.sum().backward() |
| |
| self.assertEqual(ref, res) |
| self.assertEqual(x.grad, x_clone.grad) |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_set_get_rng_state(self, dtype, device): |
| def fn(x): |
| a = torch.rand_like(x) * x |
| state = torch.cuda.get_rng_state() |
| a = torch.rand_like(x) * a |
| torch.cuda.set_rng_state(state) |
| a = torch.rand_like(x) * a |
| return a |
| |
| x = torch.rand(10, device=device, dtype=dtype) |
| |
| for seed in range(10): |
| torch.cuda.manual_seed(seed) |
| ref = fn(x) |
| |
| torch.cuda.manual_seed(seed) |
| fwd_compiler = functools.partial(count_philox_rand, freq=3) |
| aot_fn = aot_function(fn, fwd_compiler) |
| res = aot_fn(x) |
| |
| self.assertEqual(ref, res) |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_min_cut_partitioner(self, dtype, device): |
| # Checks that the calling convention is maintained |
| shape = (16, 16) |
| |
| def fn(x): |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| a = torch.sin(a) |
| a = torch.sin(a) |
| a = torch.sin(a) |
| return a |
| |
| |
| x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) |
| |
| x_clone = x.clone().detach().requires_grad_(True) |
| |
| torch.cuda.manual_seed(123) |
| ref = fn(x) |
| ref.sum().backward() |
| |
| torch.cuda.manual_seed(123) |
| fwd_compiler = functools.partial(count_philox_rand, freq=2) |
| bwd_compiler = functools.partial(count_philox_rand, freq=0) |
| aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition) |
| # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler) |
| res = aot_custom(x_clone) |
| res.sum().backward() |
| |
| self.assertEqual(ref, res) |
| self.assertEqual(x.grad, x_clone.grad) |
| |
| # TODO - Dropout needs more work because of offset calculation |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| @dtypes(torch.float32) |
| def test_checkpoint(self, dtype, device): |
| def g(x, y): |
| return torch.nn.functional.dropout(x, 0.6) |
| |
| def fn(x, y): |
| return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False) |
| |
| # x = torch.rand(2, 2, device="cuda", requires_grad=True) |
| x = torch.ones(2, 2, device="cuda", requires_grad=True) |
| y = torch.rand(2, 2, device="cuda", requires_grad=True) |
| torch.cuda.manual_seed(123) |
| ref = fn(x, y) |
| |
| # With checkpointing we should recompute dropout in bwd, and should see philox_rand |
| fwd_compiler = functools.partial(count_philox_rand, freq=1) |
| bwd_compiler = functools.partial(count_philox_rand, freq=1) |
| aot_fn = aot_function(fn, fwd_compiler, bwd_compiler) |
| # We cant check accuracy here because rand_like generated different rand numbers than dropout |
| res = aot_fn(x, y) |
| res.sum().backward() |
| |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_dropout_decomp(self, dtype, device): |
| def fn(x): |
| return torch.nn.functional.dropout(x, 0.6) * x |
| |
| x = torch.rand(10, device=device, dtype=dtype) |
| |
| # Ensure the decomp is happening |
| aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1)) |
| # We cant check accuracy here because rand_like generated different rand numbers than dropout |
| aot_fn(x) |
| |
| |
| only_for = ("cuda",) |
| instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for) |
| |
| |
| class NegativeTest(TestCase): |
| @dtypes(torch.float32) |
| @patch.object(torch._functorch.config, "functionalize_rng_ops", True) |
| def test_on_cpu(self, dtype, device): |
| def fn(x): |
| a = torch.rand_like(x) * x |
| a = torch.rand_like(x) * a |
| return a |
| |
| x = torch.rand(10, device=device, dtype=dtype) |
| |
| aot_fn = aot_function(fn, nop) |
| with self.assertRaises(RuntimeError): |
| aot_fn(x) |
| |
| |
| only_for = ("cpu",) |
| instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for) |
| |
| if __name__ == "__main__": |
| run_tests() |