[MPS] Fix memory error in var (#85571)
* Fix memory corruption + wrong handling of negative dims
* Use vector for shape
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85571
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 2a5c2d0..ddddb8a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2961,104 +2961,48 @@
helper((9, 5, 6, 7))
# Test var
- def test_var(self):
- def helper(shape):
+ def test_var_simple(self):
+ def helper():
+
+ shape = [2, 3, 4, 5]
+
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
- all_var = torch.var(x, unbiased=False)
- all_var_cpu = torch.var(cpu_x, unbiased=False)
+ for unbiased in [False, True]:
+ for keepdim in [False, True]:
- self.assertEqual(all_var, all_var_cpu)
+ zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
+ zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
- nil_dim_var = torch.var(x, dim=[], unbiased=False)
- nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False)
+ self.assertEqual(zero_dim_var, zero_dim_var_cpu)
- self.assertEqual(nil_dim_var, nil_dim_var_cpu)
+ all_var = torch.var(x, unbiased=unbiased)
+ all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
- nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False)
- nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False)
+ self.assertEqual(all_var, all_var_cpu)
- self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
+ nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
+ nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
- zero_dim_var = torch.var(x, dim=[0], unbiased=False)
- zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False)
+ self.assertEqual(nil_dim_var, nil_dim_var_cpu)
- self.assertEqual(zero_dim_var, zero_dim_var_cpu)
+ zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
+ zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
- zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False)
- zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False)
+ self.assertEqual(zero_dim_var, zero_dim_var_cpu)
- self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
+ zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
+ zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
- zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False)
- zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False)
+ self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
- self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
+ two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
+ two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
- zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False)
- zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
+ self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
- self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
- two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False)
- two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False)
-
- self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
- two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False)
- two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
-
- self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
- all_var = torch.var(x, unbiased=True)
- all_var_cpu = torch.var(cpu_x, unbiased=True)
-
- self.assertEqual(all_var, all_var_cpu)
-
- nil_dim_var = torch.var(x, dim=[], unbiased=True)
- nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True)
-
- self.assertEqual(nil_dim_var, nil_dim_var_cpu)
-
- nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True)
- nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True)
-
- self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
-
- zero_dim_var = torch.var(x, dim=[0], unbiased=True)
- zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True)
-
- self.assertEqual(zero_dim_var, zero_dim_var_cpu)
-
- zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True)
- zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True)
-
- self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
-
- zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True)
- zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True)
-
- self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
-
- zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True)
- zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
-
- self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
- two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True)
- two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True)
-
- self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
- two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True)
- two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
-
- self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
- helper((4, 5, 6, 7))
- # verify if a change in shape of input would cause problems with graph caching
- helper((9, 5, 6, 7))
+ helper()
# Test forward amax
def test_amax(self):