| # Owner(s): ["module: dynamo"] |
| |
| import unittest |
| |
| import torch |
| from functorch import make_fx |
| from torch._dynamo import debug_utils |
| from torch._dynamo.debug_utils import aot_graph_input_parser |
| from torch._dynamo.test_case import TestCase |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") |
| |
| f32 = torch.float32 |
| i64 = torch.int64 |
| i32 = torch.int32 |
| |
| |
| class TestDebugUtils(TestCase): |
| def test_cast_model_to_fp64_dtype_args(self): |
| # Test that dtype arguments are converted to fp64 |
| |
| def fn(x): |
| return ( |
| torch.ops.prims.convert_element_type(x, torch.float16), |
| x.to(torch.float16), |
| torch.full(x.shape, 2, dtype=torch.float32, device=x.device), |
| x.new_empty(x.shape), |
| ) |
| |
| x = torch.randn(32, device="cpu") |
| decomps = torch._decomp.core_aten_decompositions() |
| fx = make_fx(fn, decomposition_table=decomps)(x) |
| |
| self.assertExpectedInline( |
| fx.code.lstrip(), |
| """\ |
| def forward(self, x_1): |
| convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16) |
| _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None |
| full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) |
| empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) |
| return (convert_element_type, _to_copy, full, empty) |
| """, # NOQA: B950 |
| ) |
| |
| fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) |
| self.assertEqual(fp64_examples, (x.to(torch.float64),)) |
| |
| self.assertExpectedInline( |
| fx.code.lstrip(), |
| """\ |
| def forward(self, x_1): |
| convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64) |
| _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None |
| full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) |
| empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) |
| return (convert_element_type, _to_copy, full, empty) |
| """, # NOQA: B950 |
| ) |
| |
| @requires_cuda |
| def test_aot_graph_parser(self): |
| from torch import device |
| |
| def forward( |
| self, |
| primals_1: "f32[1001, 6]", |
| primals_2: "f32[1001]", |
| primals_3: "f32[1001, 64]", |
| primals_4: "f32[4190]", |
| primals_5: "f32[4190]", |
| primals_6: "f32[1739, 4190]", |
| primals_48: "f32[6144, 4191]", |
| ): |
| _tensor_constant0: "i64[4190]" = self._tensor_constant0 |
| lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default( |
| _tensor_constant0 |
| ) |
| _tensor_constant0 = None |
| index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( |
| primals_48, [None, lift_fresh_copy] |
| ) |
| lift_fresh_copy = None |
| |
| _tensor_constant1: "i64[6]" = self._tensor_constant1 |
| lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default( |
| _tensor_constant1 |
| ) |
| _tensor_constant1 = None |
| index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor( |
| primals_48, [None, lift_fresh_copy_1] |
| ) |
| primals_48 = lift_fresh_copy_1 = None |
| permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0]) |
| primals_1 = None |
| addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default( |
| primals_2, index_1, permute |
| ) |
| primals_2 = permute = None |
| amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True) |
| sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax) |
| exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub) |
| sub = None |
| sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) |
| div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1) |
| exp = None |
| |
| full_default: "i32[6144, 1001]" = torch.ops.aten.full.default( |
| [6144, 1001], |
| 1, |
| dtype=torch.int32, |
| layout=torch.strided, |
| device=device(type="cuda", index=0), |
| pin_memory=False, |
| ) |
| |
| iota: "i32[1001]" = torch.ops.prims.iota.default( |
| 1001, |
| start=0, |
| step=1, |
| dtype=torch.int32, |
| device=device(type="cuda"), |
| requires_grad=False, |
| ) |
| |
| mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota) |
| full_default = iota = None |
| |
| iota_1: "i32[6144]" = torch.ops.prims.iota.default( |
| 6144, |
| start=0, |
| step=1001, |
| dtype=torch.int32, |
| device=device(type="cuda", index=0), |
| requires_grad=False, |
| ) |
| view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1]) |
| mul = None |
| view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1]) |
| div = None |
| _embedding_bag = torch.ops.aten._embedding_bag.default( |
| primals_3, view, iota_1, False, 0, False, view_1 |
| ) |
| |
| return _embedding_bag |
| |
| kwargs = aot_graph_input_parser(forward, device="cuda") |
| # runs successfully |
| forward(**kwargs) |
| |
| @requires_cuda |
| def test_sym_aot_graph_parser(self): |
| def forward( |
| self, |
| primals_1: "f32[1001, 6]", # noqa: F821 |
| primals_2: "f32[s0]", # noqa: F821 |
| primals_3: "Sym(s0)", # noqa: F821, |
| primals_4: "f32[s1]", # noqa: F821, |
| primals_5: "Sym(s1)", # noqa: F821, |
| ): |
| _tensor_constant0: "i64[4190]" = self._tensor_constant0 |
| |
| kwargs = aot_graph_input_parser( |
| forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5 |
| ) |
| |
| self.assertEqual(list(kwargs["primals_2"].shape), [10]) |
| self.assertEqual(kwargs["primals_3"], 10) |
| |
| self.assertEqual(list(kwargs["primals_4"].shape), [5]) |
| self.assertEqual(kwargs["primals_5"], 5) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |