blob: 2797bea8ceb13238e36dc87ebd1aacf5657cc030 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import torch
from torch._dynamo import config as dynamo_config
from torch._inductor import config as inductor_config
from torch.testing._internal.common_utils import IS_LINUX, TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
class TestUnbackedSymints(TorchTestCase):
def test_expand(self):
def fn(x, y):
nz = torch.nonzero(x)
# unbacked symint in nz.size
x_exp = nz.expand([-1, 128])
# unbacked symint in target sizes
y_exp = y.expand([-1, nz.size(0)])
return x_exp, y_exp
example_inputs = (
torch.randn((32), device="cuda"),
torch.randn((32, 1), device="cuda"),
)
with dynamo_config.patch({"capture_dynamic_output_shape_ops": True}):
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
def test_autotuning(self):
def fn(x, y):
nz = torch.nonzero(x)
# unbacked symint in the GEMM input shape
a = x.new_ones([nz.size(0), y.size(0)])
return a @ y
example_inputs = (
torch.randn((64), device="cuda"),
torch.randn((32, 16), device="cuda"),
)
with dynamo_config.patch({"capture_dynamic_output_shape_ops": True}):
with inductor_config.patch(
{
"max_autotune_gemm": True,
"max_autotune_gemm_backends": "TRITON",
}
):
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
from torch._inductor.utils import is_big_gpu
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()