blob: d5015aec18860ff192de236d01c08ecdc506d342 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
f32 = torch.float32
i64 = torch.int64
i32 = torch.int32
class TestDebugUtils(TestCase):
def test_cast_model_to_fp64_dtype_args(self):
# Test that dtype arguments are converted to fp64
def fn(x):
return (
torch.ops.prims.convert_element_type(x, torch.float16),
x.to(torch.float16),
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
x.new_empty(x.shape),
)
x = torch.randn(32, device="cpu")
decomps = torch._decomp.core_aten_decompositions()
fx = make_fx(fn, decomposition_table=decomps)(x)
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
@requires_cuda
def test_aot_graph_parser(self):
from torch import device
def forward(
self,
primals_1: "f32[1001, 6]",
primals_2: "f32[1001]",
primals_3: "f32[1001, 64]",
primals_4: "f32[4190]",
primals_5: "f32[4190]",
primals_6: "f32[1739, 4190]",
primals_48: "f32[6144, 4191]",
):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant0
)
_tensor_constant0 = None
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy]
)
lift_fresh_copy = None
_tensor_constant1: "i64[6]" = self._tensor_constant1
lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
_tensor_constant1
)
_tensor_constant1 = None
index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
primals_48, [None, lift_fresh_copy_1]
)
primals_48 = lift_fresh_copy_1 = None
permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
primals_1 = None
addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
primals_2, index_1, permute
)
primals_2 = permute = None
amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
sub = None
sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
exp = None
full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
[6144, 1001],
1,
dtype=torch.int32,
layout=torch.strided,
device=device(type="cuda", index=0),
pin_memory=False,
)
iota: "i32[1001]" = torch.ops.prims.iota.default(
1001,
start=0,
step=1,
dtype=torch.int32,
device=device(type="cuda"),
requires_grad=False,
)
mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
full_default = iota = None
iota_1: "i32[6144]" = torch.ops.prims.iota.default(
6144,
start=0,
step=1001,
dtype=torch.int32,
device=device(type="cuda", index=0),
requires_grad=False,
)
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
mul = None
view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
div = None
_embedding_bag = torch.ops.aten._embedding_bag.default(
primals_3, view, iota_1, False, 0, False, view_1
)
return _embedding_bag
kwargs = aot_graph_input_parser(forward, device="cuda")
# runs successfully
forward(**kwargs)
@requires_cuda
def test_sym_aot_graph_parser(self):
def forward(
self,
primals_1: "f32[1001, 6]", # noqa: F821
primals_2: "f32[s0]", # noqa: F821
primals_3: "Sym(s0)", # noqa: F821,
primals_4: "f32[s1]", # noqa: F821,
primals_5: "Sym(s1)", # noqa: F821,
):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
kwargs = aot_graph_input_parser(
forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5
)
self.assertEqual(list(kwargs["primals_2"].shape), [10])
self.assertEqual(kwargs["primals_3"], 10)
self.assertEqual(list(kwargs["primals_4"].shape), [5])
self.assertEqual(kwargs["primals_5"], 5)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()