| # mypy: allow-untyped-defs |
| # Owner(s): ["module: unknown"] |
| |
| import os |
| import random |
| import re |
| import shutil |
| import subprocess |
| import sys |
| import tempfile |
| import textwrap |
| import traceback |
| import unittest |
| import warnings |
| from typing import Any, Dict, List |
| |
| import torch |
| import torch.cuda |
| import torch.nn as nn |
| import torch.utils.cpp_extension |
| import torch.utils.data |
| from torch.autograd._functions.utils import check_onnx_broadcast |
| from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings |
| from torch.testing._internal.common_cuda import TEST_MULTIGPU |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCPU, |
| ops, |
| ) |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] |
| IS_FBCODE, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| load_tests, |
| ) |
| from torch.utils._device import set_device |
| from torch.utils._pytree import tree_all_only, tree_any |
| from torch.utils._traceback import ( |
| CapturedTraceback, |
| format_traceback_short, |
| report_compile_source_on_error, |
| ) |
| from torch.utils.checkpoint import ( |
| _infer_device_type, |
| checkpoint, |
| checkpoint_sequential, |
| get_device_states, |
| ) |
| from torch.utils.data import DataLoader |
| |
| |
| # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| HAS_CUDA = torch.cuda.is_available() |
| |
| |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| class RandomDatasetMock(torch.utils.data.Dataset): |
| def __getitem__(self, index): |
| return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) |
| |
| def __len__(self): |
| return 1000 |
| |
| |
| class TestCheckpoint(TestCase): |
| # This runs checkpoint_sequential on each of the nets in |
| # module_lists_to_compare, and compares them against the uncheckpointed model. |
| # To compare, it checks outputs as well as input gradients and parameter gradients |
| def _check_checkpoint_sequential( |
| self, |
| model, |
| module_lists_to_compare, |
| num_chunks, |
| input, |
| use_reentrant, |
| ): |
| # not checkpointed |
| out = model(input) |
| out_not_checkpointed = out.detach().clone() |
| model.zero_grad() |
| out.sum().backward() |
| grad_not_checkpointed = { |
| name: param.grad.detach().clone() |
| for name, param in model.named_parameters() |
| } |
| input_grad_not_checkpointed = input.grad.detach().clone() |
| for model_to_compare in module_lists_to_compare: |
| # checkpointed model by passing list of modules |
| detached = input.detach() |
| detached.requires_grad = True |
| |
| # pass list of modules to checkpoint |
| out = checkpoint_sequential( |
| model_to_compare, num_chunks, detached, use_reentrant=use_reentrant |
| ) |
| out_checkpointed = out.detach().clone() |
| model.zero_grad() |
| out.sum().backward() |
| grad_checkpointed = { |
| name: param.grad.detach().clone() |
| for name, param in model.named_parameters() |
| } |
| input_grad_checkpointed = detached.grad.detach().clone() |
| # compare outputs as well as the gradients of input and parameters |
| self.assertEqual(out_checkpointed, out_not_checkpointed) |
| self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) |
| for name in grad_checkpointed: |
| self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) |
| |
| # Test whether checkpoint is being triggered or not. For this, we check |
| # the number of times forward pass happens |
| def test_checkpoint_trigger(self): |
| class Net(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.counter = 0 |
| |
| def forward(self, input_var): |
| self.counter += 1 |
| # For reentrant, need to have autograd actually |
| # pack a tensor to trigger recomp |
| ret = input_var * torch.tensor(2.0) |
| return ret |
| |
| # checkpointed |
| for use_reentrant in [True, False]: |
| with self.subTest(use_reentrant=use_reentrant): |
| modules = [Net() for _ in range(10)] |
| for m in modules: |
| self.assertEqual(m.counter, 0) |
| input_var = torch.randn(3, 4, requires_grad=True) |
| out = checkpoint_sequential( |
| modules, 2, input_var, use_reentrant=use_reentrant |
| ) |
| for m in modules: |
| self.assertEqual(m.counter, 1) |
| out.sum().backward() |
| for m in modules[: (len(modules) // 2)]: |
| self.assertEqual(m.counter, 2) |
| for m in modules[(len(modules) // 2) :]: |
| self.assertEqual(m.counter, 1) |
| |
| def test_checkpoint_valid(self): |
| model = nn.Sequential( |
| nn.Linear(100, 50), |
| nn.ReLU(), |
| nn.Linear(50, 20), |
| nn.ReLU(), |
| nn.Linear(20, 5), |
| nn.ReLU(), |
| ) |
| |
| input_var = torch.randn(1, 100, requires_grad=True) |
| |
| # checkpointed |
| chunks = 2 |
| modules = list(model.children()) |
| out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True) |
| with self.assertRaisesRegex( |
| RuntimeError, "torch.utils.checkpoint is incompatible" |
| ): |
| torch.autograd.grad( |
| outputs=[out], |
| grad_outputs=[torch.ones(1, 5)], |
| inputs=[input_var], |
| create_graph=True, |
| ) |
| # works with use_reentrant=False, and grads are the same |
| out = model(input_var) |
| grads_no_checkpoint = torch.autograd.grad( |
| outputs=[out], |
| grad_outputs=[torch.ones(1, 5)], |
| inputs=[input_var], |
| create_graph=True, |
| ) |
| out_checkpoint = checkpoint_sequential( |
| modules, chunks, input_var, use_reentrant=False |
| ) |
| # check outputs are the same |
| self.assertEqual(out_checkpoint, out) |
| grads_checkpoint = torch.autograd.grad( |
| outputs=[out_checkpoint], |
| grad_outputs=[torch.ones(1, 5)], |
| inputs=[input_var], |
| create_graph=True, |
| ) |
| self.assertEqual(grads_no_checkpoint, grads_checkpoint) |
| |
| def test_checkpoint(self): |
| for use_reentrant in [True, False]: |
| with self.subTest(use_reentrant=use_reentrant): |
| model = nn.Sequential( |
| nn.Linear(100, 50), |
| nn.ReLU(), |
| nn.Linear(50, 20), |
| nn.ReLU(), |
| nn.Linear(20, 5), |
| nn.ReLU(), |
| ) |
| |
| # Compare uncheckpointed model with its checkpointed counterparts |
| # In addition to running checkpoint_sequential on the nn.Sequential |
| # instance, we also run the function on the list of functions within |
| # the module. |
| self._check_checkpoint_sequential( |
| model, |
| [list(model.children()), model], |
| 2, |
| torch.randn(1, 100, requires_grad=True), |
| use_reentrant=use_reentrant, |
| ) |
| |
| def test_checkpoint_module_list(self): |
| class ModuleListNet(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| module_list = [ |
| nn.Linear(100, 50), |
| nn.ReLU(), |
| nn.Linear(50, 20), |
| nn.ReLU(), |
| nn.Linear(20, 5), |
| nn.ReLU(), |
| ] |
| self.module_list = nn.ModuleList(module_list) |
| |
| def forward(self, input): |
| for layer in self.module_list: |
| input = layer(input) |
| return input |
| |
| for use_reentrant in [True, False]: |
| with self.subTest(use_reentrant=use_reentrant): |
| model = ModuleListNet() |
| |
| # Compare uncheckpointed model with its checkpointed counterparts. |
| self._check_checkpoint_sequential( |
| model, |
| [list(model.module_list.children()), model.module_list], |
| 2, |
| torch.randn(1, 100, requires_grad=True), |
| use_reentrant=use_reentrant, |
| ) |
| |
| def test_checkpoint_sequential_deprecated_multiple_args(self): |
| class Two(nn.Module): |
| def forward(self, a, b): |
| return a, b |
| |
| model = nn.Sequential(Two()) |
| a = torch.randn(1, 100, requires_grad=True) |
| b = torch.randn(1, 100, requires_grad=True) |
| |
| for use_reentrant in [True, False]: |
| with self.subTest(use_reentrant=use_reentrant): |
| with self.assertRaises(TypeError): |
| checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] |
| |
| def test_checkpoint_sequential_deprecated_no_args(self): |
| class Noop(nn.Module): |
| def forward(self): |
| pass |
| |
| model = nn.Sequential(Noop()) |
| for use_reentrant in [True, False]: |
| with self.subTest(use_reentrant=use_reentrant): |
| with self.assertRaises(TypeError): |
| checkpoint_sequential(model, 1) # type: ignore[call-arg] |
| |
| def test_checkpoint_rng_cpu(self): |
| for _ in range(5): |
| inp = torch.randn(20000, device="cpu").requires_grad_() |
| phase1 = torch.nn.Dropout() |
| phase2 = torch.nn.Dropout() |
| |
| def run_fn(input): |
| return phase2(input) |
| |
| state = torch.get_rng_state() |
| |
| out = phase1(inp) |
| out = checkpoint(run_fn, out, use_reentrant=True) |
| out.sum().backward() |
| grad_with_checkpointing = inp.grad |
| |
| torch.set_rng_state(state) |
| |
| inp.grad = None |
| |
| out = phase1(inp) |
| out = run_fn(out) |
| out.sum().backward() |
| grad_no_checkpointing = inp.grad |
| |
| self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) |
| |
| @unittest.skipIf(not HAS_CUDA, "No CUDA") |
| def test_checkpoint_rng_cuda(self): |
| for _ in range(5): |
| inp = torch.randn(20000, device="cuda").requires_grad_() |
| phase1 = torch.nn.Dropout() |
| phase2 = torch.nn.Dropout() |
| |
| def run_fn(input): |
| return phase2(input) |
| |
| state = torch.cuda.get_rng_state() |
| |
| out = phase1(inp) |
| out = checkpoint(run_fn, out, use_reentrant=True) |
| out.sum().backward() |
| grad_with_checkpointing = inp.grad |
| |
| torch.cuda.set_rng_state(state) |
| |
| inp.grad = None |
| |
| out = phase1(inp) |
| out = run_fn(out) |
| out.sum().backward() |
| grad_no_checkpointing = inp.grad |
| |
| self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) |
| |
| @unittest.skipIf(not HAS_CUDA, "No CUDA") |
| def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): |
| inp = torch.randn(2, device="cuda").requires_grad_() |
| layer = torch.nn.Dropout() |
| |
| def run_fn(input): |
| return layer(input) |
| |
| out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False) |
| out.sum().backward() |
| # This should run without error |
| |
| def test_checkpoint_non_tensor(self): |
| def run_fn(tensor1, tensor2): |
| if tensor2 is None: |
| return tensor1 |
| return tensor1 + tensor2 |
| |
| input_var = torch.randn(1, 100, requires_grad=True) |
| out = checkpoint(run_fn, input_var, None, use_reentrant=True) |
| out.sum().backward() |
| |
| def test_checkpoint_non_tensor_inputs_outputs(self): |
| def foo(t1, t2, scale, t3): |
| t4 = t1 + t2 * t3 |
| t5 = t1 * t2 + t3 |
| t4 *= scale |
| t5 *= scale |
| return scale, t4, None, True, t5, "bar", t1 |
| |
| t1 = torch.rand(10, requires_grad=True) |
| t2 = torch.rand(10, requires_grad=True) |
| t3 = torch.rand(10) |
| scale = random.randint(0, 10) |
| res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) |
| self.assertEqual(scale, res[0]) |
| self.assertEqual((t1 + t2 * t3) * scale, res[1]) |
| self.assertEqual(None, res[2]) |
| self.assertEqual(True, res[3]) |
| self.assertEqual((t1 * t2 + t3) * scale, res[4]) |
| self.assertEqual("bar", res[5]) |
| self.assertEqual(t1, res[6]) |
| |
| # Validate running backward. |
| res[1].sum().backward(retain_graph=True) |
| res[4].sum().backward(retain_graph=True) |
| res[6].sum().backward() |
| with self.assertRaisesRegex( |
| RuntimeError, "Trying to backward through the graph a second time" |
| ): |
| res[6].sum().backward() |
| t1_grad = t1.grad |
| t2_grad = t2.grad |
| |
| # Reset grads, run without checkpoint and validate we receive same grads. |
| t1.grad = None |
| t2.grad = None |
| res = foo(t1, t2, scale, t3) |
| torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) |
| self.assertEqual(t1.grad, t1_grad) |
| self.assertEqual(t2.grad, t2_grad) |
| |
| def test_checkpoint_no_tensors(self): |
| def foo(t1, t2, scale, t3): |
| t4 = t1 + t2 * t3 |
| t5 = t1 * t2 + t3 |
| t4 *= scale |
| t5 *= scale |
| return scale, t4, None, True, t5, "bar", t1 |
| |
| t1 = random.random() |
| t2 = random.random() |
| t3 = random.random() |
| scale = random.randint(0, 10) |
| res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) |
| self.assertEqual(scale, res[0]) |
| self.assertEqual((t1 + t2 * t3) * scale, res[1]) |
| self.assertEqual(None, res[2]) |
| self.assertEqual(True, res[3]) |
| self.assertEqual((t1 * t2 + t3) * scale, res[4]) |
| self.assertEqual("bar", res[5]) |
| self.assertEqual(t1, res[6]) |
| |
| def test_checkpoint_partial_grad(self): |
| def run_fn(tensor1, tensor2): |
| # tensor 2 is used for other application logic |
| return tensor1, tensor2 |
| |
| input_var = torch.randn(1, 4, requires_grad=True) |
| input_var2 = torch.randn(1, 4, requires_grad=False) |
| out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True) |
| out[0].sum().backward() |
| |
| def run_fn2(tensor1, tensor2): |
| return tensor1 |
| |
| input_var = torch.randn(1, 4, requires_grad=False) |
| input_var2 = torch.randn(1, 4, requires_grad=True) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"none of output has requires_grad=True, this checkpoint\(\) is not necessary", |
| ): |
| out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) |
| out.sum().backward() |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") |
| def test_checkpointing_without_reentrant_early_free(self): |
| # I don't know how to check if the temporary saved variable buffer |
| # get de-allocated directly. So using cuda memory usage as a proxy |
| |
| def _do_test(fn, should_free): |
| stats: List[int] = [] |
| |
| def track(x, idx): |
| # Track that at each step of the backward, some Tensor were |
| # de-allocated (which correspond to the checkpoint storage being |
| # emptied at each step) |
| def hook(_unused): |
| self.assertEqual(len(stats), idx) |
| torch.cuda.synchronize() |
| stats.append(torch.cuda.memory_allocated()) |
| if idx > 0: |
| if should_free: |
| self.assertLess(stats[idx], stats[idx - 1]) |
| else: |
| self.assertEqual(stats[idx], stats[idx - 1]) |
| |
| x.register_hook(hook) |
| |
| def test_fn(x): |
| # The main property of this function is that it contains multiple |
| # operations that save gradients in a chain. |
| x = x**2 |
| track(x, 2) |
| x = x**2 |
| track(x, 1) |
| x = x**2 |
| track(x, 0) |
| x = x**2 |
| return x.sum() |
| |
| fn(test_fn) |
| |
| return stats |
| |
| x = torch.zeros(10, device="cuda", requires_grad=True) |
| x.grad = torch.zeros_like(x) |
| |
| # In a regular backward, buffers get eagerly freed |
| non_retain_stats = _do_test(lambda fn: fn(x).backward(), True) |
| |
| # In a retain_grad backward, buffers get preserved |
| _unused_retain_stats = _do_test( |
| lambda fn: fn(x).backward(retain_graph=True), False |
| ) |
| |
| # In a regular backward with checkpoint, buffers get eagerly freed |
| checkpoint_non_retain_stats = _do_test( |
| lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True |
| ) |
| |
| # In a retain_grad backward with checkpoint, buffers get eagerly freed |
| checkpoint_retain_stats = _do_test( |
| lambda fn: checkpoint(fn, x, use_reentrant=False).backward( |
| retain_graph=True |
| ), |
| True, |
| ) |
| |
| self.assertEqual(non_retain_stats, checkpoint_non_retain_stats) |
| self.assertEqual(non_retain_stats, checkpoint_retain_stats) |
| |
| @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
| def test_get_device_states_recursive(self): |
| inp = { |
| "foo": torch.rand(10, device="cuda:0"), |
| "bar": [torch.rand(10, device="cuda:1")], |
| } |
| device_ids, device_states = get_device_states(inp) |
| self.assertEqual(2, len(device_ids)) |
| self.assertEqual(2, len(device_states)) |
| self.assertEqual(0, device_ids[0]) |
| self.assertEqual(1, device_ids[1]) |
| self.assertTrue(isinstance(device_states[0], torch.Tensor)) |
| self.assertTrue(isinstance(device_states[1], torch.Tensor)) |
| |
| def test_infer_device_state_recursive_meta(self): |
| inp = {"foo": torch.rand(10, device="meta")} |
| device_type = _infer_device_type(inp) |
| self.assertEqual("meta", device_type) |
| |
| @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
| def test_infer_device_state_recursive_multi_cuda(self): |
| # Check that no warning is issued for either cuda:0, cuda:1 or |
| # cuda:0, cuda:0 cases since they are both the same device type |
| inp = { |
| "foo": torch.rand(10, device="cuda:0"), |
| "bar": [torch.rand(10, device="cuda:1")], |
| } |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error") |
| device_type = _infer_device_type(inp) |
| self.assertEqual("cuda", device_type) |
| inp = { |
| "foo": torch.rand(10, device="cuda:0"), |
| "bar": [torch.rand(10, device="cuda:0")], |
| } |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error") |
| device_type = _infer_device_type(inp) |
| self.assertEqual("cuda", device_type) |
| # Check that a warning is issued for cuda:0, meta and that it includes |
| # device type information |
| inp = { |
| "foo": torch.rand(10, device="cuda:0"), |
| "bar": [torch.rand(10, device="meta")], |
| } |
| with warnings.catch_warnings(record=True) as w: |
| device_type = _infer_device_type(inp) |
| self.assertEqual("cuda", device_type) |
| self.assertEqual(len(w), 1) |
| warning_msg = str(w[-1].message) |
| self.assertTrue( |
| "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" |
| in warning_msg |
| ) |
| self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg) |
| self.assertTrue("first device type: cuda" in warning_msg) |
| |
| |
| class TestDataLoaderUtils(TestCase): |
| MAX_TIMEOUT_IN_SECOND = 300 |
| |
| def setUp(self): |
| super().setUp() |
| self.dataset = torch.randn(5, 3, 3, 2) |
| self.batch_size = 3 |
| |
| def test_random_seed(self): |
| def run(): |
| dataloader = torch.utils.data.DataLoader( |
| RandomDatasetMock(), |
| batch_size=2, |
| num_workers=4, |
| shuffle=True, |
| timeout=self.MAX_TIMEOUT_IN_SECOND, |
| ) |
| return next(iter(dataloader)) |
| |
| torch.manual_seed(2018) |
| x1 = run() |
| torch.manual_seed(2018) |
| x2 = run() |
| self.assertEqual(x1, x2) |
| |
| def test_single_keep(self): |
| # self.dataset is a Tensor here; technically not a valid input because |
| # not a Dataset subclass, but needs to stay working so add ignore's |
| # for type checking with mypy |
| dataloader: DataLoader = DataLoader( |
| self.dataset, # type: ignore[arg-type] |
| batch_size=self.batch_size, |
| num_workers=0, |
| drop_last=False, |
| ) |
| dataiter = iter(dataloader) |
| self.assertEqual(len(list(dataiter)), 2) |
| |
| def test_single_drop(self): |
| dataloader: DataLoader = DataLoader( |
| self.dataset, # type: ignore[arg-type] |
| batch_size=self.batch_size, |
| num_workers=0, |
| drop_last=True, |
| ) |
| dataiter = iter(dataloader) |
| self.assertEqual(len(list(dataiter)), 1) |
| |
| @unittest.skip( |
| "FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN" |
| ) |
| def test_multi_keep(self): |
| dataloader: DataLoader = DataLoader( |
| self.dataset, # type: ignore[arg-type] |
| batch_size=self.batch_size, |
| num_workers=2, |
| drop_last=False, |
| timeout=self.MAX_TIMEOUT_IN_SECOND, |
| ) |
| dataiter = iter(dataloader) |
| self.assertEqual(len(list(dataiter)), 2) |
| |
| def test_multi_drop(self): |
| dataloader: DataLoader = DataLoader( |
| self.dataset, # type: ignore[arg-type] |
| batch_size=self.batch_size, |
| num_workers=2, |
| drop_last=True, |
| timeout=self.MAX_TIMEOUT_IN_SECOND, |
| ) |
| dataiter = iter(dataloader) |
| self.assertEqual(len(list(dataiter)), 1) |
| |
| |
| test_dir = os.path.abspath(os.path.dirname(str(__file__))) |
| |
| |
| @unittest.skipIf( |
| "SKIP_TEST_BOTTLENECK" in os.environ.keys(), "SKIP_TEST_BOTTLENECK is set" |
| ) |
| class TestBottleneck(TestCase): |
| def _run(self, command, timeout=30): |
| """Returns (return-code, stdout, stderr)""" |
| import subprocess |
| |
| p = subprocess.Popen( |
| command, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| shell=True, |
| ) |
| try: |
| output, err = p.communicate(timeout=timeout) |
| except subprocess.TimeoutExpired: |
| p.kill() |
| output, err = p.communicate() |
| rc = p.returncode |
| output_str = output.decode("ascii") |
| err_str = err.decode("ascii") |
| return (rc, output_str, err_str) |
| |
| def _run_bottleneck(self, test_file, scriptargs=""): |
| curdir = os.path.dirname(os.path.abspath(__file__)) |
| filepath = f"{curdir}/{test_file}" |
| if scriptargs != "": |
| scriptargs = f" {scriptargs}" |
| rc, out, err = self._run( |
| f"{sys.executable} -m torch.utils.bottleneck {filepath}{scriptargs}" |
| ) |
| return rc, out, err |
| |
| def _check_run_args(self): |
| # Check that this fails due to missing args |
| rc, out, err = self._run_bottleneck("bottleneck_test/test_args.py") |
| self.assertEqual( |
| rc, |
| 2, |
| atol=0, |
| rtol=0, |
| msg=self._fail_msg("Missing args should error", out + err), |
| ) |
| |
| # This should succeed |
| rc, out, err = self._run_bottleneck( |
| "bottleneck_test/test_args.py", "--foo foo --bar bar" |
| ) |
| self.assertEqual( |
| rc, |
| 0, |
| atol=0, |
| rtol=0, |
| msg=self._fail_msg("Should pass args to script", out + err), |
| ) |
| |
| def _fail_msg(self, msg, output): |
| return f"{msg}, output was:\n{output}" |
| |
| def _check_environment_summary(self, output): |
| results = re.search("Environment Summary", output) |
| self.assertIsNotNone( |
| results, self._fail_msg("Should have Environment Summary", output) |
| ) |
| |
| # Up to five lines away from the heading, there should be the version number |
| results = re.search( |
| r"Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+", output |
| ) |
| self.assertIsNotNone( |
| results, self._fail_msg("Should have PyTorch version", output) |
| ) |
| |
| def _check_cprof_summary(self, output): |
| results = re.search("cProfile output", output) |
| self.assertIsNotNone( |
| results, self._fail_msg("Should have cProfile output", output) |
| ) |
| |
| # This assumes that after the cProfile output section we have |
| # the autograd profiler output |
| results = re.search( |
| r"cProfile output.*(\n.*){6,50}\n.*autograd profiler output", output |
| ) |
| self.assertIsNotNone( |
| results, |
| self._fail_msg( |
| "Distance between cProfile and autograd prof out not in [6, 50] lines", |
| output, |
| ), |
| ) |
| |
| def _check_autograd_summary(self, output): |
| results = re.search("autograd profiler output", output) |
| self.assertIsNotNone( |
| results, self._fail_msg("Should have autograd profiler output", output) |
| ) |
| |
| # This assumes that after the autograd profiler output is the end of the |
| # output. |
| results = re.search(r"autograd profiler output.*(\n.*){6,100}", output) |
| self.assertIsNotNone( |
| results, |
| self._fail_msg( |
| "Distance between autograd prof output and end of output not in [6, 100] lines", |
| output, |
| ), |
| ) |
| |
| def _check_cuda(self, output): |
| if HAS_CUDA: |
| results = re.search("CUDA mode", output) |
| self.assertIsNotNone( |
| results, self._fail_msg("Should tell users CUDA", output) |
| ) |
| else: |
| results = re.search("CUDA mode", output) |
| self.assertIsNone( |
| results, self._fail_msg("Should not tell users about CUDA", output) |
| ) |
| |
| @unittest.skipIf(HAS_CUDA, "CPU-only test") |
| def test_bottleneck_cpu_only(self): |
| rc, out, err = self._run_bottleneck("bottleneck_test/test.py") |
| self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") |
| |
| self._check_run_args() |
| self._check_environment_summary(out) |
| self._check_autograd_summary(out) |
| self._check_cprof_summary(out) |
| self._check_cuda(out) |
| |
| @unittest.skipIf(not HAS_CUDA, "No CUDA") |
| def test_bottleneck_cuda(self): |
| rc, out, err = self._run_bottleneck("bottleneck_test/test_cuda.py") |
| self.assertEqual(rc, 0, msg=f"Run failed with\n{err}") |
| |
| self._check_run_args() |
| self._check_environment_summary(out) |
| self._check_autograd_summary(out) |
| self._check_cprof_summary(out) |
| self._check_cuda(out) |
| |
| |
| from torch.utils.collect_env import get_pretty_env_info |
| |
| |
| @unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") |
| class TestCollectEnv(TestCase): |
| def test_smoke(self): |
| info_output = get_pretty_env_info() |
| self.assertTrue(info_output.count("\n") >= 17) |
| |
| |
| class TestONNXUtils(TestCase): |
| def test_prepare_onnx_paddings(self): |
| sizes = [2, 3, 4] |
| pad = [1, 2, 3, 4] |
| paddings = _prepare_onnx_paddings(len(sizes), pad) |
| self.assertEqual(paddings, [0, 3, 1, 0, 4, 2]) |
| |
| def test_check_onnx_broadcast(self): |
| def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail): |
| broadcast = True |
| fail = False |
| try: |
| broadcast = check_onnx_broadcast(dims1, dims2) |
| except ValueError: |
| fail = True |
| self.assertEqual(broadcast, expect_broadcast) |
| self.assertEqual(fail, expect_fail) |
| |
| # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1 |
| dims1 = [3, 4] |
| dims2 = [2, 3, 4] |
| try_check_onnx_broadcast(dims1, dims2, True, True) |
| |
| # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1 |
| dims1 = [3, 4] |
| dims2 = [1, 1, 1] |
| try_check_onnx_broadcast(dims1, dims2, True, False) |
| |
| # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1 |
| dims1 = [1, 1] |
| dims2 = [1] |
| try_check_onnx_broadcast(dims1, dims2, True, False) |
| |
| # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2 |
| dims1 = [2, 3, 4] |
| dims2 = [3, 4] |
| try_check_onnx_broadcast(dims1, dims2, True, False) |
| |
| # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2 |
| dims1 = [2, 3, 4] |
| dims2 = [1, 4] |
| try_check_onnx_broadcast(dims1, dims2, True, True) |
| |
| # Case 6, check the equal case, no broadcast |
| dims1 = [3, 4] |
| dims2 = [3, 4] |
| try_check_onnx_broadcast(dims1, dims2, False, False) |
| |
| # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2 |
| dims1 = [3, 4] |
| dims2 = [1, 4] |
| try_check_onnx_broadcast(dims1, dims2, True, True) |
| |
| # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1 |
| dims1 = [3, 4] |
| dims2 = [1, 1] |
| try_check_onnx_broadcast(dims1, dims2, True, False) |
| |
| |
| class TestHipify(TestCase): |
| def test_import_hipify(self): |
| from torch.utils.hipify import hipify_python # noqa: F401 |
| |
| |
| class TestHipifyTrie(TestCase): |
| def setUp(self): |
| self.trie = torch.utils.hipify.hipify_python.Trie() |
| |
| def test_add_and_search_trie(self): |
| self.trie.add("banana") |
| self.assertTrue(self.trie.search("banana")) |
| self.assertFalse(self.trie.search("ban")) |
| self.assertFalse(self.trie.search("dog")) |
| |
| def test_add_multiple_and_search_trie(self): |
| words_to_add = ["banana", "apple", "orange"] |
| for word in words_to_add: |
| self.trie.add(word) |
| |
| for word in words_to_add: |
| self.assertTrue(self.trie.search(word)) |
| |
| for word in ["ban", "dog", "okay", "app"]: |
| self.assertFalse(self.trie.search(word)) |
| |
| def test_quote_escape(self): |
| orig_chars = ["*", "[", ".", "+", "a", "z", "-"] |
| quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"] |
| for i in range(len(orig_chars)): |
| self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i]) |
| |
| def test_export_trie_to_regex(self): |
| words_to_add = [ |
| "__CUDACC__", |
| "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", |
| "CUDA_ERROR_ARRAY_IS_MAPPED", |
| "CUDA_ERROR_NOT_MAPPED", |
| "CUDA_ERROR_INVALID_SOURCE", |
| ] |
| for word in words_to_add: |
| self.trie.add(word) |
| regex = self.trie.export_to_regex() |
| expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)" |
| self.assertEqual(regex, expected_regex) |
| |
| def test_prefix_words_export_trie_to_regex(self): |
| # test case where some nodes have both children and are also leaf nodes. |
| words_to_add = ["apple", "app", "ban", "banana"] |
| for word in words_to_add: |
| self.trie.add(word) |
| regex = self.trie.export_to_regex() |
| expected_regex = r"(?:app(?:le)?|ban(?:ana)?)" |
| self.assertEqual(regex, expected_regex) |
| |
| def test_single_export_trie_to_regex(self): |
| words_to_add = ["cudaErrorInvalidMemcpyDirection"] |
| for word in words_to_add: |
| self.trie.add(word) |
| regex = self.trie.export_to_regex() |
| expected_regex = "cudaErrorInvalidMemcpyDirection" |
| self.assertEqual(regex, expected_regex) |
| |
| def test_char_export_trie_to_regex(self): |
| self.trie.add("a") |
| self.assertEqual(self.trie.export_to_regex(), "a") |
| self.trie.add("b") |
| self.assertEqual(self.trie.export_to_regex(), "[ab]") |
| |
| def test_special_char_export_trie_to_regex(self): |
| self.trie.add(r"c*") |
| self.assertEqual(self.trie.export_to_regex(), r"c\*") |
| |
| |
| class TestAssert(TestCase): |
| def test_assert_true(self): |
| # verify assertions work as expected |
| # bool argument |
| torch._assert(True, "foo") |
| with self.assertRaisesRegex(AssertionError, "bar"): |
| torch._assert(False, "bar") |
| # tensor argument |
| torch._assert(torch.tensor([True], dtype=torch.bool), "foo") |
| with self.assertRaisesRegex(AssertionError, "bar"): |
| torch._assert(torch.tensor([False], dtype=torch.bool), "bar") |
| |
| def test_assert_scriptable(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| torch._assert(x.sum() > 0, "foo") |
| return x |
| |
| m = M() |
| # scriptable |
| ms = torch.jit.script(m) |
| # data can be passed without errors |
| x = torch.randn(4, 4).fill_(1.0) |
| ms(x) |
| with self.assertRaisesRegex(torch.jit.Error, "foo"): |
| ms(torch.tensor([False], dtype=torch.bool)) |
| |
| |
| @unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only") |
| class TestStandaloneCPPJIT(TestCase): |
| def test_load_standalone(self): |
| build_dir = tempfile.mkdtemp() |
| try: |
| src_path = os.path.join(build_dir, "main.cpp") |
| src = textwrap.dedent( |
| """\ |
| #include <iostream> |
| #include <torch/torch.h> |
| int main() { |
| auto x = torch::eye(3); |
| std::cout << x << std::endl; |
| } |
| """ |
| ) |
| with open(src_path, "w") as f: |
| f.write(src) |
| |
| exec_path = torch.utils.cpp_extension.load( |
| "standalone_load_test", |
| src_path, |
| build_directory=build_dir, |
| is_python_module=False, |
| is_standalone=True, |
| ) |
| |
| ext = ".exe" if IS_WINDOWS else "" |
| self.assertEqual( |
| exec_path, os.path.join(build_dir, f"standalone_load_test{ext}") |
| ) |
| |
| for shell in [True, False]: |
| r = subprocess.run( |
| [exec_path], |
| shell=shell, |
| stdout=subprocess.PIPE, |
| ) |
| self.assertEqual(r.returncode, 0) |
| self.assertEqual( |
| # Windows prints "\r\n" for newlines. |
| textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), |
| textwrap.dedent( |
| """\ |
| 1 0 0 |
| 0 1 0 |
| 0 0 1 |
| [ CPUFloatType{3,3} ] |
| """ |
| ), |
| ) |
| |
| finally: |
| shutil.rmtree(build_dir) |
| |
| |
| class DummyPrivateUse1Module: |
| @staticmethod |
| def is_available(): |
| return True |
| |
| @staticmethod |
| def is_autocast_enabled(): |
| return True |
| |
| @staticmethod |
| def get_autocast_dtype(): |
| return torch.float16 |
| |
| @staticmethod |
| def set_autocast_enabled(enable): |
| pass |
| |
| @staticmethod |
| def set_autocast_dtype(dtype): |
| pass |
| |
| @staticmethod |
| def get_amp_supported_dtype(): |
| return [torch.float16] |
| |
| |
| class TestExtensionUtils(TestCase): |
| def tearDown(self): |
| # Clean up |
| backend_name = torch._C._get_privateuse1_backend_name() |
| if hasattr(torch, backend_name): |
| delattr(torch, backend_name) |
| if f"torch.{backend_name}" in sys.modules: |
| del sys.modules[f"torch.{backend_name}"] |
| |
| def test_external_module_register(self): |
| # Built-in module |
| with self.assertRaisesRegex(RuntimeError, "The runtime module of"): |
| torch._register_device_module("cuda", torch.cuda) |
| |
| # Wrong device type |
| with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): |
| torch._register_device_module("dummmy", DummyPrivateUse1Module) |
| |
| with self.assertRaises(AttributeError): |
| torch.privateuseone.is_available() # type: ignore[attr-defined] |
| |
| torch._register_device_module("privateuseone", DummyPrivateUse1Module) |
| |
| torch.privateuseone.is_available() # type: ignore[attr-defined] |
| |
| # No supporting for override |
| with self.assertRaisesRegex(RuntimeError, "The runtime module of"): |
| torch._register_device_module("privateuseone", DummyPrivateUse1Module) |
| |
| def test_external_module_register_with_renamed_backend(self): |
| torch.utils.rename_privateuse1_backend("foo") |
| with self.assertRaisesRegex(RuntimeError, "has already been set"): |
| torch.utils.rename_privateuse1_backend("dummmy") |
| |
| custom_backend_name = torch._C._get_privateuse1_backend_name() |
| self.assertEqual(custom_backend_name, "foo") |
| |
| with self.assertRaises(AttributeError): |
| torch.foo.is_available() # type: ignore[attr-defined] |
| |
| with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): |
| with torch.autocast(device_type=custom_backend_name): |
| pass |
| torch._register_device_module("foo", DummyPrivateUse1Module) |
| |
| torch.foo.is_available() # type: ignore[attr-defined] |
| with torch.autocast(device_type=custom_backend_name): |
| pass |
| |
| self.assertEqual(torch._utils._get_device_index("foo:1"), 1) |
| self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2) |
| |
| |
| class TestRenderUtils(TestCase): |
| def test_basic(self): |
| self.assertExpectedInline( |
| torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}), |
| """torch.sum(tensor([...], size=(100,)), dim=0)""", |
| ) |
| self.assertExpectedInline( |
| torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}), |
| """torch.sum(tensor([...], size=(100, 100)), dim=0)""", |
| ) |
| |
| |
| class TestDeviceUtils(TestCase): |
| def test_basic(self): |
| with torch.device("meta") as dev: |
| x = torch.empty(3, 3) |
| self.assertEqual(x.device.type, "meta") |
| self.assertEqual(dev, torch.device("meta")) |
| |
| def test_decorator(self): |
| @set_device("meta") |
| def f(): |
| return torch.empty(3, 3) |
| |
| self.assertEqual(f().device.type, "meta") |
| |
| def test_decorator_generator(self): |
| @set_device("meta") |
| def f(): |
| yield torch.empty(3, 3) |
| yield torch.empty(3, 3) |
| |
| r1, r2 = list(f()) |
| self.assertEqual(r1.device.type, "meta") |
| self.assertEqual(r2.device.type, "meta") |
| |
| def test_nn_module(self): |
| with torch.device("meta"): |
| m = nn.Linear(40, 50) |
| self.assertEqual(m.weight.device.type, "meta") |
| |
| def test_set_default_device(self): |
| try: |
| torch.set_default_device("meta") |
| r = torch.empty(2, 2) |
| finally: |
| torch.set_default_device(None) |
| |
| self.assertEqual(r.device.type, "meta") |
| |
| def test_get_default_device(self): |
| torch.set_default_device("meta") |
| self.assertEqual(torch.get_default_device().type, "meta") |
| torch.set_default_device(None) |
| |
| @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
| def test_get_default_device_more(self): |
| torch.set_default_device("cuda") |
| self.assertEqual(torch.get_default_device(), torch.tensor([]).device) |
| torch.set_default_device(None) |
| |
| torch.set_default_device("cuda") |
| torch.cuda.set_device("cuda:1") |
| self.assertEqual(torch.get_default_device(), torch.tensor([]).device) |
| torch.set_default_device(None) |
| |
| torch.set_default_device("cuda:1") |
| self.assertEqual(torch.get_default_device(), torch.tensor([]).device) |
| torch.set_default_device(None) |
| |
| @onlyCPU |
| @ops(op_db) |
| def test_device_mode_ops(self, device, dtype, op): |
| func = op.get_op() |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| for sample in samples: |
| # Only test samples which don't have Tensor inputs. However, |
| # we don't test the factory property on OpInfo as it is very, |
| # very incomplete |
| if tree_any( |
| lambda x: isinstance(x, torch.Tensor), |
| (sample.input, sample.args, sample.kwargs), |
| ): |
| continue |
| # Many OpInfos will explicitly pass in a device. DeviceContext |
| # will respect device if it is explicitly specified. To test |
| # DeviceContext, we have to remove the device kwarg in this case. |
| # NB: Can't pass None to sample_inputs, the function can't |
| # handle it. |
| kwargs = sample.kwargs.copy() |
| kwargs.pop("device", None) |
| with torch.device("meta"): |
| r = func(sample.input, *sample.args, **kwargs) |
| |
| def is_meta_device(x: torch.Tensor) -> bool: |
| return x.device.type == "meta" |
| |
| self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) |
| |
| |
| instantiate_device_type_tests(TestDeviceUtils, globals()) |
| |
| |
| class TestCppExtensionUtils(TestCase): |
| def test_cpp_compiler_is_ok(self): |
| self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++")) |
| |
| def test_cc_compiler_is_ok(self): |
| self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc")) |
| |
| |
| class TestTraceback(TestCase): |
| def test_basic(self): |
| source = """\ |
| def f(x): |
| def g(x): |
| raise RuntimeError # HEYA |
| |
| x = x * 3 |
| return g(x) + 1 |
| """ |
| |
| out: Dict[str, Any] = {} |
| scope = {"__compile_source__": source} |
| exec(source, scope, out) |
| |
| try: |
| with report_compile_source_on_error(): |
| out["f"](1) |
| except RuntimeError as e: |
| self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__))) |
| |
| def test_format_traceback_short(self): |
| try: |
| raise RuntimeError |
| except RuntimeError as e: |
| self.assertRegex( |
| format_traceback_short(e.__traceback__), |
| r".*test_utils.py:\d+ in test_format_traceback_short", |
| ) |
| |
| def test_captured_traceback(self): |
| self.assertIn( |
| "test_captured_traceback", "".join(CapturedTraceback.extract().format()) |
| ) |
| |
| def test_captured_traceback_format_all(self): |
| rs = CapturedTraceback.format_all( |
| [CapturedTraceback.extract(), CapturedTraceback.extract()] |
| ) |
| self.assertEqual(len(rs), 2) |
| self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) |
| |
| def test_captured_traceback_format_all_cached(self): |
| tb = CapturedTraceback.extract() |
| tb.format() # cached |
| rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()]) |
| self.assertEqual(len(rs), 2) |
| self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |