| # Owner(s): ["module: unknown"] |
| |
| from typing import Optional, List |
| import torch |
| from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo |
| |
| # End-to-end tests of features in native_functions.yaml |
| |
| |
| class FloatListWrapperModule(torch.nn.Module): |
| def forward(self, values, incr: Optional[List[float]]): |
| return torch._C._nn._test_optional_floatlist(values, incr) |
| |
| |
| class IntListWrapperModule(torch.nn.Module): |
| def forward(self, values, incr: Optional[List[int]]): |
| return torch._C._nn._test_optional_intlist(values, incr) |
| |
| |
| class TestNativeFunctions(TestCase): |
| |
| def _lists_with_str(self): |
| return [ |
| ("foo",), |
| (2, "foo"), |
| ("foo", 3), |
| ["foo"], |
| [2, "foo"], |
| ["foo", 3], |
| "foo", |
| ] |
| |
| def _test_raises_str_typeerror(self, fn): |
| for arg in self._lists_with_str(): |
| self.assertRaisesRegex(TypeError, "str", lambda: fn(arg)) |
| try: |
| fn(arg) |
| except TypeError as e: |
| print(e) |
| |
| def test_symintlist_error(self): |
| x = torch.randn(1) |
| self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) |
| |
| def test_vararg_symintlist_error(self): |
| self._test_raises_str_typeerror(lambda arg: torch.rand(arg)) |
| self._test_raises_str_typeerror(lambda arg: torch.rand(*arg)) |
| |
| def test_symintlist_error_with_overload_but_is_unique(self): |
| x = torch.randn(1) |
| y = torch.randn(1) |
| self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg)) |
| |
| def test_symintlist_error_with_overload(self): |
| x = torch.randn(1) |
| self._test_raises_str_typeerror(lambda arg: x.view(arg)) |
| |
| def test_intlist_error_with_overload(self): |
| x = torch.randn(1) |
| self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) |
| |
| # |
| # optional float list |
| # |
| |
| def do_test_optional_floatlist_with_module(self, module): |
| values = torch.tensor([1.5, 2.5], dtype=torch.float) |
| |
| returned = module(values, None) |
| self.assertEqual(values, returned) |
| # Make sure that it's an alias, indicating that the operator saw a nullopt. |
| values[0] = 3.5 |
| self.assertEqual(values, returned) |
| |
| returned = module(values, [5.1, 4.1]) |
| self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float)) |
| self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float)) |
| |
| def trace_optional_floatlist(self, const): |
| def wrapper(values): |
| return torch._C._nn._test_optional_floatlist(values, const) |
| return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float)) |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_optional_floatlist(self): |
| self.do_test_optional_floatlist_with_module(FloatListWrapperModule()) |
| self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule())) |
| |
| traced_none = self.trace_optional_floatlist(None) |
| traced_list = self.trace_optional_floatlist([5.1, 4.1]) |
| |
| # Not really a module, just lets us use our two traced functions to handle |
| # the specific cases of passing None and [5.1, 4.1]. |
| def fake_module(values, const): |
| if const is None: |
| return traced_none(values) |
| if const == [5.1, 4.1]: |
| return traced_list(values) |
| raise Exception("Invalid argument") # noqa: TRY002 |
| |
| self.do_test_optional_floatlist_with_module(fake_module) |
| |
| def test_optional_floatlist_invalid(self): |
| with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"): |
| FloatListWrapperModule()(torch.zeros(1), ["hi"]) |
| |
| with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): |
| torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"]) |
| |
| with self.assertRaisesRegex(TypeError, "must be .* Tensor"): |
| FloatListWrapperModule()(torch.zeros(1), torch.zeros(1)) |
| |
| with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): |
| torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1)) |
| |
| # |
| # optional int list |
| # |
| |
| def do_test_optional_intlist_with_module(self, module): |
| values = torch.tensor([1, 2], dtype=torch.int) |
| |
| returned = module(values, None) |
| self.assertEqual(values, returned) |
| # Make sure that it's an alias, indicating that the operator saw a nullopt. |
| values[0] = 3 |
| self.assertEqual(values, returned) |
| |
| returned = module(values, [5, 4]) |
| self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) |
| self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int)) |
| |
| def trace_optional_intlist(self, const): |
| def wrapper(values): |
| return torch._C._nn._test_optional_intlist(values, const) |
| return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_optional_intlist(self): |
| self.do_test_optional_intlist_with_module(IntListWrapperModule()) |
| self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule())) |
| |
| traced_none = self.trace_optional_intlist(None) |
| traced_list = self.trace_optional_intlist([5, 4]) |
| |
| # Not really a module, just lets us use our two traced functions to handle |
| # the specific cases of passing None and [5, 4]. |
| def fake_module(values, const): |
| if const is None: |
| return traced_none(values) |
| if const == [5, 4]: |
| return traced_list(values) |
| raise Exception("Invalid argument") # noqa: TRY002 |
| |
| self.do_test_optional_intlist_with_module(fake_module) |
| |
| def test_optional_intlist_invalid(self): |
| with self.assertRaisesRegex(TypeError, "must be .* but found"): |
| IntListWrapperModule()(torch.zeros(1), [0.5]) |
| |
| with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): |
| torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5]) |
| |
| with self.assertRaisesRegex(TypeError, "must be .* Tensor"): |
| IntListWrapperModule()(torch.zeros(1), torch.zeros(1)) |
| |
| with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): |
| torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1)) |
| |
| # |
| # optional filled int list |
| # |
| |
| def do_test_optional_filled_intlist_with_module(self, module): |
| values = torch.tensor([1, 2], dtype=torch.int) |
| |
| returned = module(values, None) |
| self.assertEqual(values, returned) |
| # Make sure that it's an alias, indicating that the operator saw a nullopt. |
| values[0] = 3 |
| self.assertEqual(values, returned) |
| |
| returned = module(values, 10) |
| self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) |
| self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int)) |
| |
| def trace_optional_filled_intlist(self, const): |
| def wrapper(values): |
| return torch._C._nn._test_optional_filled_intlist(values, const) |
| return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_optional_filled_intlist(self): |
| |
| def f(n: int): |
| x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n)) |
| y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n) |
| return x, y |
| |
| # eager |
| returned = f(10) |
| self.assertEqual(returned[0], returned[1]) |
| |
| # scripted |
| s = torch.jit.script(f) |
| returned = s(10) |
| self.assertEqual(returned[0], returned[1]) |
| |
| # traced |
| traced_none = self.trace_optional_filled_intlist(None) |
| traced_int = self.trace_optional_filled_intlist(10) |
| |
| # Not really a module, just lets us use our two traced functions to handle |
| # the specific cases of passing None and 10. |
| def fake_module(values, const): |
| if const is None: |
| return traced_none(values) |
| if const == 10: |
| return traced_int(values) |
| raise Exception("Invalid argument") # noqa: TRY002 |
| |
| self.do_test_optional_filled_intlist_with_module(fake_module) |
| |
| def test_string_defaults(self): |
| dummy = torch.rand(1) |
| fn = torch._C._nn._test_string_default |
| fn(dummy) |
| |
| with self.assertRaisesRegex(RuntimeError, "A"): |
| fn(dummy, a="") |
| |
| with self.assertRaisesRegex(RuntimeError, "B"): |
| fn(dummy, b="") |
| |
| def f(x): |
| torch._C._nn._test_string_default(x) |
| scripted_fn = torch.jit.script(f) |
| scripted_fn(dummy) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |