blob: 6d1fd07fe6c8fb61b967dffee9de02695572adb0 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
def canonical(graph):
return torch._C._jit_pass_canonicalize(graph).str(False)
class TestCustomOperators(JitTestCase):
def test_dynamic_op_registry(self):
from torch._ops import _OpNamespace
self.assertTrue(hasattr(torch, 'ops'))
if '_test' in torch.ops.__dict__:
torch.ops.__dict__.pop('_test')
# Don't use `hasattr()` because it will call `__getattr__`.
self.assertNotIn('_test', torch.ops.__dict__)
torch.ops._test
self.assertIn('_test', torch.ops.__dict__)
self.assertEqual(type(torch.ops._test), _OpNamespace)
self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
op = torch.ops._test.leaky_relu
self.assertTrue(callable(op))
self.assertIn('leaky_relu', torch.ops._test.__dict__)
op2 = torch.ops._test.leaky_relu
self.assertEqual(op, op2)
def test_simply_calling_an_operator(self):
input = torch.randn(100)
output = torch.ops.aten.relu(input)
self.assertEqual(output, input.relu())
def test_default_arguments_are_used(self):
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
self.assertEqual(output, torch.tensor([-0.01, 1]))
def test_passing_too_many_args(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)",
""
):
torch.ops.aten.relu(1, 2)
def test_passing_too_few_args(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::relu\(\) is missing value for argument 'self'.",
""
):
torch.ops.aten.relu()
def test_passing_one_positional_but_not_the_second(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::type_as\(\) is missing value for argument 'other'.",
""
):
torch.ops.aten.type_as(torch.ones(5, 5))
def test_passing_unknown_kwargs(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'",
""
):
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
def test_passing_and_returning_lists(self):
# Replace with actual test once we support lists.
a, b = torch.rand(5), torch.rand(5)
output = torch.ops._test.cat([a, b])
output_ref = torch.cat([a, b])
self.assertEqual(output, output_ref)
def test_calling_scripted_custom_op(self):
@torch.jit.script
def func(x):
return torch.ops.aten.relu(x)
input = torch.ones(5, 5)
self.assertEqual(func(input), input.relu())
def test_calling_traced_custom_op(self):
input = torch.ones(5, 5)
func = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertEqual(func(input), input.relu())
@unittest.skip("Need to figure out default dtype differences between fbcode and oss")
def test_script_graph_for_custom_ops_matches_traced_graph(self):
input = torch.ones(5, 5)
trace = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertExpectedInline(canonical(trace.graph), '''\
graph(%0 : Float(5, 5)):
%1 : Float(5, 5) = aten::relu(%0)
return (%1)
''')
def test_script_graph_contains_custom_op(self):
@torch.jit.script
def func(x):
return torch.ops.aten.relu(x)
self.assertExpectedInline(canonical(func.graph), '''\
graph(%x.1 : Tensor):
%1 : Tensor = aten::relu(%x.1)
return (%1)
''')
def test_generic_list(self):
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
# https://github.com/pytorch/pytorch/issues/80508
def test_where_no_scalar(self):
x = torch.rand(1, 3, 224, 224)
torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise