| # mypy: allow-untyped-defs |
| # Owner(s): ["module: unknown"] |
| |
| import threading |
| import time |
| import torch |
| import unittest |
| from torch.futures import Future |
| from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests |
| from typing import TypeVar |
| |
| T = TypeVar("T") |
| |
| |
| def add_one(fut): |
| return fut.wait() + 1 |
| |
| |
| class TestFuture(TestCase): |
| def test_set_exception(self) -> None: |
| # This test is to ensure errors can propagate across futures. |
| error_msg = "Intentional Value Error" |
| value_error = ValueError(error_msg) |
| |
| f = Future[T]() # type: ignore[valid-type] |
| # Set exception |
| f.set_exception(value_error) |
| # Exception should throw on wait |
| with self.assertRaisesRegex(ValueError, "Intentional"): |
| f.wait() |
| |
| # Exception should also throw on value |
| f = Future[T]() # type: ignore[valid-type] |
| f.set_exception(value_error) |
| with self.assertRaisesRegex(ValueError, "Intentional"): |
| f.value() |
| |
| def cb(fut): |
| fut.value() |
| |
| f = Future[T]() # type: ignore[valid-type] |
| f.set_exception(value_error) |
| |
| with self.assertRaisesRegex(RuntimeError, "Got the following error"): |
| cb_fut = f.then(cb) |
| cb_fut.wait() |
| |
| def test_set_exception_multithreading(self) -> None: |
| # Ensure errors can propagate when one thread waits on future result |
| # and the other sets it with an error. |
| error_msg = "Intentional Value Error" |
| value_error = ValueError(error_msg) |
| |
| def wait_future(f): |
| with self.assertRaisesRegex(ValueError, "Intentional"): |
| f.wait() |
| |
| f = Future[T]() # type: ignore[valid-type] |
| t = threading.Thread(target=wait_future, args=(f, )) |
| t.start() |
| f.set_exception(value_error) |
| t.join() |
| |
| def cb(fut): |
| fut.value() |
| |
| def then_future(f): |
| fut = f.then(cb) |
| with self.assertRaisesRegex(RuntimeError, "Got the following error"): |
| fut.wait() |
| |
| f = Future[T]() # type: ignore[valid-type] |
| t = threading.Thread(target=then_future, args=(f, )) |
| t.start() |
| f.set_exception(value_error) |
| t.join() |
| |
| def test_done(self) -> None: |
| f = Future[torch.Tensor]() |
| self.assertFalse(f.done()) |
| |
| f.set_result(torch.ones(2, 2)) |
| self.assertTrue(f.done()) |
| |
| def test_done_exception(self) -> None: |
| err_msg = "Intentional Value Error" |
| |
| def raise_exception(unused_future): |
| raise RuntimeError(err_msg) |
| |
| f1 = Future[torch.Tensor]() |
| self.assertFalse(f1.done()) |
| f1.set_result(torch.ones(2, 2)) |
| self.assertTrue(f1.done()) |
| |
| f2 = f1.then(raise_exception) |
| self.assertTrue(f2.done()) |
| with self.assertRaisesRegex(RuntimeError, err_msg): |
| f2.wait() |
| |
| def test_wait(self) -> None: |
| f = Future[torch.Tensor]() |
| f.set_result(torch.ones(2, 2)) |
| |
| self.assertEqual(f.wait(), torch.ones(2, 2)) |
| |
| def test_wait_multi_thread(self) -> None: |
| |
| def slow_set_future(fut, value): |
| time.sleep(0.5) |
| fut.set_result(value) |
| |
| f = Future[torch.Tensor]() |
| |
| t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) |
| t.start() |
| |
| self.assertEqual(f.wait(), torch.ones(2, 2)) |
| t.join() |
| |
| def test_mark_future_twice(self) -> None: |
| fut = Future[int]() |
| fut.set_result(1) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Future can only be marked completed once" |
| ): |
| fut.set_result(1) |
| |
| def test_pickle_future(self): |
| fut = Future[int]() |
| errMsg = "Can not pickle torch.futures.Future" |
| with TemporaryFileName() as fname: |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| torch.save(fut, fname) |
| |
| def test_then(self): |
| fut = Future[torch.Tensor]() |
| then_fut = fut.then(lambda x: x.wait() + 1) |
| |
| fut.set_result(torch.ones(2, 2)) |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) |
| |
| def test_chained_then(self): |
| fut = Future[torch.Tensor]() |
| futs = [] |
| last_fut = fut |
| for _ in range(20): |
| last_fut = last_fut.then(add_one) |
| futs.append(last_fut) |
| |
| fut.set_result(torch.ones(2, 2)) |
| |
| for i in range(len(futs)): |
| self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) |
| |
| def _test_then_error(self, cb, errMsg): |
| fut = Future[int]() |
| then_fut = fut.then(cb) |
| |
| fut.set_result(5) |
| self.assertEqual(5, fut.wait()) |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| then_fut.wait() |
| |
| def test_then_wrong_arg(self): |
| |
| def wrong_arg(tensor): |
| return tensor + 1 |
| |
| self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int") |
| |
| def test_then_no_arg(self): |
| |
| def no_arg(): |
| return True |
| |
| self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given") |
| |
| def test_then_raise(self): |
| |
| def raise_value_error(fut): |
| raise ValueError("Expected error") |
| |
| self._test_then_error(raise_value_error, "Expected error") |
| |
| def test_add_done_callback_simple(self): |
| callback_result = False |
| |
| def callback(fut): |
| nonlocal callback_result |
| fut.wait() |
| callback_result = True |
| |
| fut = Future[torch.Tensor]() |
| fut.add_done_callback(callback) |
| |
| self.assertFalse(callback_result) |
| fut.set_result(torch.ones(2, 2)) |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| self.assertTrue(callback_result) |
| |
| def test_add_done_callback_maintains_callback_order(self): |
| callback_result = 0 |
| |
| def callback_set1(fut): |
| nonlocal callback_result |
| fut.wait() |
| callback_result = 1 |
| |
| def callback_set2(fut): |
| nonlocal callback_result |
| fut.wait() |
| callback_result = 2 |
| |
| fut = Future[torch.Tensor]() |
| fut.add_done_callback(callback_set1) |
| fut.add_done_callback(callback_set2) |
| |
| fut.set_result(torch.ones(2, 2)) |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| # set2 called last, callback_result = 2 |
| self.assertEqual(callback_result, 2) |
| |
| def _test_add_done_callback_error_ignored(self, cb): |
| fut = Future[int]() |
| fut.add_done_callback(cb) |
| |
| fut.set_result(5) |
| # error msg logged to stdout |
| self.assertEqual(5, fut.wait()) |
| |
| def test_add_done_callback_error_is_ignored(self): |
| |
| def raise_value_error(fut): |
| raise ValueError("Expected error") |
| |
| self._test_add_done_callback_error_ignored(raise_value_error) |
| |
| def test_add_done_callback_no_arg_error_is_ignored(self): |
| |
| def no_arg(): |
| return True |
| |
| # Adding another level of function indirection here on purpose. |
| # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI |
| self._test_add_done_callback_error_ignored(no_arg) |
| |
| def test_interleaving_then_and_add_done_callback_maintains_callback_order(self): |
| callback_result = 0 |
| |
| def callback_set1(fut): |
| nonlocal callback_result |
| fut.wait() |
| callback_result = 1 |
| |
| def callback_set2(fut): |
| nonlocal callback_result |
| fut.wait() |
| callback_result = 2 |
| |
| def callback_then(fut): |
| nonlocal callback_result |
| return fut.wait() + callback_result |
| |
| fut = Future[torch.Tensor]() |
| fut.add_done_callback(callback_set1) |
| then_fut = fut.then(callback_then) |
| fut.add_done_callback(callback_set2) |
| |
| fut.set_result(torch.ones(2, 2)) |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| # then_fut's callback is called with callback_result = 1 |
| self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) |
| # set2 called last, callback_result = 2 |
| self.assertEqual(callback_result, 2) |
| |
| def test_interleaving_then_and_add_done_callback_propagates_error(self): |
| def raise_value_error(fut): |
| raise ValueError("Expected error") |
| |
| fut = Future[torch.Tensor]() |
| then_fut = fut.then(raise_value_error) |
| fut.add_done_callback(raise_value_error) |
| fut.set_result(torch.ones(2, 2)) |
| |
| # error from add_done_callback's callback is swallowed |
| # error from then's callback is not |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| with self.assertRaisesRegex(RuntimeError, "Expected error"): |
| then_fut.wait() |
| |
| def test_collect_all(self): |
| fut1 = Future[int]() |
| fut2 = Future[int]() |
| fut_all = torch.futures.collect_all([fut1, fut2]) |
| |
| def slow_in_thread(fut, value): |
| time.sleep(0.1) |
| fut.set_result(value) |
| |
| t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) |
| fut2.set_result(2) |
| t.start() |
| |
| res = fut_all.wait() |
| self.assertEqual(res[0].wait(), 1) |
| self.assertEqual(res[1].wait(), 2) |
| t.join() |
| |
| @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows") |
| def test_wait_all(self): |
| fut1 = Future[int]() |
| fut2 = Future[int]() |
| |
| # No error version |
| fut1.set_result(1) |
| fut2.set_result(2) |
| res = torch.futures.wait_all([fut1, fut2]) |
| print(res) |
| self.assertEqual(res, [1, 2]) |
| |
| # Version with an exception |
| def raise_in_fut(fut): |
| raise ValueError("Expected error") |
| fut3 = fut1.then(raise_in_fut) |
| with self.assertRaisesRegex(RuntimeError, "Expected error"): |
| torch.futures.wait_all([fut3, fut2]) |
| |
| def test_wait_none(self): |
| fut1 = Future[int]() |
| with self.assertRaisesRegex(RuntimeError, "Future can't be None"): |
| torch.jit.wait(None) |
| with self.assertRaisesRegex(RuntimeError, "Future can't be None"): |
| torch.futures.wait_all((None,)) # type: ignore[arg-type] |
| with self.assertRaisesRegex(RuntimeError, "Future can't be None"): |
| torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type] |
| |
| if __name__ == '__main__': |
| run_tests() |