blob: 22f2a343c89826dabf55b2838d9094aae52267a6 [file] [log] [blame]
# Owner(s): ["module: primTorch"]
from functools import partial
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
dtypes,
)
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
import torch._prims as prims
from torch._prims.executor import make_traced
class TestPrims(TestCase):
@onlyCUDA
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
# nvfuser is not currently capable of realizing a broadcasted tensor
# when the broadcast is the only operation. Another op is needed.
def _wrapper(a, b, broadcast_dimensions):
a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
return prims.add(a_bc, b)
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'nvfuser'):
fn = partial(traced, executor=executor)
# Same shape
shape = (5, 5)
a = make_arg(shape)
b = make_arg(shape, low=0.0, high=0.0)
result = fn(a, b, (0, 1))
self.assertEqual(result.shape, a.shape)
self.assertTrue(result.is_contiguous)
self.assertEqual(a, result)
# Error input: reordering dims
with self.assertRaises(Exception):
result = fn(a, b, (1, 0))
# Adding outermost dimensions
a = make_arg((5, 5))
b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
result = fn(a, b, (2, 3))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.broadcast_to(b.shape), result)
# Expands
a = make_arg((1, 5, 1))
b = make_arg((3, 5, 7), low=0.0, high=0.0)
result = fn(a, b, (0, 1, 2))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.expand_as(result), result)
# Unsqueezes
a = make_arg((1, 2, 3))
b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
result = fn(a, b, (0, 1, 3))
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.unsqueeze(2), result)
# FIXME: This test exposes an issue in nvfuser
# Adds outermost, expands, and unsqueezes
"""
a = make_arg((1, 2, 3))
b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0)
result = fn(a, b, (1, 3, 4))
self.assertEqual(result.shape, b.shape)
a.unsqueeze_(3)
a.unsqueeze_(1)
a.unsqueeze_(0)
self.assertEqual(a.expand_as(result), result)
"""
class TestPrimsBasic(TestCase):
def test_torch_ops(self):
r = make_tensor((2,), device='cpu', dtype=torch.float)
self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
r = LoggingTensor(r)
with capture_logs() as logs:
log_input("input", r)
prims.sin(r)
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.prims.sin.default($0)""")
def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j)
instantiate_device_type_tests(TestPrims, globals())
if __name__ == "__main__":
run_tests()