blob: 517a05e132a4a4616c134a000bf7fbd46affe371 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestBatchMM(JitTestCase):
@staticmethod
def _get_test_tensors(n: int):
return [
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
if x % 2 == 0
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
for x in range(n)
]
def test_batch_mm_no_mutation(self):
def test_batch_mm(
T1: torch.Tensor,
T2: torch.Tensor,
T3: torch.Tensor,
T4: torch.Tensor,
T5: torch.Tensor,
T6: torch.Tensor,
T7: torch.Tensor,
T8: torch.Tensor,
):
return (
torch.mm(T1, T2)
+ torch.mm(T3, T4)
+ torch.mm(T5, T6)
+ torch.mm(T7, T8)
)
test_batch_mm_scripted = torch.jit.script(test_batch_mm)
tensors = TestBatchMM._get_test_tensors(8)
expected = test_batch_mm(*tensors)
FileCheck().check_count("aten::mm", 4, exactly=True).run(
test_batch_mm_scripted.graph
)
self.run_pass("batch_mm", test_batch_mm_scripted.graph)
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
test_batch_mm_scripted.graph
)
actual = test_batch_mm_scripted(*tensors)
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
def test_batch_mm_permitted_mutation(self):
def test_batch_mm(
T1: torch.Tensor,
T2: torch.Tensor,
T3: torch.Tensor,
T4: torch.Tensor,
T5: torch.Tensor,
T6: torch.Tensor,
T7: torch.Tensor,
T8: torch.Tensor,
):
result = {}
result["product"] = (
torch.mm(T1, T2)
+ torch.mm(T3, T4)
+ torch.mm(T5, T6)
+ torch.mm(T7, T8)
)
result["constant"] = torch.tensor([42.0])
return result
test_batch_mm_scripted = torch.jit.script(test_batch_mm)
tensors = TestBatchMM._get_test_tensors(8)
expected = test_batch_mm(*tensors)
FileCheck().check_count("aten::mm", 4, exactly=True).run(
test_batch_mm_scripted.graph
)
self.run_pass("batch_mm", test_batch_mm_scripted.graph)
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
test_batch_mm_scripted.graph
)
actual = test_batch_mm_scripted(*tensors)
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
def test_batch_mm_prohibited_mutation(self):
@torch.jit.script
def test_batch_mm(n: int):
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
torch.relu_(T1)
result = (
torch.mm(T1, T2)
+ torch.mm(T3, T4)
+ torch.mm(T5, T6)
+ torch.mm(T7, T8)
)
return result
FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
"prim::MMTreeReduce"
).run(test_batch_mm.graph)
def test_batch_mm_prohibited_mutation_multiple_adds(self):
@torch.jit.script
def test_batch_mm(n: int):
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
T9 = torch.zeros((n, n))
T10 = torch.zeros((n, n))
torch.relu_(T1)
result = {}
result["no_mutated_parameters"] = (
torch.mm(T2, T3)
+ torch.mm(T4, T5)
+ torch.mm(T6, T7)
+ torch.mm(T8, T9)
)
result["all_parameters"] = (
torch.mm(T1, T2)
+ torch.mm(T3, T4)
+ torch.mm(T5, T6)
+ torch.mm(T7, T8)
+ torch.mm(T9, T10)
)
return result
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
"aten::mm", 5, exactly=True
).run(test_batch_mm.graph)
def test_batch_mm_prohibited_mutation_if_node(self):
@torch.jit.script
def test_batch_mm(n: int, use_t1: bool):
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
T9 = torch.zeros((n, n))
T10 = torch.zeros((n, n))
if use_t1:
torch.relu_(T1)
return (
torch.mm(T1, T2)
+ torch.mm(T3, T4)
+ torch.mm(T5, T6)
+ torch.mm(T7, T8)
+ torch.mm(T9, T10)
)
else:
return (
torch.mm(T2, T3)
+ torch.mm(T4, T5)
+ torch.mm(T6, T7)
+ torch.mm(T8, T9)
)
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
"prim::MMTreeReduce", 1, exactly=True
).run(test_batch_mm.graph)
def test_batch_mm_side_permitted_mutation(self):
@torch.jit.script
def test_batch_mm(n: int):
result = {}
A = torch.zeros((n, n))
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
result["T1"] = torch.mm(A, T1)
result["T2"] = torch.mm(A, T2)
result["T3"] = torch.mm(A, T3)
result["T4"] = torch.mm(A, T4)
result["T5"] = torch.mm(A, T5)
result["T6"] = torch.mm(A, T6)
result["T7"] = torch.mm(A, T7)
result["T8"] = torch.mm(A, T8)
return result
FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
"aten::mm"
).run(test_batch_mm.graph)
def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
@torch.jit.script
def test_batch_mm(n: int):
A = torch.zeros((n, n))
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
T9 = torch.zeros((n, n))
T10 = torch.zeros((n, n))
torch.relu_(T1)
result = {}
result["T1"] = torch.mm(A, T1)
result["T2"] = torch.mm(A, T2)
result["T3"] = torch.mm(A, T3)
result["T4"] = torch.mm(A, T4)
result["T5"] = torch.mm(A, T5)
result["T6"] = torch.mm(A, T6)
result["T7"] = torch.mm(A, T7)
result["T8"] = torch.mm(A, T8)
result["T9"] = torch.mm(A, T9)
result["T10"] = torch.mm(A, T10)
return result
FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
test_batch_mm.graph
)
def test_batch_mm_side_prohibited_mutation_common_side(self):
@torch.jit.script
def test_batch_mm(n: int):
A = torch.zeros((n, n))
T1 = torch.zeros((n, n))
T2 = torch.zeros((n, n))
T3 = torch.zeros((n, n))
T4 = torch.zeros((n, n))
T5 = torch.zeros((n, n))
T6 = torch.zeros((n, n))
T7 = torch.zeros((n, n))
T8 = torch.zeros((n, n))
T9 = torch.zeros((n, n))
T10 = torch.zeros((n, n))
torch.relu_(A)
result = {}
result["T1"] = torch.mm(A, T1)
result["T2"] = torch.mm(A, T2)
result["T3"] = torch.mm(A, T3)
result["T4"] = torch.mm(A, T4)
result["T5"] = torch.mm(A, T5)
result["T6"] = torch.mm(A, T6)
result["T7"] = torch.mm(A, T7)
result["T8"] = torch.mm(A, T8)
result["T9"] = torch.mm(A, T9)
result["T10"] = torch.mm(A, T10)
return result
FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
self.run_pass("batch_mm", test_batch_mm.graph)
FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
"prim::MMBatchSide"
).run(test_batch_mm.graph)