blob: 0db9533e9aa5ac5cf048824682c18b0b4b3d1255 [file] [log] [blame]
# Owner(s): ["module: primTorch"]
from functools import partial
from itertools import product
import unittest
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
skipCUDAIfRocm,
dtypes,
)
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
import torch._prims as prims
from torch._prims.executor import make_traced
import torch._refs as refs
if TEST_SCIPY:
import scipy.special
class TestPrims(TestCase):
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
# nvfuser is not currently capable of realizing a broadcasted tensor
# when the broadcast is the only operation. Another op is needed.
def _wrapper(a, b, broadcast_dimensions):
a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
return prims.add(a_bc, b)
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'nvfuser'):
fn = partial(traced, executor=executor)
# Same shape
shape = (5, 5)
a = make_arg(shape)
b = make_arg(shape, low=0.0, high=0.0)
result = fn(a, b, (0, 1))
self.assertEqual(result.shape, a.shape)
self.assertTrue(result.is_contiguous)
self.assertEqual(a, result)
# Error input: reordering dims
with self.assertRaises(Exception):
result = fn(a, b, (1, 0))
# Adding outermost dimensions
a = make_arg((5, 5))
b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
result = fn(a, b, (2, 3))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.broadcast_to(b.shape), result)
# Expands
a = make_arg((1, 5, 1))
b = make_arg((3, 5, 7), low=0.0, high=0.0)
result = fn(a, b, (0, 1, 2))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.expand_as(result), result)
# Unsqueezes
a = make_arg((1, 2, 3))
b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
result = fn(a, b, (0, 1, 3))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.unsqueeze(2), result)
# FIXME: This test exposes an issue in nvfuser
# Adds outermost, expands, and unsqueezes
"""
a = make_arg((1, 2, 3))
b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0)
result = fn(a, b, (1, 3, 4))
self.assertEqual(result.shape, b.shape)
a.unsqueeze_(3)
a.unsqueeze_(1)
a.unsqueeze_(0)
self.assertEqual(a.expand_as(result), result)
"""
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_broadcast_in_dim_sum(self, device, dtype):
def _wrapper(a):
a_sum = prims.sum(a, [0, 1])
a_bc = prims.broadcast_in_dim(a_sum, [], [])
return a_bc
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'nvfuser'):
fn = partial(traced, executor=executor)
shape = (5, 5)
a = make_arg(shape)
result = fn(a)
self.assertEqual(result.shape, ())
self.assertTrue(result.is_contiguous)
self.assertEqual(_wrapper(a), result)
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
@dtypes(torch.float64, torch.long)
def test_cbrt_prim(self, device, dtype):
make_arg = partial(make_tensor, device=device, dtype=dtype)
batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
shapes = [(), (0,), (1,), (5,)]
try:
# Sets the default dtype to NumPy's default dtype of double
cur_default = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
# Tested here, as this OP is not currently exposed or tested in ATen
for b, s in product(batches, shapes):
x = make_arg(b + s)
y = prims.cbrt(x)
x_np = x.cpu().numpy()
y_np = scipy.special.cbrt(x_np)
self.assertEqual(y, y_np, exact_device=False)
finally:
torch.set_default_dtype(cur_default)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used
# Assuming one-to-one mapping between prims and nvfuser implementations
# This test is not intended to test the correctness of the nvfuser implementation
from torch._C._nvfuser import FusionDefinition as fd
prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.Ops))
ops_without_nvfuser_impl = {
name
for name in prim_nvfuser_ops
if getattr(torch.ops.prims, name).default.impl_nvfuser is None
}
assert (
len(ops_without_nvfuser_impl) == 0
), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ",
"while there exists nvfuser implementations for them.")
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_executor_cached_noncontiguous(self, device):
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsMode
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
def func(a):
return torch.sigmoid(a)
with TorchRefsMode.push():
gm = make_fx(func)(a)
# First run to create the cache
execute(gm, a, executor="nvfuser")
# a.mT is noncontiguous, but it shouldn't affect correctness
expected = execute(gm, a.mT, executor="aten")
actual = execute(gm, a.mT, executor="nvfuser")
self.assertEqual(expected, actual)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
@parametrize("correction", [0, 1])
def test_var(self, device, dtype, correction):
def _wrapper(a):
return prims.var(a, [0, 1], correction=correction)
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'nvfuser'):
fn = partial(traced, executor=executor)
shape = (5, 5)
a = make_arg(shape)
result = fn(a)
self.assertEqual(result.shape, ())
self.assertTrue(result.is_contiguous)
self.assertEqual(_wrapper(a), result)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_pytree_input_output(self, device, dtype):
@make_traced
def fn(a, b_dict):
b = b_dict["b"]
d = {}
d["c"] = torch.add(a, b)
return (d, torch.add(a, d["c"]))
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((5, 5))
b = make_arg((1, 5))
b_dict = {"b": b}
result_aten = fn(a, b_dict, executor="aten")
result_nvfuser = fn(a, b_dict, executor="nvfuser")
self.assertEqual(result_aten, result_nvfuser)
@dtypes(torch.float32)
def test_memory_format_strides(self, device, dtype):
shapes = (
(),
(0,),
(1,),
(5),
(1, 0),
(1, 1),
(3, 7),
(3, 0, 2),
(1, 1, 2),
(4, 1, 1),
(7, 8, 9),
)
channels_last_shapes = (
(0, 0, 0, 0),
(1, 0, 3, 0),
(0, 2, 3, 5),
(2, 2, 2, 0),
(5, 4, 3, 2),
(8, 8, 7, 2),
(9, 1, 3, 1),
(4, 5, 8, 7)
)
channels_last_3d_shapes = (
(0, 8, 7, 9, 2),
(5, 0, 7, 9, 2),
(5, 0, 7, 9, 0),
(5, 8, 7, 9, 2),
(5, 1, 7, 9, 2),
(5, 1, 7, 9, 1),
)
pairs = (
(shapes, torch.contiguous_format),
(channels_last_shapes, torch.contiguous_format),
(channels_last_3d_shapes, torch.contiguous_format),
(channels_last_shapes, torch.channels_last),
(channels_last_3d_shapes, torch.channels_last_3d),
)
for shapes, memory_format in pairs:
for shape in shapes:
# tests empty
expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
self.assertEqual(expected.stride(), actual.stride())
# tests clone
a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
expected = torch.clone(a, memory_format=memory_format)
actual = torch.clone(a, memory_format=memory_format)
self.assertEqual(expected.stride(), actual.stride())
# tests contiguous
a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
expected = a.contiguous(memory_format=memory_format)
actual = refs.contiguous(a, memory_format=memory_format)
self.assertEqual(expected.stride(), actual.stride())
class TestPrimsBasic(TestCase):
def test_torch_ops(self):
r = make_tensor((2,), device='cpu', dtype=torch.float)
self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
r = LoggingTensor(r)
with capture_logs() as logs:
log_input("input", r)
prims.sin(r)
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.prims.sin.default($0)""")
def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j)
instantiate_device_type_tests(TestPrims, globals())
if __name__ == "__main__":
run_tests()