| # Owner(s): ["oncall: jit"] |
| |
| import torch |
| from torch.testing import FileCheck |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| |
| class TestBatchMM(JitTestCase): |
| @staticmethod |
| def _get_test_tensors(n: int): |
| return [ |
| torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]]) |
| if x % 2 == 0 |
| else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]]) |
| for x in range(n) |
| ] |
| |
| def test_batch_mm_no_mutation(self): |
| def test_batch_mm( |
| T1: torch.Tensor, |
| T2: torch.Tensor, |
| T3: torch.Tensor, |
| T4: torch.Tensor, |
| T5: torch.Tensor, |
| T6: torch.Tensor, |
| T7: torch.Tensor, |
| T8: torch.Tensor, |
| ): |
| return ( |
| torch.mm(T1, T2) |
| + torch.mm(T3, T4) |
| + torch.mm(T5, T6) |
| + torch.mm(T7, T8) |
| ) |
| |
| test_batch_mm_scripted = torch.jit.script(test_batch_mm) |
| |
| tensors = TestBatchMM._get_test_tensors(8) |
| expected = test_batch_mm(*tensors) |
| |
| FileCheck().check_count("aten::mm", 4, exactly=True).run( |
| test_batch_mm_scripted.graph |
| ) |
| self.run_pass("batch_mm", test_batch_mm_scripted.graph) |
| FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( |
| test_batch_mm_scripted.graph |
| ) |
| |
| actual = test_batch_mm_scripted(*tensors) |
| self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) |
| |
| |
| def test_batch_mm_permitted_mutation(self): |
| def test_batch_mm( |
| T1: torch.Tensor, |
| T2: torch.Tensor, |
| T3: torch.Tensor, |
| T4: torch.Tensor, |
| T5: torch.Tensor, |
| T6: torch.Tensor, |
| T7: torch.Tensor, |
| T8: torch.Tensor, |
| ): |
| result = {} |
| result["product"] = ( |
| torch.mm(T1, T2) |
| + torch.mm(T3, T4) |
| + torch.mm(T5, T6) |
| + torch.mm(T7, T8) |
| ) |
| result["constant"] = torch.tensor([42.0]) |
| return result |
| |
| test_batch_mm_scripted = torch.jit.script(test_batch_mm) |
| |
| tensors = TestBatchMM._get_test_tensors(8) |
| expected = test_batch_mm(*tensors) |
| |
| FileCheck().check_count("aten::mm", 4, exactly=True).run( |
| test_batch_mm_scripted.graph |
| ) |
| self.run_pass("batch_mm", test_batch_mm_scripted.graph) |
| FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( |
| test_batch_mm_scripted.graph |
| ) |
| |
| actual = test_batch_mm_scripted(*tensors) |
| self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) |
| |
| def test_batch_mm_prohibited_mutation(self): |
| @torch.jit.script |
| def test_batch_mm(n: int): |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| torch.relu_(T1) |
| result = ( |
| torch.mm(T1, T2) |
| + torch.mm(T3, T4) |
| + torch.mm(T5, T6) |
| + torch.mm(T7, T8) |
| ) |
| return result |
| |
| FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph) |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| FileCheck().check_count("aten::mm", 4, exactly=True).check_not( |
| "prim::MMTreeReduce" |
| ).run(test_batch_mm.graph) |
| |
| def test_batch_mm_prohibited_mutation_multiple_adds(self): |
| @torch.jit.script |
| def test_batch_mm(n: int): |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| T9 = torch.zeros((n, n)) |
| T10 = torch.zeros((n, n)) |
| torch.relu_(T1) |
| result = {} |
| result["no_mutated_parameters"] = ( |
| torch.mm(T2, T3) |
| + torch.mm(T4, T5) |
| + torch.mm(T6, T7) |
| + torch.mm(T8, T9) |
| ) |
| result["all_parameters"] = ( |
| torch.mm(T1, T2) |
| + torch.mm(T3, T4) |
| + torch.mm(T5, T6) |
| + torch.mm(T7, T8) |
| + torch.mm(T9, T10) |
| ) |
| return result |
| |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count( |
| "aten::mm", 5, exactly=True |
| ).run(test_batch_mm.graph) |
| |
| def test_batch_mm_prohibited_mutation_if_node(self): |
| @torch.jit.script |
| def test_batch_mm(n: int, use_t1: bool): |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| T9 = torch.zeros((n, n)) |
| T10 = torch.zeros((n, n)) |
| if use_t1: |
| torch.relu_(T1) |
| return ( |
| torch.mm(T1, T2) |
| + torch.mm(T3, T4) |
| + torch.mm(T5, T6) |
| + torch.mm(T7, T8) |
| + torch.mm(T9, T10) |
| ) |
| else: |
| return ( |
| torch.mm(T2, T3) |
| + torch.mm(T4, T5) |
| + torch.mm(T6, T7) |
| + torch.mm(T8, T9) |
| ) |
| |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| FileCheck().check_count("aten::mm", 5, exactly=True).check_count( |
| "prim::MMTreeReduce", 1, exactly=True |
| ).run(test_batch_mm.graph) |
| |
| def test_batch_mm_side_permitted_mutation(self): |
| @torch.jit.script |
| def test_batch_mm(n: int): |
| result = {} |
| A = torch.zeros((n, n)) |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| result["T1"] = torch.mm(A, T1) |
| result["T2"] = torch.mm(A, T2) |
| result["T3"] = torch.mm(A, T3) |
| result["T4"] = torch.mm(A, T4) |
| result["T5"] = torch.mm(A, T5) |
| result["T6"] = torch.mm(A, T6) |
| result["T7"] = torch.mm(A, T7) |
| result["T8"] = torch.mm(A, T8) |
| return result |
| |
| FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph) |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not( |
| "aten::mm" |
| ).run(test_batch_mm.graph) |
| |
| def test_batch_mm_side_prohibited_mutation_uncommon_side(self): |
| @torch.jit.script |
| def test_batch_mm(n: int): |
| A = torch.zeros((n, n)) |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| T9 = torch.zeros((n, n)) |
| T10 = torch.zeros((n, n)) |
| torch.relu_(T1) |
| result = {} |
| result["T1"] = torch.mm(A, T1) |
| result["T2"] = torch.mm(A, T2) |
| result["T3"] = torch.mm(A, T3) |
| result["T4"] = torch.mm(A, T4) |
| result["T5"] = torch.mm(A, T5) |
| result["T6"] = torch.mm(A, T6) |
| result["T7"] = torch.mm(A, T7) |
| result["T8"] = torch.mm(A, T8) |
| result["T9"] = torch.mm(A, T9) |
| result["T10"] = torch.mm(A, T10) |
| return result |
| |
| FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| |
| FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph) |
| FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run( |
| test_batch_mm.graph |
| ) |
| |
| def test_batch_mm_side_prohibited_mutation_common_side(self): |
| @torch.jit.script |
| def test_batch_mm(n: int): |
| A = torch.zeros((n, n)) |
| T1 = torch.zeros((n, n)) |
| T2 = torch.zeros((n, n)) |
| T3 = torch.zeros((n, n)) |
| T4 = torch.zeros((n, n)) |
| T5 = torch.zeros((n, n)) |
| T6 = torch.zeros((n, n)) |
| T7 = torch.zeros((n, n)) |
| T8 = torch.zeros((n, n)) |
| T9 = torch.zeros((n, n)) |
| T10 = torch.zeros((n, n)) |
| torch.relu_(A) |
| result = {} |
| result["T1"] = torch.mm(A, T1) |
| result["T2"] = torch.mm(A, T2) |
| result["T3"] = torch.mm(A, T3) |
| result["T4"] = torch.mm(A, T4) |
| result["T5"] = torch.mm(A, T5) |
| result["T6"] = torch.mm(A, T6) |
| result["T7"] = torch.mm(A, T7) |
| result["T8"] = torch.mm(A, T8) |
| result["T9"] = torch.mm(A, T9) |
| result["T10"] = torch.mm(A, T10) |
| return result |
| |
| FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) |
| self.run_pass("batch_mm", test_batch_mm.graph) |
| FileCheck().check_count("aten::mm", 10, exactly=True).check_not( |
| "prim::MMBatchSide" |
| ).run(test_batch_mm.graph) |