| # Owner(s): ["oncall: jit"] |
| |
| import sys |
| import os |
| import contextlib |
| import subprocess |
| from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName |
| |
| |
| @contextlib.contextmanager |
| def _jit_disabled(): |
| cur_env = os.environ.get("PYTORCH_JIT", "1") |
| os.environ["PYTORCH_JIT"] = "0" |
| try: |
| yield |
| finally: |
| os.environ["PYTORCH_JIT"] = cur_env |
| |
| |
| class TestJitDisabled(TestCase): |
| """ |
| These tests are separate from the rest of the JIT tests because we need |
| run a new subprocess and `import torch` with the correct environment |
| variables set. |
| """ |
| |
| def compare_enabled_disabled(self, src): |
| """ |
| Runs the script in `src` with PYTORCH_JIT enabled and disabled and |
| compares their stdout for equality. |
| """ |
| # Write `src` out to a temporary so our source inspection logic works |
| # correctly. |
| with TemporaryFileName() as fname: |
| with open(fname, 'w') as f: |
| f.write(src) |
| with _jit_disabled(): |
| out_disabled = subprocess.check_output([ |
| sys.executable, |
| fname]) |
| out_enabled = subprocess.check_output([ |
| sys.executable, |
| fname]) |
| self.assertEqual(out_disabled, out_enabled) |
| |
| def test_attribute(self): |
| _program_string = """ |
| import torch |
| |
| class Foo(torch.jit.ScriptModule): |
| def __init__(self, x): |
| super().__init__() |
| self.x = torch.jit.Attribute(x, torch.Tensor) |
| |
| def forward(self, input): |
| return input |
| |
| s = Foo(torch.ones(2, 3)) |
| print(s.x) |
| """ |
| self.compare_enabled_disabled(_program_string) |
| |
| def test_script_module_construction(self): |
| _program_string = """ |
| import torch |
| |
| class AModule(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, input): |
| pass |
| |
| AModule() |
| print("Didn't throw exception") |
| """ |
| self.compare_enabled_disabled(_program_string) |
| |
| def test_recursive_script(self): |
| _program_string = """ |
| import torch |
| |
| class AModule(torch.nn.Module): |
| def forward(self, input): |
| pass |
| |
| sm = torch.jit.script(AModule()) |
| print("Didn't throw exception") |
| """ |
| self.compare_enabled_disabled(_program_string) |
| |
| if __name__ == '__main__': |
| run_tests() |