blob: 9909f2bd7b5da2774e2160c907de5197bbd411d0 [file] [log] [blame] [edit]
# Owner(s): ["module: cuda"]
import torch
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn
import sys
from itertools import product
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_device_type import (
skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = NoTest # noqa: F811
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
def ref_fn(x, y, alpha=1, beta=1):
return alpha * x + beta * y
class TestPythonJiterator(TestCase):
@parametrize("shape_strides", [
(([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
# See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
# On cuda 11.3, nvrtcCompileProgram is taking too long to
# compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.
@skipCUDAIfVersionLessThan((11, 6))
@parametrize("shape_strides", [
(([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
@parametrize("alpha", [-1, 2.0, None])
@parametrize("beta", [3, -4.2, None])
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
def test_extra_args(self, device, dtype, alpha, beta):
a = torch.rand(3, device=device).mul(10).type(dtype)
b = torch.rand(3, device=device).mul(10).type(dtype)
extra_args = {}
if alpha is not None:
extra_args["alpha"] = alpha
if beta is not None:
extra_args["beta"] = beta
expected = ref_fn(a, b, **extra_args)
result = jitted_fn(a, b, **extra_args)
self.assertEqual(expected, result)
@parametrize("is_train", [True, False])
def test_bool_extra_args(self, device, is_train):
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
jitted_fn = create_jit_fn(code_string, is_train=False)
def ref_fn(x, mask, is_train):
return x * mask if is_train else x
a = torch.rand(3, device=device)
b = torch.rand(3, device=device)
expected = ref_fn(a, b, is_train=is_train)
result = jitted_fn(a, b, is_train=is_train)
self.assertEqual(expected, result)
def test_multiple_functors(self, device):
code_string = '''
template <typename T> T fn(T x, T mask) { return x * mask; }
template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
'''
jitted_fn = create_jit_fn(code_string)
def ref_fn(x, mask, y):
return x * mask + y
a = torch.rand(3, device=device)
b = torch.rand(3, device=device)
c = torch.rand(3, device=device)
expected = ref_fn(a, b, c)
result = jitted_fn(a, b, c)
self.assertEqual(expected, result)
@parametrize("num_inputs", [1, 5, 8])
def test_various_num_inputs(self, num_inputs):
inputs = []
for i in range(num_inputs):
inputs.append(torch.rand(3, device='cuda').mul(10))
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
jitted_fn = create_jit_fn(code_string)
def ref_fn(*inputs):
return torch.sum(torch.stack(inputs), dim=0)
expected = ref_fn(*inputs)
result = jitted_fn(*inputs)
self.assertEqual(expected, result)
@parametrize("num_outputs", [1, 4, 8])
def test_various_num_outputs(self, num_outputs):
input = torch.rand(3, device='cuda')
output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
function_body = ""
for i in range(num_outputs):
function_body += f"out{i} = input + {i};\n"
# NB: return type must be void, otherwise ROCm silently fails
code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}"
jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)
def ref_fn(input):
outputs = []
for i in range(num_outputs):
outputs.append(input + i)
if num_outputs == 1:
return outputs[0]
return tuple(outputs)
expected = ref_fn(input)
result = jitted_fn(input)
for i in range(num_outputs):
self.assertEqual(expected[i], result[i])
@parametrize("code_string", [
"template <typename T> T my _kernel(T x) { return x; }",
"template <typename T> Tmy_kernel(T x) { return x; }",
])
def test_invalid_function_name(self, code_string):
with self.assertRaises(Exception):
jitted_fn = create_jit_fn(code_string)
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
if __name__ == '__main__':
run_tests()