blob: dd1e79ff83b3b7d160416d34a3be5af06161b3e4 [file] [log] [blame] [edit]
# mypy: allow-untyped-defs
# Owner(s): ["module: unknown"]
import threading
import time
import torch
import unittest
from torch.futures import Future
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
from typing import TypeVar
T = TypeVar("T")
def add_one(fut):
return fut.wait() + 1
class TestFuture(TestCase):
def test_set_exception(self) -> None:
# This test is to ensure errors can propagate across futures.
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
f = Future[T]() # type: ignore[valid-type]
# Set exception
f.set_exception(value_error)
# Exception should throw on wait
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
# Exception should also throw on value
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value()
def cb(fut):
fut.value()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
cb_fut = f.then(cb)
cb_fut.wait()
def test_set_exception_multithreading(self) -> None:
# Ensure errors can propagate when one thread waits on future result
# and the other sets it with an error.
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
def wait_future(f):
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=wait_future, args=(f, ))
t.start()
f.set_exception(value_error)
t.join()
def cb(fut):
fut.value()
def then_future(f):
fut = f.then(cb)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
fut.wait()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=then_future, args=(f, ))
t.start()
f.set_exception(value_error)
t.join()
def test_done(self) -> None:
f = Future[torch.Tensor]()
self.assertFalse(f.done())
f.set_result(torch.ones(2, 2))
self.assertTrue(f.done())
def test_done_exception(self) -> None:
err_msg = "Intentional Value Error"
def raise_exception(unused_future):
raise RuntimeError(err_msg)
f1 = Future[torch.Tensor]()
self.assertFalse(f1.done())
f1.set_result(torch.ones(2, 2))
self.assertTrue(f1.done())
f2 = f1.then(raise_exception)
self.assertTrue(f2.done())
with self.assertRaisesRegex(RuntimeError, err_msg):
f2.wait()
def test_wait(self) -> None:
f = Future[torch.Tensor]()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self) -> None:
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future[torch.Tensor]()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self) -> None:
fut = Future[int]()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
def test_pickle_future(self):
fut = Future[int]()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future[torch.Tensor]()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future[torch.Tensor]()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)
fut.set_result(torch.ones(2, 2))
for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_then_error(self, cb, errMsg):
fut = Future[int]()
then_fut = fut.then(cb)
fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()
def test_then_wrong_arg(self):
def wrong_arg(tensor):
return tensor + 1
self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
def test_then_no_arg(self):
def no_arg():
return True
self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
def test_then_raise(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_then_error(raise_value_error, "Expected error")
def test_add_done_callback_simple(self):
callback_result = False
def callback(fut):
nonlocal callback_result
fut.wait()
callback_result = True
fut = Future[torch.Tensor]()
fut.add_done_callback(callback)
self.assertFalse(callback_result)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertTrue(callback_result)
def test_add_done_callback_maintains_callback_order(self):
callback_result = 0
def callback_set1(fut):
nonlocal callback_result
fut.wait()
callback_result = 1
def callback_set2(fut):
nonlocal callback_result
fut.wait()
callback_result = 2
fut = Future[torch.Tensor]()
fut.add_done_callback(callback_set1)
fut.add_done_callback(callback_set2)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
# set2 called last, callback_result = 2
self.assertEqual(callback_result, 2)
def _test_add_done_callback_error_ignored(self, cb):
fut = Future[int]()
fut.add_done_callback(cb)
fut.set_result(5)
# error msg logged to stdout
self.assertEqual(5, fut.wait())
def test_add_done_callback_error_is_ignored(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_add_done_callback_error_ignored(raise_value_error)
def test_add_done_callback_no_arg_error_is_ignored(self):
def no_arg():
return True
# Adding another level of function indirection here on purpose.
# Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
self._test_add_done_callback_error_ignored(no_arg)
def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
callback_result = 0
def callback_set1(fut):
nonlocal callback_result
fut.wait()
callback_result = 1
def callback_set2(fut):
nonlocal callback_result
fut.wait()
callback_result = 2
def callback_then(fut):
nonlocal callback_result
return fut.wait() + callback_result
fut = Future[torch.Tensor]()
fut.add_done_callback(callback_set1)
then_fut = fut.then(callback_then)
fut.add_done_callback(callback_set2)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
# then_fut's callback is called with callback_result = 1
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
# set2 called last, callback_result = 2
self.assertEqual(callback_result, 2)
def test_interleaving_then_and_add_done_callback_propagates_error(self):
def raise_value_error(fut):
raise ValueError("Expected error")
fut = Future[torch.Tensor]()
then_fut = fut.then(raise_value_error)
fut.add_done_callback(raise_value_error)
fut.set_result(torch.ones(2, 2))
# error from add_done_callback's callback is swallowed
# error from then's callback is not
self.assertEqual(fut.wait(), torch.ones(2, 2))
with self.assertRaisesRegex(RuntimeError, "Expected error"):
then_fut.wait()
def test_collect_all(self):
fut1 = Future[int]()
fut2 = Future[int]()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
time.sleep(0.1)
fut.set_result(value)
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
fut2.set_result(2)
t.start()
res = fut_all.wait()
self.assertEqual(res[0].wait(), 1)
self.assertEqual(res[1].wait(), 2)
t.join()
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
def test_wait_all(self):
fut1 = Future[int]()
fut2 = Future[int]()
# No error version
fut1.set_result(1)
fut2.set_result(2)
res = torch.futures.wait_all([fut1, fut2])
print(res)
self.assertEqual(res, [1, 2])
# Version with an exception
def raise_in_fut(fut):
raise ValueError("Expected error")
fut3 = fut1.then(raise_in_fut)
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])
def test_wait_none(self):
fut1 = Future[int]()
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.jit.wait(None)
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.futures.wait_all((None,)) # type: ignore[arg-type]
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type]
if __name__ == '__main__':
run_tests()