| # Owner(s): ["oncall: jit"] |
| |
| from threading import Event |
| from time import sleep |
| |
| import torch._lazy |
| import torch._lazy.ts_backend |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| torch._lazy.ts_backend.init() |
| |
| |
| class ClosuresTest(TestCase): |
| def test_synchronous(self): |
| flag = Event() |
| assert not flag.is_set() |
| |
| def closure(): |
| sleep(1) |
| assert not flag.is_set() |
| flag.set() |
| |
| torch._lazy.add_step_closure(closure) |
| torch._lazy.mark_step() |
| |
| # should not get to this part before closure is finished running |
| assert flag.is_set() |
| |
| def test_asynchronous(self): |
| flag = Event() |
| assert not flag.is_set() |
| |
| def closure(): |
| sleep(1) |
| assert flag.is_set() |
| |
| torch._lazy.add_step_closure(closure, run_async=True) |
| torch._lazy.mark_step() |
| |
| # should get to this part and complete before closure is finished running |
| assert not flag.is_set() |
| flag.set() |
| |
| def test_synchronous_exception(self): |
| flag = Event() |
| assert not flag.is_set() |
| |
| try: |
| |
| def closure(): |
| flag.set() |
| raise RuntimeError("Simulating exception in closure") |
| |
| torch._lazy.add_step_closure(closure) |
| torch._lazy.mark_step() |
| |
| raise AssertionError() # Should not reach here |
| except RuntimeError as e: |
| assert flag.is_set(), "Should have caught exception from closure" |
| |
| def test_asynchronous_exception(self): |
| flag = Event() |
| assert not flag.is_set() |
| |
| def closure1(): |
| flag.set() |
| raise RuntimeError("Simulating exception in closure1") |
| |
| torch._lazy.add_step_closure(closure1, run_async=True) |
| torch._lazy.mark_step() |
| |
| flag.wait(timeout=5) |
| |
| try: |
| |
| def closure2(): # Should never execute |
| flag.clear() |
| |
| torch._lazy.add_step_closure(closure2, run_async=True) |
| torch._lazy.mark_step() |
| |
| raise AssertionError() # Should not reach here |
| except RuntimeError as e: |
| # Should have caught exception from closure1 |
| pass |
| |
| assert flag.is_set() |
| |
| |
| if __name__ == "__main__": |
| run_tests() |