blob: a83d61267e747b21d1966dae2b43239c646b2986 [file] [log] [blame] [edit]
# Owner(s): ["module: cuda graphs"]
import functools
import unittest
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch.testing._internal.common_utils import TEST_CUDA_GRAPH
def composed(*decs):
def deco(f):
for dec in reversed(decs):
f = dec(f)
return f
return deco
def assert_aot_autograd_counter(ok=True):
def deco(f):
@functools.wraps(f)
def wrap(self, *args, **kwargs):
torch._dynamo.utils.counters.clear()
r = f(self, *args, **kwargs)
c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"]
c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"]
if ok:
self.assertGreater(c_ok, 0)
self.assertEqual(c_not_ok, 0)
else:
self.assertEqual(c_ok, 0)
self.assertGreater(c_not_ok, 0)
return r
return wrap
return deco
def patch_all(ok=True):
return composed(
torch._dynamo.config.patch(
verify_correctness=True, automatic_dynamic_shapes=True
),
assert_aot_autograd_counter(ok),
)
N_ITERS = 5
@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
@patch_all()
def test_basic(self):
def model(x, y):
return (x + y) * y
@torch._dynamo.optimize("cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
loss.backward()
x = torch.randn(3, device="cuda", requires_grad=True)
y = torch.randn(3, device="cuda")
fn(x, y)
@patch_all()
def test_dtoh(self):
def model(x, y):
a = x + y
b = a.cpu() * 3
return b
@torch._dynamo.optimize("cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
loss.backward()
x = torch.randn(3, device="cuda", requires_grad=True)
y = torch.randn(3, device="cuda")
fn(x, y)
@patch_all()
def test_htod(self):
def model(x, y):
a = x + y
return a * 3
@torch._dynamo.optimize("cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
loss.backward()
x = torch.randn(3, device="cuda", requires_grad=True)
y = torch.randn((), device="cpu")
fn(x, y)
def test_mutate_input(self):
def model(x, y):
y.add_(3)
return x * y
@torch._dynamo.optimize("cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
with self.subTest(i):
y_orig = y.clone()
loss = model(x, y).sum()
self.assertTrue(same(y, y_orig + 3))
loss.backward()
x = torch.randn(3, device="cuda", requires_grad=True)
y = torch.randn(3, device="cuda")
fn(x, y)
@patch_all()
def test_mutate_constant(self):
def model(x, y):
c = torch.tensor(1)
c.add_(2)
return x * y * 0 + c
@torch._dynamo.optimize("cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
with self.subTest(i):
loss = model(x, y).sum()
self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
loss.backward()
x = torch.randn(1, device="cuda", requires_grad=True)
y = torch.randn(1, device="cuda")
fn(x, y)
@patch_all()
def test_factory(self):
def model(y):
x = torch.zeros(3, device="cuda:0")
x.add_(3)
return x * y
@torch._dynamo.optimize("cudagraphs")
def fn(y):
for i in range(N_ITERS):
with self.subTest(i):
loss = model(y).sum()
loss.backward()
y = torch.randn(3, device="cuda:0", requires_grad=True)
fn(y)
@patch_all()
def test_mutated_metadata(self):
# more tortured example at
# https://github.com/pytorch/pytorch/issues/81385
def model(x):
x = x.clone()
x.resize_(20)
x.fill_(2)
return x
@torch._dynamo.optimize("cudagraphs")
def fn(x):
for i in range(N_ITERS):
with self.subTest(i):
rx = model(x)
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
x = torch.empty(0, device="cuda:0")
fn(x)
@patch_all()
def test_dead_fill(self):
def model(x):
x = x.clone()
y = x[0:0]
x.fill_(2)
y.fill_(3)
return x, y
@torch._dynamo.optimize("cudagraphs")
def fn(x):
for i in range(N_ITERS):
with self.subTest(i):
rx, ry = model(x)
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
x = torch.empty(20, device="cuda:0")
fn(x)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if not TEST_CUDA_GRAPH:
if __name__ == "__main__":
import sys
sys.exit(0)
raise unittest.SkipTest("cuda graph test is skipped")
run_tests()