blob: 8062782e06959d592277c3e93d72b97f7b8a8519 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
const = torch.tensor(0.0)
device = "cuda"
class TestReinplacingPassCorrectness(TestCase):
def _test(self, f):
nf = torch.compile(f)
inp = (
torch.randn(4, device=device),
torch.ones(2, device=device, dtype=torch.int),
)
inp2 = (inp[0].clone(), inp[1].clone())
self.assertEqual(f(*inp), nf(*inp2))
# breakpoint()
self.assertEqual(inp, inp2)
def test_dont_modify_live(self):
def f(x, y):
x = x.cos()
x2 = x.index_put((y,), const)
return x2, x
self._test(f)
def test_dont_modify_view_of_live(self):
def f(x, y):
x = x.cos()
x2 = aten.alias(x)
x2 = x2.index_put((y,), const)
y = x2 + x.cos()
return y
self._test(f)
def test_dont_modify_input(self):
def f(x, y):
return x.index_put((y,), const)
self._test(f)
def test_should_modify_inner(self):
def f(x, y):
x = x.cos()
x = x.index_put((y,), const)
return x
self._test(f)
def test_should_modify_input(self):
def f(x, y):
x = x.index_put_((y,), const)
return x
self._test(f)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()