[BE] Do not test deprecated `torch.nn.utils.weight_norm` (#128727)
Test `torch.nn.utils.parametrizations.weight_norm` instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128727
Approved by: https://github.com/kit1980
ghstack dependencies: #128726
diff --git a/test/test_mps.py b/test/test_mps.py
index 9ca0cba..275013f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2856,8 +2856,8 @@
def test_weight_norm(self):
def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
- cpu_norm = torch.nn.utils.weight_norm(cpu_model, dim=dim)
- norm = torch.nn.utils.weight_norm(model, dim=dim)
+ cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
+ norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
cpu_out = cpu_norm(cpu_x)
out = norm(x)
@@ -2869,8 +2869,8 @@
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)
- self.assertEqual(cpu_model.weight_g.grad, model.weight_g.grad)
- self.assertEqual(cpu_model.weight_v.grad, model.weight_v.grad)
+ self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
+ self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
self.assertEqual(x.grad, cpu_x.grad)