| # Owner(s): ["module: optimizer"] |
| |
| import itertools |
| import pickle |
| |
| import torch |
| from torch.optim.swa_utils import ( |
| AveragedModel, |
| get_ema_multi_avg_fn, |
| get_swa_multi_avg_fn, |
| update_bn, |
| ) |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| load_tests, |
| parametrize, |
| TestCase, |
| ) |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| |
| class TestSWAUtils(TestCase): |
| class SWATestDNN(torch.nn.Module): |
| def __init__(self, input_features): |
| super().__init__() |
| self.n_features = 100 |
| self.fc1 = torch.nn.Linear(input_features, self.n_features) |
| self.bn = torch.nn.BatchNorm1d(self.n_features) |
| |
| def compute_preactivation(self, x): |
| return self.fc1(x) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.bn(x) |
| return x |
| |
| class SWATestCNN(torch.nn.Module): |
| def __init__(self, input_channels): |
| super().__init__() |
| self.n_features = 10 |
| self.conv1 = torch.nn.Conv2d( |
| input_channels, self.n_features, kernel_size=3, padding=1 |
| ) |
| self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) |
| |
| def compute_preactivation(self, x): |
| return self.conv1(x) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn(x) |
| return x |
| |
| def _test_averaged_model(self, net_device, swa_device, ema): |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.ReLU(), |
| torch.nn.MaxPool2d(kernel_size=2), |
| torch.nn.BatchNorm2d(5, momentum=0.3), |
| torch.nn.Conv2d(5, 2, kernel_size=3), |
| torch.nn.ReLU(), |
| torch.nn.Linear(5, 5), |
| torch.nn.ReLU(), |
| torch.nn.Linear(5, 10), |
| ).to(net_device) |
| |
| averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema) |
| |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| # Check that AveragedModel is on the correct device |
| self.assertTrue(p_swa.device == swa_device) |
| self.assertTrue(p_avg.device == net_device) |
| self.assertTrue(averaged_dnn.n_averaged.device == swa_device) |
| |
| def _run_averaged_steps(self, dnn, swa_device, ema): |
| ema_decay = 0.999 |
| if ema: |
| averaged_dnn = AveragedModel( |
| dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay) |
| ) |
| else: |
| averaged_dnn = AveragedModel( |
| dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn() |
| ) |
| |
| averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] |
| |
| n_updates = 10 |
| for i in range(n_updates): |
| for p, p_avg in zip(dnn.parameters(), averaged_params): |
| p.detach().add_(torch.randn_like(p)) |
| if ema: |
| p_avg += ( |
| p.detach() |
| * ema_decay ** (n_updates - i - 1) |
| * ((1 - ema_decay) if i > 0 else 1.0) |
| ) |
| else: |
| p_avg += p.detach() / n_updates |
| averaged_dnn.update_parameters(dnn) |
| |
| return averaged_params, averaged_dnn |
| |
| @parametrize("ema", [True, False]) |
| def test_averaged_model_all_devices(self, ema): |
| cpu = torch.device("cpu") |
| self._test_averaged_model(cpu, cpu, ema) |
| if torch.cuda.is_available(): |
| cuda = torch.device(0) |
| self._test_averaged_model(cuda, cpu, ema) |
| self._test_averaged_model(cpu, cuda, ema) |
| self._test_averaged_model(cuda, cuda, ema) |
| |
| @parametrize("ema", [True, False]) |
| def test_averaged_model_mixed_device(self, ema): |
| if not torch.cuda.is_available(): |
| return |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) |
| ) |
| dnn[0].cuda() |
| dnn[1].cpu() |
| |
| averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema) |
| |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| # Check that AveragedModel is on the correct device |
| self.assertTrue(p_avg.device == p_swa.device) |
| |
| def test_averaged_model_state_dict(self): |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) |
| ) |
| averaged_dnn = AveragedModel(dnn) |
| averaged_dnn2 = AveragedModel(dnn) |
| n_updates = 10 |
| for i in range(n_updates): |
| for p in dnn.parameters(): |
| p.detach().add_(torch.randn_like(p)) |
| averaged_dnn.update_parameters(dnn) |
| averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) |
| for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): |
| self.assertEqual(p_swa, p_swa2) |
| self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) |
| |
| def test_averaged_model_default_avg_fn_picklable(self): |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.BatchNorm2d(5), |
| torch.nn.Linear(5, 5), |
| ) |
| averaged_dnn = AveragedModel(dnn) |
| pickle.dumps(averaged_dnn) |
| |
| @parametrize("use_multi_avg_fn", [True, False]) |
| @parametrize("use_buffers", [True, False]) |
| def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers): |
| # Test AveragedModel with EMA as avg_fn and use_buffers as True. |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.BatchNorm2d(5, momentum=0.3), |
| torch.nn.Linear(5, 10), |
| ) |
| decay = 0.9 |
| |
| if use_multi_avg_fn: |
| averaged_dnn = AveragedModel( |
| dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers |
| ) |
| else: |
| |
| def avg_fn(p_avg, p, n_avg): |
| return decay * p_avg + (1 - decay) * p |
| |
| averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers) |
| |
| if use_buffers: |
| dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers())) |
| else: |
| dnn_params = list(dnn.parameters()) |
| |
| averaged_params = [ |
| torch.zeros_like(param) |
| for param in dnn_params |
| if param.size() != torch.Size([]) |
| ] |
| |
| n_updates = 10 |
| for i in range(n_updates): |
| updated_averaged_params = [] |
| for p, p_avg in zip(dnn_params, averaged_params): |
| if p.size() == torch.Size([]): |
| continue |
| p.detach().add_(torch.randn_like(p)) |
| if i == 0: |
| updated_averaged_params.append(p.clone()) |
| else: |
| updated_averaged_params.append( |
| (p_avg * decay + p * (1 - decay)).clone() |
| ) |
| averaged_dnn.update_parameters(dnn) |
| averaged_params = updated_averaged_params |
| |
| if use_buffers: |
| for p_avg, p_swa in zip( |
| averaged_params, |
| itertools.chain( |
| averaged_dnn.module.parameters(), averaged_dnn.module.buffers() |
| ), |
| ): |
| self.assertEqual(p_avg, p_swa) |
| else: |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): |
| self.assertEqual(b_avg, b_swa) |
| |
| def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): |
| preactivation_sum = torch.zeros(dnn.n_features) |
| preactivation_squared_sum = torch.zeros(dnn.n_features) |
| if cuda: |
| preactivation_sum = preactivation_sum.cuda() |
| preactivation_squared_sum = preactivation_squared_sum.cuda() |
| total_num = 0 |
| for x in dl_x: |
| x = x[0] |
| if cuda: |
| x = x.cuda() |
| |
| dnn.forward(x) |
| preactivations = dnn.compute_preactivation(x) |
| if len(preactivations.shape) == 4: |
| preactivations = preactivations.transpose(1, 3) |
| preactivations = preactivations.contiguous().view(-1, dnn.n_features) |
| total_num += preactivations.shape[0] |
| |
| preactivation_sum += torch.sum(preactivations, dim=0) |
| preactivation_squared_sum += torch.sum(preactivations**2, dim=0) |
| |
| preactivation_mean = preactivation_sum / total_num |
| preactivation_var = preactivation_squared_sum / total_num |
| preactivation_var = preactivation_var - preactivation_mean**2 |
| |
| update_bn(dl_xy, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| |
| def _reset_bn(module): |
| if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
| module.running_mean = torch.zeros_like(module.running_mean) |
| module.running_var = torch.ones_like(module.running_var) |
| |
| # reset batch norm and run update_bn again |
| dnn.apply(_reset_bn) |
| update_bn(dl_xy, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| # using the dl_x loader instead of dl_xy |
| dnn.apply(_reset_bn) |
| update_bn(dl_x, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| |
| def test_update_bn_dnn(self): |
| # Test update_bn for a fully-connected network with BatchNorm1d |
| objects, input_features = 100, 5 |
| x = torch.rand(objects, input_features) |
| y = torch.rand(objects) |
| ds_x = torch.utils.data.TensorDataset(x) |
| ds_xy = torch.utils.data.TensorDataset(x, y) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) |
| dnn = self.SWATestDNN(input_features=input_features) |
| dnn.train() |
| self._test_update_bn(dnn, dl_x, dl_xy, False) |
| if torch.cuda.is_available(): |
| dnn = self.SWATestDNN(input_features=input_features) |
| dnn.train() |
| self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) |
| self.assertTrue(dnn.training) |
| |
| def test_update_bn_cnn(self): |
| # Test update_bn for convolutional network and BatchNorm2d |
| objects = 100 |
| input_channels = 3 |
| height, width = 5, 5 |
| x = torch.rand(objects, input_channels, height, width) |
| y = torch.rand(objects) |
| ds_x = torch.utils.data.TensorDataset(x) |
| ds_xy = torch.utils.data.TensorDataset(x, y) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) |
| cnn = self.SWATestCNN(input_channels=input_channels) |
| cnn.train() |
| self._test_update_bn(cnn, dl_x, dl_xy, False) |
| if torch.cuda.is_available(): |
| cnn = self.SWATestCNN(input_channels=input_channels) |
| cnn.train() |
| self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True) |
| self.assertTrue(cnn.training) |
| |
| def test_bn_update_eval_momentum(self): |
| # check that update_bn preserves eval mode |
| objects = 100 |
| input_channels = 3 |
| height, width = 5, 5 |
| x = torch.rand(objects, input_channels, height, width) |
| ds_x = torch.utils.data.TensorDataset(x) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| cnn = self.SWATestCNN(input_channels=input_channels) |
| cnn.eval() |
| update_bn(dl_x, cnn) |
| self.assertFalse(cnn.training) |
| |
| # check that momentum is preserved |
| self.assertEqual(cnn.bn.momentum, 0.3) |
| |
| |
| instantiate_parametrized_tests(TestSWAUtils) |
| |
| |
| if __name__ == "__main__": |
| print("These tests should be run through test/test_optim.py instead") |