| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| import io |
| |
| import torch |
| import warnings |
| from contextlib import redirect_stderr |
| from torch.testing import FileCheck |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| |
| class TestWarn(JitTestCase): |
| def test_warn(self): |
| @torch.jit.script |
| def fn(): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=1, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_only_once(self): |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=1, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_only_once_in_loop_func(self): |
| def w(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| w() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=1, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_once_per_func(self): |
| def w1(): |
| warnings.warn("I am warning you") |
| |
| def w2(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| w1() |
| w2() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=2, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_once_per_func_in_loop(self): |
| def w1(): |
| warnings.warn("I am warning you") |
| |
| def w2(): |
| warnings.warn("I am warning you") |
| |
| @torch.jit.script |
| def fn(): |
| for _ in range(10): |
| w1() |
| w2() |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=2, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_multiple_calls_multiple_warnings(self): |
| @torch.jit.script |
| def fn(): |
| warnings.warn("I am warning you") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| fn() |
| fn() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you", |
| count=2, |
| exactly=True) \ |
| .run(f.getvalue()) |
| |
| def test_warn_multiple_calls_same_func_diff_stack(self): |
| def warn(caller: str): |
| warnings.warn("I am warning you from " + caller) |
| |
| @torch.jit.script |
| def foo(): |
| warn("foo") |
| |
| @torch.jit.script |
| def bar(): |
| warn("bar") |
| |
| f = io.StringIO() |
| with redirect_stderr(f): |
| foo() |
| bar() |
| |
| FileCheck() \ |
| .check_count( |
| str="UserWarning: I am warning you from foo", |
| count=1, |
| exactly=True) \ |
| .check_count( |
| str="UserWarning: I am warning you from bar", |
| count=1, |
| exactly=True) \ |
| .run(f.getvalue()) |