blob: 8af97e3baa4144ef2fe2a7860bb5e4ff6361343a [file] [log] [blame]
# Owner(s): ["module: inductor"]
import math
import sys
import unittest
import torch
import torch._dynamo.config as dynamo_config
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_FBCODE,
TEST_WITH_ASAN,
)
try:
try:
import triton
from triton import language as tl
except ImportError:
raise unittest.SkipTest("requires triton")
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
TestCase = test_torchinductor.TestCase
ToTuple = test_torchinductor.ToTuple
check_model_cuda = test_torchinductor.check_model_cuda
aten = torch.ops.aten
class CudaReproTests(TestCase):
common = check_model_cuda
def test_index_put_issue(self):
def forward(
self,
arg76_1,
expand_default,
full_like_default,
_to_copy_default_67,
zeros,
):
sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True)
view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768])
where_self = torch.ops.aten.where.self(
expand_default, view_default_57, full_like_default
)
clone_default_12 = torch.ops.aten.clone.default(zeros)
index_put__default = torch.ops.aten.index_put_.default(
clone_default_12, [arg76_1], where_self, True
)
return (index_put__default,)
inps = [
(torch.Size([512]), torch.int64),
(torch.Size([512, 768]), torch.bool),
(torch.Size([512, 768]), torch.float16),
(torch.Size([4, 512, 768]), torch.float16),
(torch.Size([512, 768]), torch.float16),
]
inps = [torch.zeros(())] + [
torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps
]
mod = make_fx(forward)(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(inps)
def test_input_channels_last(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 1, 1),
ToTuple(),
).cuda()
inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
self.common(
m,
(inp,),
check_lowp=False,
)
@torch._dynamo.optimize()
def foo(m, inp):
return m(inp)
self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last))
# https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
def test_unspec_inputs_interop(self):
class Repro(torch.nn.Module):
def forward(self, x, y):
unsqueeze = torch.ops.aten.unsqueeze.default(x, 4)
permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3])
add = torch.ops.aten.add.Tensor(y, 1)
return [permute, add]
inps = [
rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"),
rand_strided((), (), torch.int64, "cpu"),
]
mod = make_fx(Repro().to(device="cuda"))(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(inps)
@unittest.skipIf(
IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context"
)
def test_backward_context(self):
def fn(x):
return x * 3
x = torch.randn(4, device="cuda", requires_grad=True)
gO = torch.rand_like(x)
opt_fn = torch.compile(fn)
out = opt_fn(x)
out.backward(gO)
@config.patch(fallback_random=True)
def test_dtype_factory_issue(self):
def forward():
randn = torch.ops.aten.randn.default(
[12, 64, 1, 64],
dtype=torch.float32,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1)
return (unsqueeze_default_2,)
mod = make_fx(forward)()
compiled = compile_fx_inner(mod, ())
assert compiled([])[0].device.type == "cuda"
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_no_device_idx_repro_cudagraphs(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
full = torch.ops.aten.full.default(
[8, 512],
1,
dtype=torch.float32,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
full_1 = torch.ops.aten.full.default(
[8, 512],
0,
dtype=torch.int64,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
return (full_1, full)
self.common(Repro(), ())
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs(self):
@torch._dynamo.optimize("inductor")
def fn(x, y):
return x + y
inputs = (
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(
automatic_dynamic_shapes=True,
assume_static_by_default=False,
)
def test_dynamic_to_static_cudagraphs(self):
for b in [False, True]:
with config.patch({"triton.cudagraph_trees": b}):
@torch._dynamo.optimize("inductor")
def fn(x, y):
r = x + y
return r, r.size(0)
inputs = (
torch.randn((5, 5), device="cuda"),
torch.randn((5, 5), device="cuda"),
)
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5)))
inputs = (
torch.randn((6, 6), device="cuda"),
torch.randn((6, 6), device="cuda"),
)
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6)))
# TODO: Abstract this out, test more extensively
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic_shapes(self):
torch._dynamo.reset() # Needed since everywhere else uses "inductor"
def f(x):
return x.cos().view(x.shape).sin()
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
f2 = torch._dynamo.optimize(cnts)(f)
f2(torch.randn(32))
inp = torch.randn(16)
real_out = f(inp)
compiled_out = f2(inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(real_out, compiled_out)
torch._dynamo.reset()
@config.patch({"triton.cudagraphs": True, "size_asserts": False})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs_no_size_asserts(self):
@torch._dynamo.optimize("inductor")
def fn(x, y):
return x + y
inputs = (
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@config.patch({"triton.cudagraph_trees": False})
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_inplace_updates_cudagraphs(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight1 = torch.nn.Parameter(
torch.randn(10, 20, requires_grad=True)
)
def forward(self, x):
x = torch.matmul(x, self.weight1)
return x
from copy import deepcopy
model = Repro().cuda()
model_ref = deepcopy(model)
model_opt = torch._dynamo.optimize("inductor")(model)
input = torch.randn(10, 10, device="cuda", requires_grad=True)
for i in range(2):
output_ref = model_ref(input)
output_res = model_opt(input)
output_ref.sum().backward()
output_res.sum().backward()
for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()):
self.assertEqual(p_ref.grad, p_res.grad)
with torch.no_grad():
for param in model_ref.parameters():
param.add_(1.0)
for param in model_opt.parameters():
param.add_(1.0)
# https://github.com/pytorch/torchdynamo/issues/1850
def test_inductor_output_aliases_intermediate(self):
def foo(x):
out = x + x
return out.t()
foo_opt = torch._dynamo.optimize("inductor")(foo)
inpt = torch.randn(10, 10, device="cuda", requires_grad=True)
# TODO: this is broken, fix later
# out = foo_opt(inpt)
# out.add_(2)
out_ref = foo(inpt)
out_ref.add_(2)
# self.assertEqual(out_ref, out)
def test_accuracy_issue1(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(
in_features=768, out_features=2, bias=True
)
def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
linear = self.linear(x)
split = linear.split(1, dim=-1)
getitem = split[0]
squeeze = getitem.squeeze(-1)
clamp = start_positions.clamp(0, 128)
cross_entropy = torch.nn.functional.cross_entropy(
squeeze, clamp, None, None, 128, None, "mean", 0.0
)
return cross_entropy
mod = Repro().cuda()
opt_mod = torch._dynamo.optimize("inductor")(mod)
mod.eval()
opt_mod.eval()
args = [
((1,), (1,), torch.int64, "cuda", False),
((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True),
]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
with torch.cuda.amp.autocast(enabled=False):
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
@config.patch(allow_buffer_reuse=False)
def test_issue103461(self):
def forward(add_1):
var_mean = torch.ops.aten.var_mean.correction(
add_1, [2], correction=0, keepdim=True
)
getitem_1 = var_mean[1]
return getitem_1
x = torch.randn(1, 8, 768, device="cuda")
correct = forward(x)
actual = torch.compile(forward, fullgraph=True)(x)
self.assertEqual(actual, correct)
def test_autotune_inplace_kernel(self):
"""
This UT tests autotune on an inplace kernel. The autotune should not contaminate
the input buffers when tuning with multiple configs. For more details, refer to
https://github.com/openai/triton/issues/781
https://github.com/pytorch/torchdynamo/issues/1670
"""
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
from torch._inductor.triton_heuristics import (
CachingAutotuner,
grid,
HeuristicType,
)
from torch._inductor.utils import instance_descriptor
def autotune(configs, meta):
def decorator(fn):
return CachingAutotuner(
# force autotune by setting save_cache_hook to False
fn,
meta=meta,
configs=configs,
save_cache_hook=False,
mutated_arg_names=["in_out_ptr0"],
heuristic_type=HeuristicType.POINTWISE,
)
return decorator
@autotune(
configs=[
triton.Config({"XBLOCK": 1}),
triton.Config({"XBLOCK": 2}),
],
meta={
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
"device": 0,
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
"constants": {},
},
)
@triton.jit
def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * XBLOCK
offsets = block_start + tl.arange(0, XBLOCK)
mask = offsets < xnumel
x = tl.load(in_out_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr0 + offsets, mask=mask)
output = x + y
tl.store(in_out_ptr0 + offsets, output, mask=mask)
xnumel = 384
in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
inout2 = inout1.clone()
stream0 = get_cuda_stream(0)
kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
assert same(
inout1, inout2, tol=0.001, equal_nan=True
), "failed autotune with inplace kernel"
def test_sort_stride_issue(self):
# This minified testcase comes from detectron2_maskrcnn_r_50_fpn
# There was a false error from our size_assert code
@torch._dynamo.optimize(nopython=True)
def forward(pred_objectness_logits_3_: torch.Tensor):
sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1)
getitem_12 = sort_3[0]
return getitem_12
args = [((1, 100), (0, 1), torch.float16, "cuda", False)]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
result = forward(*args)
assert same(result, torch.sort(args[0], descending=True, dim=1)[0])
def test_scalar_triton_index(self):
# The indirect indexing via a scalar like below used to lead to
# bad triton code that made triton segfault when compiling.
# See https://github.com/pytorch/torchdynamo/issues/1515
def fn(a):
zero = torch.zeros((16,), device=a.device, dtype=torch.int64)
return (a[zero],)
a = torch.randn((8,), dtype=torch.float32, device="cuda")
fn_optimized = torch._dynamo.optimize("inductor")(fn)
assert same(fn(a), fn_optimized(a))
def test_indirect_indexing_dense_mask(self):
def fn(x, y):
ne = torch.ops.aten.ne.Scalar(x, 1)
sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1])
sub = torch.ops.aten.sub.Tensor(sum_1, 1)
unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1)
gather = torch.ops.aten.gather.default(x, 1, unsqueeze)
squeeze = torch.ops.aten.squeeze.default(gather)
out = torch.ops.aten.multiply(y, squeeze)
return (out,)
a = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
b = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
fn_optimized = torch._dynamo.optimize("inductor")(fn)
assert same(fn(a, b), fn_optimized(a, b))
def test_simplify_dims(self):
def fn(a):
return (a + 1,)
self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],))
@config.patch(permute_fusion=True)
def test_permute_fusion(self):
class Repro(torch.nn.Module):
def forward(self, view, reshape_2):
permute = view.permute(0, 2, 1)
view = None
reshape = torch.reshape(permute, (-1, 642))
bmm = torch.bmm(permute, reshape_2)
return (bmm,)
args = [
((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True),
((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True),
]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
mod = Repro()
opt_mod = torch._dynamo.optimize("inductor")(mod)
ref = mod(*args)
res = opt_mod(*args)
self.assertTrue(same(ref, res))
@config.patch({"triton.autotune_pointwise": True})
def test_inplace_add_alpha_autotune(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
return (x,)
x1 = torch.zeros(2, 3, 4, 10, device="cuda")
x2 = torch.zeros(2, 3, 4, 10, device="cuda")
x3 = torch.zeros(2, 3, 4, 10, device="cuda")
y = torch.randn(2, 3, 4, 10, device="cuda").to(
memory_format=torch.channels_last
)
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled([x3, y])
assert same(x2, x3)
@config.patch({"triton.autotune_pointwise": True})
def test_inplace_buffer_autotune(self):
def foo(x, y, z):
a = x @ y
return a.unsqueeze(0).unsqueeze(0) + z
x = torch.zeros(5, 5, device="cuda")
y = torch.zeros(5, 5, device="cuda")
z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last)
self.common(
foo,
(x, y, z),
check_lowp=False,
)
def test_memory_history_inductor(self):
def called_inside_compile(x, w, b):
a = x @ w + b
return torch.sigmoid(a)
@torch.compile
def fn(x, w, b):
x = called_inside_compile(x, w, b)
return called_inside_compile(x, w, b)
w = torch.rand(3, 3, device="cuda")
b = torch.rand(3, device="cuda")
x = torch.rand(3, device="cuda")
try:
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history(True)
r = fn(x, w, b)
finally:
torch.cuda.memory._record_memory_history(False)
snapshot = str(torch.cuda.memory._snapshot())
self.assertTrue("called_inside_compile" in snapshot)
def test_negative_arange_dynamic_shapes(self):
# Repro from alibi relative encodings
def sign(x):
return (x > 0) - (x < 0)
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
nheads = 16
start = math.log2(0.5)
end = math.log2(1 / (2**8))
self.scales = nn.Buffer(
2
** torch.arange(
start,
end + 1e-6 * sign(end - start),
(end - start) / (nheads - 1),
).view(1, nheads, 1, 1),
)
self.emb = nn.Embedding(1024, 256)
self.dec_layer = nn.TransformerDecoderLayer(
256, 16, 512, batch_first=True, norm_first=True
)
self.head = nn.Linear(256, 1024)
def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor):
padmask = dec_in == 0
dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2)
dec_mask = dec_mask.to(dtype=torch.float32)
dec_mask = dec_mask.tril(diagonal=0).cuda()
q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
rel_pos = k_pos[None, :] - q_pos[:, None]
values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0)
dec_bias = values * self.scales
dec_bias.tril_(diagonal=0)
dec_mask = dec_mask + dec_bias[0]
out = self.emb(dec_in)
out = self.dec_layer(out, enc_out, tgt_mask=dec_mask)
return self.head(out)
mod = Repro().cuda()
opt_mod = torch._dynamo.optimize("inductor", dynamic=True)(mod)
mod.eval()
opt_mod.eval()
enc_out = torch.rand(1, 512, 256).cuda()
dec_inputs = [
torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8)
]
for dec_inp in dec_inputs:
assert same_two_models(
mod, opt_mod, [enc_out, dec_inp], only_fwd=True
), "Inductor with dynamic shapes failed"
def test_issue97695_1input(self):
def fn(arg3_1, relu, permute_1):
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
cat_2 = torch.ops.aten.cat.default([addmm_1], 1)
return (cat_2,)
args = [
((96,), (1,), torch.float32, "cuda"),
((10, 256), (256, 1), torch.float32, "cuda"),
((256, 96), (1, 256), torch.float32, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
correct = fn(*args)
mod = make_fx(fn, tracing_mode="real")(*args)
compiled = compile_fx_inner(mod, args)
ref = compiled(list(args))
assert same(ref, correct)
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)
def test_issue_103924(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.temperature = 1
self.layer = torch.nn.Softmax(dim=1)
def forward(self, x):
n_samples, _ = x.shape
y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
inp = x / y[..., None]
return self.layer(inp)
x = torch.rand([4, 4], device="cuda")
m = MyModule()
opt_m = torch.compile(backend="inductor")(m)
self.assertEqual(opt_m(x), m(x))
def test_issue97695_2input(self):
def fn(arg3_1, arg3_2, relu, permute_1):
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1)
cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1)
return (cat_2,)
args = [
((96,), (1,), torch.float32, "cuda"),
((96,), (1,), torch.float32, "cuda"),
((10, 256), (256, 1), torch.float32, "cuda"),
((256, 96), (1, 256), torch.float32, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
correct = fn(*args)
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)
def test_embedding_var_mean(self):
def forward(arg0_1):
full = torch.ops.aten.full.default(
[1, 2048],
1,
dtype=torch.float32,
layout=torch.strided,
device=torch.device(type="cuda", index=0),
pin_memory=False,
)
convert_element_type_1 = torch.ops.prims.convert_element_type.default(
full, torch.int64
)
cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1)
mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1)
sub_1 = torch.ops.aten.sub.Tensor(mul, 1)
slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807)
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
add_2 = torch.ops.aten.add.Tensor(slice_6, 2)
embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2)
var_mean = torch.ops.aten.var_mean.correction(
embedding_1, [2], correction=0, keepdim=True
)
return [var_mean[0], var_mean[1], add_2]
emb = torch.randn([2050, 768], device="cuda")
gm = make_fx(forward)(emb)
opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb])
opt([emb])
torch.cuda.synchronize()
def test_deterministic_algorithms(self):
N = 10000
@torch.compile
def fn(idx, values):
x = torch.zeros(1, device="cuda")
x[idx] += values
return x
idx = torch.zeros(N, dtype=torch.int64, device="cuda")
values = torch.randn(N, device="cuda")
r0 = fn(idx, values)
with DeterministicGuard(True):
r1 = fn(idx, values)
for _ in range(10):
rn = fn(idx, values)
self.assertEqual(r1, rn, atol=0, rtol=0)
# https://github.com/pytorch/pytorch/issues/96406
def test_linear_cpu_input(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, data):
data = data.to("cuda")
return self.linear(data)
mod = Model().cuda().eval()
with torch.no_grad():
self.common(mod, (torch.randn(4, 4),))
@config.patch({"fallback_random": True, "triton.cudagraphs": True})
def test_xlnet_lm_stride_repro(self):
class Repro(nn.Module):
def __init__(self):
super().__init__()
self.dropout = nn.Dropout(p=0.1, inplace=False)
def forward(self, x):
y = torch._C._nn.gelu(x)
return self.dropout(y)
mod = Repro()
x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda")
y = torch.compile(mod)(x)
# Inductor claims the output layout of gelu's saved variable for
# backwards will be (4096, 4096, 1) but in actuality it is (4096,
# 2097152, 1). Fortunately this doesn't actually matter in practice.
y.sum().backward()
def test_lookup_seed_backward(self):
@torch.compile(fullgraph=True)
def forward(inductor_seeds, mul_4, view_15):
inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default(
inductor_seeds, 2
)
inductor_random_2 = torch.ops.prims.inductor_random.default(
[2, 512, 768], inductor_lookup_seed_2, "rand"
)
gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1)
mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15)
mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112)
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4)
var_mean_1 = torch.ops.aten.var_mean.correction(
add_5, [2], correction=0, keepdim=True
)
getitem_3 = var_mean_1[1]
sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3)
return (sub_3,)
buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda")
buf1 = torch.zeros((2, 512, 768), device="cuda")
buf2 = torch.zeros((2, 512, 768), device="cuda")
forward(buf0, buf1, buf2)
def test_issue100806(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 20)
self.linear2 = torch.nn.Linear(20, 30)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = torch.cat((x, x), dim=1)
x = x.view(-1, 2, 30)
x = x[:, 1, :]
x = self.relu(x)
return x
device = "cuda"
batch_size = 2
x = torch.randn(batch_size, 10).to(device)
func = Model().to(device)
with torch.no_grad():
func.train(False)
jit_func = torch.compile(func)
res1 = func(x)
res2 = jit_func(x)
self.assertEqual(res1, res2)
def test_issue103481(self):
def fn(x, y):
# NOTE: 6 dimensions is important! does not fail for 5 dimensions
mean = torch.mean(x, [2, 3, 4, 5], keepdim=True)
add = mean + y
return add
x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda")
y = torch.rand((), device="cuda")
expect = fn(x, y)
opt_fn = torch.compile(fn)
actual = opt_fn(x, y)
self.assertEqual(expect, actual)
@config.patch({"triton.dense_indexing": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_bucketize_dynamic_dense(self):
"""
Make sure that ops.bucketize() can handle dense_indexing, which previously
caused issues due to incorrect handling of the size of offsets.
"""
def fn(values, offsets):
return torch.ops.prims._inductor_bucketize(values, offsets)
values = torch.rand((64, 64), device="cuda")
offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda")
expect = fn(values, offsets)
opt_fn = torch.compile(fn, dynamic=True)
actual = opt_fn(values, offsets)
self.assertEqual(expect, actual)
def test_float64_constants(self):
def fn():
# NOTE: tensors of all the same value are constant folded, so we
# need a tensor with two distinct values
a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda")
return a * 2e50
cfn = torch.compile(fn)
expect = fn()
actual = cfn()
self.assertEqual(expect, actual, atol=0, rtol=0)
def test_issue104759(self):
def fn(arg7_1, add_1, permute_2, select_scatter, slice_8):
slice_scatter_4 = torch.ops.aten.slice_scatter.default(
permute_2, select_scatter, 0, 1, 9223372036854775807
)
permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4])
view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48])
view_7 = torch.ops.aten.view.default(view_6, [1000, 48])
view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48])
view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4])
permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4])
slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807)
slice_scatter_5 = torch.ops.aten.slice_scatter.default(
slice_8, slice_7, 4, 0, 9223372036854775807
)
slice_scatter_6 = torch.ops.aten.slice_scatter.default(
arg7_1, slice_scatter_5, 3, 0, 1000
)
mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476)
slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000)
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807)
select_2 = torch.ops.aten.select.int(slice_10, 0, 0)
permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2])
mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476)
expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4])
view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4])
expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000])
view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000])
bmm = torch.ops.aten.bmm.default(view_10, view_11)
return (bmm,)
args = []
args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda"))
args.append(
rand_strided(
(1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda"
)
)
args.append(
rand_strided(
(3, 1, 4, 1000, 4),
(16, 48000, 4, 48, 1),
dtype=torch.float16,
device="cuda",
)
)
args.append(
rand_strided(
(2, 1, 4, 1000, 4),
(16, 48000, 4, 48, 1),
dtype=torch.float16,
device="cuda",
)
)
args.append(
rand_strided(
(2, 1, 4, 1000, 4),
(19200, 19200, 4800, 4, 1),
dtype=torch.float16,
device="cuda",
)
)
correct = fn(*args)
mod = make_fx(fn, tracing_mode="real")(*args)
compiled = compile_fx_inner(mod, args)
ref = compiled(list(args))
assert same(ref, correct)
@config.patch({"triton.cudagraphs": True})
def test_index_put_inplace_cudagraph(self):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put_([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
opt_fn = torch._dynamo.optimize("inductor")(fn)
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
@config.patch({"triton.cudagraphs": True})
def test_index_put_cudagraph(self):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
opt_fn = torch._dynamo.optimize("inductor")(fn)
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
@config.patch({"triton.cudagraphs": True})
def test_index_put_no_fallback_cudagraph(self):
def fn(x, y, z):
x = torch.zeros_like(x)
return x.index_put([y], z, True)
x = torch.zeros((512, 512), device="cuda", dtype=torch.int32)
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
z = torch.ones((512, 512), device="cuda", dtype=torch.int32)
opt_fn = torch._dynamo.optimize("inductor")(fn)
ref = fn(x, y, z)
# run it twice to test cuda graph issue
res = opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res)
# https://github.com/pytorch/pytorch/issues/104937
def test_linear_with_zero_infeature_size(self):
m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda")
x = torch.rand(1, 1, 0, device="cuda")
expect = m(x)
opt_fn = torch.compile(m)
actual = opt_fn(x)
self.assertEqual(expect, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CUDA
if HAS_CUDA and not TEST_WITH_ASAN:
run_tests(needs="filelock")