blob: f785f56f937f9b89ee839c3087c0b95a629fa323 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import random
import sys
import unittest
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestUtils(TestCase):
@parametrize(
"devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")]
)
def test_apply_to_tensors(self, devices):
if "cuda" in devices and (
not torch.cuda.is_available() or torch.cuda.device_count() < 1
):
raise unittest.SkipTest("Skipped due to lack of GPU")
expected = 0
def get_a_tensor():
"""Return a random tensor on random device."""
dev = random.choice(devices)
shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
t = torch.rand(shape).to(dev)
nonlocal expected
expected += t.numel()
return t
@dataclass
class SomeDataClass:
some_key: str
some_float: float
some_tensor: List[torch.Tensor]
# create a mixed bag of data.
data = [1, "str"]
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, {"x", get_a_tensor(), get_a_tensor()})
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))
data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])})
od = OrderedDict()
od["k"] = "value"
data.append(od)
total = 0
def fn(t):
nonlocal total
total += t.numel()
return t
new_data = _apply_to_tensors(fn, data)
self.assertEqual(total, expected)
for i, v in enumerate(data):
self.assertEqual(type(new_data[i]), type(v))
def test_replace_by_prefix(self):
state_dict = {
"layer.a": torch.tensor(1),
"abc.layer.def": torch.tensor(2),
"layer.b": torch.tensor(3),
}
original_state_dict = state_dict.copy()
_replace_by_prefix(state_dict, "layer.", "module.layer.")
assert state_dict == {
"module.layer.a": torch.tensor(1),
"abc.layer.def": torch.tensor(2),
"module.layer.b": torch.tensor(3),
}
_replace_by_prefix(state_dict, "module.layer.", "layer.")
assert state_dict == original_state_dict
def test_packed_sequence(self):
"""Test to ensure RNN packed sequences are modified correctly."""
rnn = nn.RNN(5, 5)
x = torch.rand((5, 1, 5), dtype=torch.float)
seq_length = torch.tensor([4], dtype=torch.int)
def fill_fn(x):
x.fill_(0)
x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
x, h = rnn(x)
x = _apply_to_tensors(fill_fn, x)
x, _ = nn.utils.rnn.pad_packed_sequence(x)
self.assertEqual(torch.sum(x), 0)
instantiate_parametrized_tests(TestUtils)
if __name__ == "__main__":
run_tests()