| # Owner(s): ["module: inductor"] |
| import math |
| import sys |
| import unittest |
| |
| import torch |
| import torch._dynamo.config as dynamo_config |
| from torch import nn |
| from torch._dynamo.debug_utils import same_two_models |
| from torch._dynamo.testing import rand_strided |
| from torch._dynamo.utils import same |
| from torch._inductor import config |
| from torch._inductor.compile_fx import compile_fx_inner |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing._internal.common_utils import ( |
| DeterministicGuard, |
| IS_FBCODE, |
| TEST_WITH_ASAN, |
| ) |
| |
| try: |
| try: |
| import triton |
| from triton import language as tl |
| except ImportError: |
| raise unittest.SkipTest("requires triton") |
| |
| try: |
| from . import test_torchinductor |
| except ImportError: |
| import test_torchinductor |
| except unittest.SkipTest: |
| if __name__ == "__main__": |
| sys.exit(0) |
| raise |
| |
| |
| TestCase = test_torchinductor.TestCase |
| ToTuple = test_torchinductor.ToTuple |
| check_model_cuda = test_torchinductor.check_model_cuda |
| aten = torch.ops.aten |
| |
| |
| class CudaReproTests(TestCase): |
| common = check_model_cuda |
| |
| def test_index_put_issue(self): |
| def forward( |
| self, |
| arg76_1, |
| expand_default, |
| full_like_default, |
| _to_copy_default_67, |
| zeros, |
| ): |
| sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True) |
| view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768]) |
| where_self = torch.ops.aten.where.self( |
| expand_default, view_default_57, full_like_default |
| ) |
| clone_default_12 = torch.ops.aten.clone.default(zeros) |
| index_put__default = torch.ops.aten.index_put_.default( |
| clone_default_12, [arg76_1], where_self, True |
| ) |
| return (index_put__default,) |
| |
| inps = [ |
| (torch.Size([512]), torch.int64), |
| (torch.Size([512, 768]), torch.bool), |
| (torch.Size([512, 768]), torch.float16), |
| (torch.Size([4, 512, 768]), torch.float16), |
| (torch.Size([512, 768]), torch.float16), |
| ] |
| inps = [torch.zeros(())] + [ |
| torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps |
| ] |
| mod = make_fx(forward)(*inps) |
| compiled = compile_fx_inner(mod, inps) |
| compiled(inps) |
| |
| def test_input_channels_last(self): |
| m = torch.nn.Sequential( |
| torch.nn.Conv2d(3, 3, 1, 1), |
| ToTuple(), |
| ).cuda() |
| inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() |
| |
| self.common( |
| m, |
| (inp,), |
| check_lowp=False, |
| ) |
| |
| @torch._dynamo.optimize() |
| def foo(m, inp): |
| return m(inp) |
| |
| self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last)) |
| |
| # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 |
| def test_unspec_inputs_interop(self): |
| class Repro(torch.nn.Module): |
| def forward(self, x, y): |
| unsqueeze = torch.ops.aten.unsqueeze.default(x, 4) |
| permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3]) |
| add = torch.ops.aten.add.Tensor(y, 1) |
| return [permute, add] |
| |
| inps = [ |
| rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"), |
| rand_strided((), (), torch.int64, "cpu"), |
| ] |
| mod = make_fx(Repro().to(device="cuda"))(*inps) |
| compiled = compile_fx_inner(mod, inps) |
| compiled(inps) |
| |
| @unittest.skipIf( |
| IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context" |
| ) |
| def test_backward_context(self): |
| def fn(x): |
| return x * 3 |
| |
| x = torch.randn(4, device="cuda", requires_grad=True) |
| gO = torch.rand_like(x) |
| opt_fn = torch.compile(fn) |
| out = opt_fn(x) |
| out.backward(gO) |
| |
| @config.patch(fallback_random=True) |
| def test_dtype_factory_issue(self): |
| def forward(): |
| randn = torch.ops.aten.randn.default( |
| [12, 64, 1, 64], |
| dtype=torch.float32, |
| device=torch.device(type="cuda", index=0), |
| pin_memory=False, |
| ) |
| unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1) |
| return (unsqueeze_default_2,) |
| |
| mod = make_fx(forward)() |
| compiled = compile_fx_inner(mod, ()) |
| assert compiled([])[0].device.type == "cuda" |
| |
| @config.patch({"triton.cudagraphs": True}) |
| @dynamo_config.patch(automatic_dynamic_shapes=True) |
| def test_no_device_idx_repro_cudagraphs(self): |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self): |
| full = torch.ops.aten.full.default( |
| [8, 512], |
| 1, |
| dtype=torch.float32, |
| layout=torch.strided, |
| device=torch.device(type="cuda", index=0), |
| pin_memory=False, |
| ) |
| full_1 = torch.ops.aten.full.default( |
| [8, 512], |
| 0, |
| dtype=torch.int64, |
| layout=torch.strided, |
| device=torch.device(type="cuda", index=0), |
| pin_memory=False, |
| ) |
| return (full_1, full) |
| |
| self.common(Repro(), ()) |
| |
| @config.patch({"triton.cudagraphs": True}) |
| @dynamo_config.patch(automatic_dynamic_shapes=True) |
| def test_expanded_inputs_cudagraphs(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(x, y): |
| return x + y |
| |
| inputs = ( |
| rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), |
| rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), |
| ) |
| self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) |
| |
| @config.patch({"triton.cudagraphs": True}) |
| @dynamo_config.patch( |
| automatic_dynamic_shapes=True, |
| assume_static_by_default=False, |
| ) |
| def test_dynamic_to_static_cudagraphs(self): |
| for b in [False, True]: |
| with config.patch({"triton.cudagraph_trees": b}): |
| |
| @torch._dynamo.optimize("inductor") |
| def fn(x, y): |
| r = x + y |
| return r, r.size(0) |
| |
| inputs = ( |
| torch.randn((5, 5), device="cuda"), |
| torch.randn((5, 5), device="cuda"), |
| ) |
| self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5))) |
| |
| inputs = ( |
| torch.randn((6, 6), device="cuda"), |
| torch.randn((6, 6), device="cuda"), |
| ) |
| self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6))) |
| |
| # TODO: Abstract this out, test more extensively |
| @torch._dynamo.config.patch(assume_static_by_default=False) |
| def test_dynamic_shapes(self): |
| torch._dynamo.reset() # Needed since everywhere else uses "inductor" |
| |
| def f(x): |
| return x.cos().view(x.shape).sin() |
| |
| cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor") |
| |
| f2 = torch._dynamo.optimize(cnts)(f) |
| |
| f2(torch.randn(32)) |
| |
| inp = torch.randn(16) |
| real_out = f(inp) |
| compiled_out = f2(inp) |
| |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(real_out, compiled_out) |
| torch._dynamo.reset() |
| |
| @config.patch({"triton.cudagraphs": True, "size_asserts": False}) |
| @dynamo_config.patch(automatic_dynamic_shapes=True) |
| def test_expanded_inputs_cudagraphs_no_size_asserts(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(x, y): |
| return x + y |
| |
| inputs = ( |
| rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), |
| rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), |
| ) |
| self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) |
| |
| @config.patch({"triton.cudagraph_trees": False}) |
| @config.patch({"triton.cudagraphs": True}) |
| @dynamo_config.patch(automatic_dynamic_shapes=True) |
| def test_inplace_updates_cudagraphs(self): |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight1 = torch.nn.Parameter( |
| torch.randn(10, 20, requires_grad=True) |
| ) |
| |
| def forward(self, x): |
| x = torch.matmul(x, self.weight1) |
| return x |
| |
| from copy import deepcopy |
| |
| model = Repro().cuda() |
| model_ref = deepcopy(model) |
| model_opt = torch._dynamo.optimize("inductor")(model) |
| |
| input = torch.randn(10, 10, device="cuda", requires_grad=True) |
| |
| for i in range(2): |
| output_ref = model_ref(input) |
| output_res = model_opt(input) |
| output_ref.sum().backward() |
| output_res.sum().backward() |
| for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()): |
| self.assertEqual(p_ref.grad, p_res.grad) |
| with torch.no_grad(): |
| for param in model_ref.parameters(): |
| param.add_(1.0) |
| for param in model_opt.parameters(): |
| param.add_(1.0) |
| |
| # https://github.com/pytorch/torchdynamo/issues/1850 |
| def test_inductor_output_aliases_intermediate(self): |
| def foo(x): |
| out = x + x |
| return out.t() |
| |
| foo_opt = torch._dynamo.optimize("inductor")(foo) |
| |
| inpt = torch.randn(10, 10, device="cuda", requires_grad=True) |
| # TODO: this is broken, fix later |
| # out = foo_opt(inpt) |
| # out.add_(2) |
| |
| out_ref = foo(inpt) |
| out_ref.add_(2) |
| # self.assertEqual(out_ref, out) |
| |
| def test_accuracy_issue1(self): |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear( |
| in_features=768, out_features=2, bias=True |
| ) |
| |
| def forward(self, start_positions: torch.Tensor, x: torch.Tensor): |
| linear = self.linear(x) |
| split = linear.split(1, dim=-1) |
| getitem = split[0] |
| squeeze = getitem.squeeze(-1) |
| clamp = start_positions.clamp(0, 128) |
| cross_entropy = torch.nn.functional.cross_entropy( |
| squeeze, clamp, None, None, 128, None, "mean", 0.0 |
| ) |
| return cross_entropy |
| |
| mod = Repro().cuda() |
| opt_mod = torch._dynamo.optimize("inductor")(mod) |
| mod.eval() |
| opt_mod.eval() |
| |
| args = [ |
| ((1,), (1,), torch.int64, "cuda", False), |
| ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True), |
| ] |
| args = [ |
| rand_strided(sh, st, dt, dev).requires_grad_(rg) |
| for (sh, st, dt, dev, rg) in args |
| ] |
| with torch.cuda.amp.autocast(enabled=False): |
| assert same_two_models(mod, opt_mod, args), "Dynamo failed" |
| |
| @config.patch(allow_buffer_reuse=False) |
| def test_issue103461(self): |
| def forward(add_1): |
| var_mean = torch.ops.aten.var_mean.correction( |
| add_1, [2], correction=0, keepdim=True |
| ) |
| getitem_1 = var_mean[1] |
| return getitem_1 |
| |
| x = torch.randn(1, 8, 768, device="cuda") |
| correct = forward(x) |
| actual = torch.compile(forward, fullgraph=True)(x) |
| self.assertEqual(actual, correct) |
| |
| def test_autotune_inplace_kernel(self): |
| """ |
| This UT tests autotune on an inplace kernel. The autotune should not contaminate |
| the input buffers when tuning with multiple configs. For more details, refer to |
| https://github.com/openai/triton/issues/781 |
| https://github.com/pytorch/torchdynamo/issues/1670 |
| """ |
| from torch._C import _cuda_getCurrentRawStream as get_cuda_stream |
| from torch._inductor.triton_heuristics import ( |
| CachingAutotuner, |
| grid, |
| HeuristicType, |
| ) |
| from torch._inductor.utils import instance_descriptor |
| |
| def autotune(configs, meta): |
| def decorator(fn): |
| return CachingAutotuner( |
| # force autotune by setting save_cache_hook to False |
| fn, |
| meta=meta, |
| configs=configs, |
| save_cache_hook=False, |
| mutated_arg_names=["in_out_ptr0"], |
| heuristic_type=HeuristicType.POINTWISE, |
| ) |
| |
| return decorator |
| |
| @autotune( |
| configs=[ |
| triton.Config({"XBLOCK": 1}), |
| triton.Config({"XBLOCK": 2}), |
| ], |
| meta={ |
| "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, |
| "device": 0, |
| "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], |
| "constants": {}, |
| }, |
| ) |
| @triton.jit |
| def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): |
| pid = tl.program_id(0) |
| block_start = pid * XBLOCK |
| offsets = block_start + tl.arange(0, XBLOCK) |
| mask = offsets < xnumel |
| x = tl.load(in_out_ptr0 + offsets, mask=mask) |
| y = tl.load(in_ptr0 + offsets, mask=mask) |
| output = x + y |
| tl.store(in_out_ptr0 + offsets, output, mask=mask) |
| |
| xnumel = 384 |
| in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) |
| inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) |
| inout2 = inout1.clone() |
| |
| stream0 = get_cuda_stream(0) |
| kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0) |
| kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0) |
| |
| assert same( |
| inout1, inout2, tol=0.001, equal_nan=True |
| ), "failed autotune with inplace kernel" |
| |
| def test_sort_stride_issue(self): |
| # This minified testcase comes from detectron2_maskrcnn_r_50_fpn |
| # There was a false error from our size_assert code |
| @torch._dynamo.optimize(nopython=True) |
| def forward(pred_objectness_logits_3_: torch.Tensor): |
| sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1) |
| getitem_12 = sort_3[0] |
| return getitem_12 |
| |
| args = [((1, 100), (0, 1), torch.float16, "cuda", False)] |
| args = [ |
| rand_strided(sh, st, dt, dev).requires_grad_(rg) |
| for (sh, st, dt, dev, rg) in args |
| ] |
| result = forward(*args) |
| assert same(result, torch.sort(args[0], descending=True, dim=1)[0]) |
| |
| def test_scalar_triton_index(self): |
| # The indirect indexing via a scalar like below used to lead to |
| # bad triton code that made triton segfault when compiling. |
| # See https://github.com/pytorch/torchdynamo/issues/1515 |
| def fn(a): |
| zero = torch.zeros((16,), device=a.device, dtype=torch.int64) |
| return (a[zero],) |
| |
| a = torch.randn((8,), dtype=torch.float32, device="cuda") |
| |
| fn_optimized = torch._dynamo.optimize("inductor")(fn) |
| assert same(fn(a), fn_optimized(a)) |
| |
| def test_indirect_indexing_dense_mask(self): |
| def fn(x, y): |
| ne = torch.ops.aten.ne.Scalar(x, 1) |
| sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1]) |
| sub = torch.ops.aten.sub.Tensor(sum_1, 1) |
| unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1) |
| gather = torch.ops.aten.gather.default(x, 1, unsqueeze) |
| squeeze = torch.ops.aten.squeeze.default(gather) |
| out = torch.ops.aten.multiply(y, squeeze) |
| return (out,) |
| |
| a = torch.zeros((1, 128), dtype=torch.int64, device="cuda") |
| b = torch.zeros((1, 128), dtype=torch.int64, device="cuda") |
| |
| fn_optimized = torch._dynamo.optimize("inductor")(fn) |
| assert same(fn(a, b), fn_optimized(a, b)) |
| |
| def test_simplify_dims(self): |
| def fn(a): |
| return (a + 1,) |
| |
| self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],)) |
| |
| @config.patch(permute_fusion=True) |
| def test_permute_fusion(self): |
| class Repro(torch.nn.Module): |
| def forward(self, view, reshape_2): |
| permute = view.permute(0, 2, 1) |
| view = None |
| reshape = torch.reshape(permute, (-1, 642)) |
| bmm = torch.bmm(permute, reshape_2) |
| return (bmm,) |
| |
| args = [ |
| ((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True), |
| ((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True), |
| ] |
| args = [ |
| rand_strided(sh, st, dt, dev).requires_grad_(rg) |
| for (sh, st, dt, dev, rg) in args |
| ] |
| |
| mod = Repro() |
| opt_mod = torch._dynamo.optimize("inductor")(mod) |
| |
| ref = mod(*args) |
| res = opt_mod(*args) |
| self.assertTrue(same(ref, res)) |
| |
| @config.patch({"triton.autotune_pointwise": True}) |
| def test_inplace_add_alpha_autotune(self): |
| def fn(x, y): |
| aten.add_.Tensor(x, y, alpha=0.55) |
| return (x,) |
| |
| x1 = torch.zeros(2, 3, 4, 10, device="cuda") |
| x2 = torch.zeros(2, 3, 4, 10, device="cuda") |
| x3 = torch.zeros(2, 3, 4, 10, device="cuda") |
| y = torch.randn(2, 3, 4, 10, device="cuda").to( |
| memory_format=torch.channels_last |
| ) |
| fn_fx = make_fx(fn)(x1, y) |
| fn_compiled = compile_fx_inner(fn_fx, [x1, y]) |
| fn(x2, y) |
| fn_compiled([x3, y]) |
| assert same(x2, x3) |
| |
| @config.patch({"triton.autotune_pointwise": True}) |
| def test_inplace_buffer_autotune(self): |
| def foo(x, y, z): |
| a = x @ y |
| return a.unsqueeze(0).unsqueeze(0) + z |
| |
| x = torch.zeros(5, 5, device="cuda") |
| y = torch.zeros(5, 5, device="cuda") |
| z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last) |
| self.common( |
| foo, |
| (x, y, z), |
| check_lowp=False, |
| ) |
| |
| def test_memory_history_inductor(self): |
| def called_inside_compile(x, w, b): |
| a = x @ w + b |
| return torch.sigmoid(a) |
| |
| @torch.compile |
| def fn(x, w, b): |
| x = called_inside_compile(x, w, b) |
| return called_inside_compile(x, w, b) |
| |
| w = torch.rand(3, 3, device="cuda") |
| b = torch.rand(3, device="cuda") |
| x = torch.rand(3, device="cuda") |
| try: |
| torch.cuda.memory.empty_cache() |
| torch.cuda.memory._record_memory_history(True) |
| r = fn(x, w, b) |
| finally: |
| torch.cuda.memory._record_memory_history(False) |
| snapshot = str(torch.cuda.memory._snapshot()) |
| self.assertTrue("called_inside_compile" in snapshot) |
| |
| def test_negative_arange_dynamic_shapes(self): |
| # Repro from alibi relative encodings |
| def sign(x): |
| return (x > 0) - (x < 0) |
| |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| nheads = 16 |
| start = math.log2(0.5) |
| end = math.log2(1 / (2**8)) |
| |
| self.scales = nn.Buffer( |
| 2 |
| ** torch.arange( |
| start, |
| end + 1e-6 * sign(end - start), |
| (end - start) / (nheads - 1), |
| ).view(1, nheads, 1, 1), |
| ) |
| self.emb = nn.Embedding(1024, 256) |
| self.dec_layer = nn.TransformerDecoderLayer( |
| 256, 16, 512, batch_first=True, norm_first=True |
| ) |
| self.head = nn.Linear(256, 1024) |
| |
| def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor): |
| padmask = dec_in == 0 |
| dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2) |
| dec_mask = dec_mask.to(dtype=torch.float32) |
| dec_mask = dec_mask.tril(diagonal=0).cuda() |
| |
| q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") |
| k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") |
| rel_pos = k_pos[None, :] - q_pos[:, None] |
| values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0) |
| dec_bias = values * self.scales |
| dec_bias.tril_(diagonal=0) |
| |
| dec_mask = dec_mask + dec_bias[0] |
| out = self.emb(dec_in) |
| out = self.dec_layer(out, enc_out, tgt_mask=dec_mask) |
| return self.head(out) |
| |
| mod = Repro().cuda() |
| opt_mod = torch._dynamo.optimize("inductor", dynamic=True)(mod) |
| mod.eval() |
| opt_mod.eval() |
| |
| enc_out = torch.rand(1, 512, 256).cuda() |
| dec_inputs = [ |
| torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8) |
| ] |
| |
| for dec_inp in dec_inputs: |
| assert same_two_models( |
| mod, opt_mod, [enc_out, dec_inp], only_fwd=True |
| ), "Inductor with dynamic shapes failed" |
| |
| def test_issue97695_1input(self): |
| def fn(arg3_1, relu, permute_1): |
| addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1) |
| cat_2 = torch.ops.aten.cat.default([addmm_1], 1) |
| return (cat_2,) |
| |
| args = [ |
| ((96,), (1,), torch.float32, "cuda"), |
| ((10, 256), (256, 1), torch.float32, "cuda"), |
| ((256, 96), (1, 256), torch.float32, "cuda"), |
| ] |
| args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] |
| correct = fn(*args) |
| |
| mod = make_fx(fn, tracing_mode="real")(*args) |
| compiled = compile_fx_inner(mod, args) |
| ref = compiled(list(args)) |
| assert same(ref, correct) |
| |
| ref = torch.compile(fn, fullgraph=True)(*args) |
| assert same(ref, correct) |
| |
| def test_issue_103924(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.temperature = 1 |
| self.layer = torch.nn.Softmax(dim=1) |
| |
| def forward(self, x): |
| n_samples, _ = x.shape |
| y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device) |
| inp = x / y[..., None] |
| return self.layer(inp) |
| |
| x = torch.rand([4, 4], device="cuda") |
| m = MyModule() |
| opt_m = torch.compile(backend="inductor")(m) |
| self.assertEqual(opt_m(x), m(x)) |
| |
| def test_issue97695_2input(self): |
| def fn(arg3_1, arg3_2, relu, permute_1): |
| addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1) |
| addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1) |
| cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1) |
| return (cat_2,) |
| |
| args = [ |
| ((96,), (1,), torch.float32, "cuda"), |
| ((96,), (1,), torch.float32, "cuda"), |
| ((10, 256), (256, 1), torch.float32, "cuda"), |
| ((256, 96), (1, 256), torch.float32, "cuda"), |
| ] |
| args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] |
| correct = fn(*args) |
| |
| ref = torch.compile(fn, fullgraph=True)(*args) |
| assert same(ref, correct) |
| |
| def test_embedding_var_mean(self): |
| def forward(arg0_1): |
| full = torch.ops.aten.full.default( |
| [1, 2048], |
| 1, |
| dtype=torch.float32, |
| layout=torch.strided, |
| device=torch.device(type="cuda", index=0), |
| pin_memory=False, |
| ) |
| convert_element_type_1 = torch.ops.prims.convert_element_type.default( |
| full, torch.int64 |
| ) |
| cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1) |
| mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1) |
| sub_1 = torch.ops.aten.sub.Tensor(mul, 1) |
| slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807) |
| slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807) |
| add_2 = torch.ops.aten.add.Tensor(slice_6, 2) |
| embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2) |
| var_mean = torch.ops.aten.var_mean.correction( |
| embedding_1, [2], correction=0, keepdim=True |
| ) |
| return [var_mean[0], var_mean[1], add_2] |
| |
| emb = torch.randn([2050, 768], device="cuda") |
| gm = make_fx(forward)(emb) |
| opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb]) |
| opt([emb]) |
| torch.cuda.synchronize() |
| |
| def test_deterministic_algorithms(self): |
| N = 10000 |
| |
| @torch.compile |
| def fn(idx, values): |
| x = torch.zeros(1, device="cuda") |
| x[idx] += values |
| return x |
| |
| idx = torch.zeros(N, dtype=torch.int64, device="cuda") |
| values = torch.randn(N, device="cuda") |
| |
| r0 = fn(idx, values) |
| with DeterministicGuard(True): |
| r1 = fn(idx, values) |
| for _ in range(10): |
| rn = fn(idx, values) |
| self.assertEqual(r1, rn, atol=0, rtol=0) |
| |
| # https://github.com/pytorch/pytorch/issues/96406 |
| def test_linear_cpu_input(self): |
| class Model(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(4, 4) |
| |
| def forward(self, data): |
| data = data.to("cuda") |
| return self.linear(data) |
| |
| mod = Model().cuda().eval() |
| with torch.no_grad(): |
| self.common(mod, (torch.randn(4, 4),)) |
| |
| @config.patch({"fallback_random": True, "triton.cudagraphs": True}) |
| def test_xlnet_lm_stride_repro(self): |
| class Repro(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = nn.Dropout(p=0.1, inplace=False) |
| |
| def forward(self, x): |
| y = torch._C._nn.gelu(x) |
| return self.dropout(y) |
| |
| mod = Repro() |
| x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda") |
| y = torch.compile(mod)(x) |
| # Inductor claims the output layout of gelu's saved variable for |
| # backwards will be (4096, 4096, 1) but in actuality it is (4096, |
| # 2097152, 1). Fortunately this doesn't actually matter in practice. |
| y.sum().backward() |
| |
| def test_lookup_seed_backward(self): |
| @torch.compile(fullgraph=True) |
| def forward(inductor_seeds, mul_4, view_15): |
| inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default( |
| inductor_seeds, 2 |
| ) |
| inductor_random_2 = torch.ops.prims.inductor_random.default( |
| [2, 512, 768], inductor_lookup_seed_2, "rand" |
| ) |
| gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1) |
| mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15) |
| mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112) |
| add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4) |
| var_mean_1 = torch.ops.aten.var_mean.correction( |
| add_5, [2], correction=0, keepdim=True |
| ) |
| getitem_3 = var_mean_1[1] |
| sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3) |
| return (sub_3,) |
| |
| buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda") |
| buf1 = torch.zeros((2, 512, 768), device="cuda") |
| buf2 = torch.zeros((2, 512, 768), device="cuda") |
| forward(buf0, buf1, buf2) |
| |
| def test_issue100806(self): |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 20) |
| self.linear2 = torch.nn.Linear(20, 30) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.linear2(x) |
| x = torch.cat((x, x), dim=1) |
| x = x.view(-1, 2, 30) |
| x = x[:, 1, :] |
| x = self.relu(x) |
| return x |
| |
| device = "cuda" |
| batch_size = 2 |
| x = torch.randn(batch_size, 10).to(device) |
| func = Model().to(device) |
| |
| with torch.no_grad(): |
| func.train(False) |
| jit_func = torch.compile(func) |
| |
| res1 = func(x) |
| res2 = jit_func(x) |
| self.assertEqual(res1, res2) |
| |
| def test_issue103481(self): |
| def fn(x, y): |
| # NOTE: 6 dimensions is important! does not fail for 5 dimensions |
| mean = torch.mean(x, [2, 3, 4, 5], keepdim=True) |
| add = mean + y |
| return add |
| |
| x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda") |
| y = torch.rand((), device="cuda") |
| expect = fn(x, y) |
| |
| opt_fn = torch.compile(fn) |
| actual = opt_fn(x, y) |
| |
| self.assertEqual(expect, actual) |
| |
| @config.patch({"triton.dense_indexing": True}) |
| @dynamo_config.patch(automatic_dynamic_shapes=True) |
| def test_bucketize_dynamic_dense(self): |
| """ |
| Make sure that ops.bucketize() can handle dense_indexing, which previously |
| caused issues due to incorrect handling of the size of offsets. |
| """ |
| |
| def fn(values, offsets): |
| return torch.ops.prims._inductor_bucketize(values, offsets) |
| |
| values = torch.rand((64, 64), device="cuda") |
| offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda") |
| |
| expect = fn(values, offsets) |
| |
| opt_fn = torch.compile(fn, dynamic=True) |
| actual = opt_fn(values, offsets) |
| |
| self.assertEqual(expect, actual) |
| |
| def test_float64_constants(self): |
| def fn(): |
| # NOTE: tensors of all the same value are constant folded, so we |
| # need a tensor with two distinct values |
| a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda") |
| return a * 2e50 |
| |
| cfn = torch.compile(fn) |
| expect = fn() |
| actual = cfn() |
| self.assertEqual(expect, actual, atol=0, rtol=0) |
| |
| def test_issue104759(self): |
| def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): |
| slice_scatter_4 = torch.ops.aten.slice_scatter.default( |
| permute_2, select_scatter, 0, 1, 9223372036854775807 |
| ) |
| permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4]) |
| view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48]) |
| view_7 = torch.ops.aten.view.default(view_6, [1000, 48]) |
| view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48]) |
| view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4]) |
| permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4]) |
| slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807) |
| slice_scatter_5 = torch.ops.aten.slice_scatter.default( |
| slice_8, slice_7, 4, 0, 9223372036854775807 |
| ) |
| slice_scatter_6 = torch.ops.aten.slice_scatter.default( |
| arg7_1, slice_scatter_5, 3, 0, 1000 |
| ) |
| mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476) |
| slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000) |
| slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807) |
| select_2 = torch.ops.aten.select.int(slice_10, 0, 0) |
| permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2]) |
| mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476) |
| expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4]) |
| view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4]) |
| expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000]) |
| view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000]) |
| bmm = torch.ops.aten.bmm.default(view_10, view_11) |
| return (bmm,) |
| |
| args = [] |
| args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda")) |
| args.append( |
| rand_strided( |
| (1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda" |
| ) |
| ) |
| args.append( |
| rand_strided( |
| (3, 1, 4, 1000, 4), |
| (16, 48000, 4, 48, 1), |
| dtype=torch.float16, |
| device="cuda", |
| ) |
| ) |
| args.append( |
| rand_strided( |
| (2, 1, 4, 1000, 4), |
| (16, 48000, 4, 48, 1), |
| dtype=torch.float16, |
| device="cuda", |
| ) |
| ) |
| args.append( |
| rand_strided( |
| (2, 1, 4, 1000, 4), |
| (19200, 19200, 4800, 4, 1), |
| dtype=torch.float16, |
| device="cuda", |
| ) |
| ) |
| |
| correct = fn(*args) |
| mod = make_fx(fn, tracing_mode="real")(*args) |
| compiled = compile_fx_inner(mod, args) |
| ref = compiled(list(args)) |
| assert same(ref, correct) |
| |
| @config.patch({"triton.cudagraphs": True}) |
| def test_index_put_inplace_cudagraph(self): |
| def fn(x, y, z): |
| x = torch.zeros_like(x) |
| return x.index_put_([y], z, True) |
| |
| x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) |
| y = torch.zeros((512,), device="cuda", dtype=torch.int64) |
| z = torch.ones((512, 512), device="cuda", dtype=torch.bool) |
| |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| |
| ref = fn(x, y, z) |
| |
| # run it twice to test cuda graph issue |
| res = opt_fn(x, y, z) |
| res = opt_fn(x, y, z) |
| |
| self.assertEqual(ref, res) |
| |
| @config.patch({"triton.cudagraphs": True}) |
| def test_index_put_cudagraph(self): |
| def fn(x, y, z): |
| x = torch.zeros_like(x) |
| return x.index_put([y], z, True) |
| |
| x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) |
| y = torch.zeros((512,), device="cuda", dtype=torch.int64) |
| z = torch.ones((512, 512), device="cuda", dtype=torch.bool) |
| |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| |
| ref = fn(x, y, z) |
| |
| # run it twice to test cuda graph issue |
| res = opt_fn(x, y, z) |
| res = opt_fn(x, y, z) |
| |
| self.assertEqual(ref, res) |
| |
| @config.patch({"triton.cudagraphs": True}) |
| def test_index_put_no_fallback_cudagraph(self): |
| def fn(x, y, z): |
| x = torch.zeros_like(x) |
| return x.index_put([y], z, True) |
| |
| x = torch.zeros((512, 512), device="cuda", dtype=torch.int32) |
| y = torch.zeros((512,), device="cuda", dtype=torch.int64) |
| z = torch.ones((512, 512), device="cuda", dtype=torch.int32) |
| |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| |
| ref = fn(x, y, z) |
| |
| # run it twice to test cuda graph issue |
| res = opt_fn(x, y, z) |
| res = opt_fn(x, y, z) |
| |
| self.assertEqual(ref, res) |
| |
| # https://github.com/pytorch/pytorch/issues/104937 |
| def test_linear_with_zero_infeature_size(self): |
| m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda") |
| x = torch.rand(1, 1, 0, device="cuda") |
| expect = m(x) |
| opt_fn = torch.compile(m) |
| actual = opt_fn(x) |
| self.assertEqual(expect, actual) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| if HAS_CUDA and not TEST_WITH_ASAN: |
| run_tests(needs="filelock") |