blob: 262737f9dd75c73c87866bf03d5537dac36f7ddc [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
import torch.distributed as dist
torch.backends.cuda.matmul.allow_tf32 = False
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN
from torch.testing._internal.distributed.distributed_test import (
DistributedTest, TestDistBackend
)
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
if NO_MULTIPROCESSING_SPAWN:
print("Spawn not available, skipping tests.", file=sys.stderr)
sys.exit(0)
BACKEND = os.environ["BACKEND"]
if BACKEND == "gloo" or BACKEND == "nccl":
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
def setUp(self):
super().setUp()
self._spawn_processes()
torch.backends.cudnn.flags(allow_tf32=False).__enter__()
if __name__ == "__main__":
run_tests()