blob: acad97827ec423bf7f051462224d591221b4317c [file] [log] [blame]
# Owner(s): ["module: multiprocessing"]
import os
import pickle
import random
import signal
import sys
import time
import unittest
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import (
IS_WINDOWS,
NO_MULTIPROCESSING_SPAWN,
run_tests,
TestCase,
)
def _test_success_func(i):
pass
def _test_success_single_arg_func(i, arg):
if arg:
arg.put(i)
def _test_exception_single_func(i, arg):
if i == arg:
raise ValueError("legitimate exception from process %d" % i)
time.sleep(1.0)
def _test_exception_all_func(i):
time.sleep(random.random() / 10)
raise ValueError("legitimate exception from process %d" % i)
def _test_terminate_signal_func(i):
if i == 0:
os.kill(os.getpid(), signal.SIGABRT)
time.sleep(1.0)
def _test_terminate_exit_func(i, arg):
if i == 0:
sys.exit(arg)
time.sleep(1.0)
def _test_success_first_then_exception_func(i, arg):
if i == 0:
return
time.sleep(0.1)
raise ValueError("legitimate exception")
def _test_nested_child_body(i, ready_queue, nested_child_sleep):
ready_queue.put(None)
time.sleep(nested_child_sleep)
def _test_infinite_task(i):
while True:
time.sleep(1)
def _test_process_exit(idx):
sys.exit(12)
def _test_nested(i, pids_queue, nested_child_sleep, start_method):
context = mp.get_context(start_method)
nested_child_ready_queue = context.Queue()
nprocs = 2
mp_context = mp.start_processes(
fn=_test_nested_child_body,
args=(nested_child_ready_queue, nested_child_sleep),
nprocs=nprocs,
join=False,
daemon=False,
start_method=start_method,
)
pids_queue.put(mp_context.pids())
# Wait for both children to have started, to ensure that they
# have called prctl(2) to register a parent death signal.
for _ in range(nprocs):
nested_child_ready_queue.get()
# Kill self. This should take down the child processes as well.
os.kill(os.getpid(), signal.SIGTERM)
class _TestMultiProcessing:
start_method = None
def test_success(self):
mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method)
def test_success_non_blocking(self):
mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method)
# After all processes (nproc=2) have joined it must return True
mp_context.join(timeout=None)
mp_context.join(timeout=None)
self.assertTrue(mp_context.join(timeout=None))
def test_first_argument_index(self):
context = mp.get_context(self.start_method)
queue = context.SimpleQueue()
mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method)
self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
def test_exception_single(self):
nprocs = 2
for i in range(nprocs):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process %d$" % i,
):
mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)
def test_exception_all(self):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process (0|1)$",
):
mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method)
def test_terminate_signal(self):
# SIGABRT is aliased with SIGIOT
message = "process 0 terminated with signal (SIGABRT|SIGIOT)"
# Termination through with signal is expressed as a negative exit code
# in multiprocessing, so we know it was a signal that caused the exit.
# This doesn't appear to exist on Windows, where the exit code is always
# positive, and therefore results in a different exception message.
# Exit code 22 means "ERROR_BAD_COMMAND".
if IS_WINDOWS:
message = "process 0 terminated with exit code 22"
with self.assertRaisesRegex(Exception, message):
mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method)
def test_terminate_exit(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"process 0 terminated with exit code %d" % exitcode,
):
mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
def test_success_first_then_exception(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"ValueError: legitimate exception",
):
mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
@unittest.skipIf(
sys.platform != "linux",
"Only runs on Linux; requires prctl(2)",
)
def _test_nested(self):
context = mp.get_context(self.start_method)
pids_queue = context.Queue()
nested_child_sleep = 20.0
mp_context = mp.start_processes(
fn=_test_nested,
args=(pids_queue, nested_child_sleep, self.start_method),
nprocs=1,
join=False,
daemon=False,
start_method=self.start_method,
)
# Wait for nested children to terminate in time
pids = pids_queue.get()
start = time.time()
while len(pids) > 0:
for pid in pids:
try:
os.kill(pid, 0)
except ProcessLookupError:
pids.remove(pid)
break
# This assert fails if any nested child process is still
# alive after (nested_child_sleep / 2) seconds. By
# extension, this test times out with an assertion error
# after (nested_child_sleep / 2) seconds.
self.assertLess(time.time() - start, nested_child_sleep / 2)
time.sleep(0.1)
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class SpawnTest(TestCase, _TestMultiProcessing):
start_method = 'spawn'
def test_exception_raises(self):
with self.assertRaises(mp.ProcessRaisedException):
mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1)
def test_signal_raises(self):
context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False)
for pid in context.pids():
os.kill(pid, signal.SIGTERM)
with self.assertRaises(mp.ProcessExitedException):
context.join()
def _test_process_exited(self):
with self.assertRaises(mp.ProcessExitedException) as e:
mp.spawn(_test_process_exit, args=(), nprocs=1)
self.assertEqual(12, e.exit_code)
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork'
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
orig_paralell_env_val = None
def setUp(self):
super().setUp()
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self):
super().tearDown()
if self.orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerPerfTest(TestCase):
def test_forkserver_perf(self):
start_method = 'forkserver'
expensive = Expensive()
nprocs = 4
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
# test the non parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the elapsed time should be at least {nprocs}x the sleep time
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
# test the parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the elapsed time should be less than {nprocs}x the sleep time
self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs)
if orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
class Expensive:
SLEEP_SECS = 5
# Simulate startup overhead such as large imports
time.sleep(SLEEP_SECS)
def __init__(self):
self.config: str = "*" * 1000000
def my_call(self, *args):
pass
class ErrorTest(TestCase):
def test_errors_pickleable(self):
for error in (
mp.ProcessRaisedException("Oh no!", 1, 1),
mp.ProcessExitedException("Oh no!", 1, 1, 1),
):
pickle.loads(pickle.dumps(error))
if __name__ == '__main__':
run_tests()