[MPS] Fix memory error in var (#85571)

* Fix memory corruption + wrong handling of negative dims
* Use vector for shape

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85571
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 44802e7..56d1e0f 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -589,7 +589,7 @@
   NSMutableArray<NSNumber *> *axes = nil;
   NSMutableArray<NSNumber*> *apparent_output_shape = nil;
   NSMutableArray<NSNumber*> *apparent_input_shape = nil;
-  int64_t* output_shape = nil;
+  std::vector<int64_t> output_shape;
 
   if ((!keepdim && !use_dim) || (!keepdim && use_dim && dim_value.size() <= 0))
   {
@@ -629,7 +629,6 @@
                            axes);
 
       num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims;
-      output_shape = (int64_t *)malloc(num_output_dims  * sizeof(int64_t));
 
       unsigned int curr_i = 0;
       for (int i = 0; i < num_input_dims; i++)
@@ -644,13 +643,17 @@
               }
           }
           if (found) continue;
-          output_shape[curr_i] = input_shape[i];
+          output_shape.push_back(input_shape[i]);
           curr_i += 1;
+          // End loop when output shape is filled
+          if (curr_i == num_output_dims)
+            break;
       }
 
       for(int i = 0; i < num_reduce_dims; i++)
       {
-          correction_n *= input_shape[dim_value[i]];
+          auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
+          correction_n *= input_shape[wrap_dim];
       }
       // (3, 4, 5) --> (3, 5)
   }
@@ -667,10 +670,9 @@
                            input_shape,
                            axes);
       num_output_dims = num_input_dims;
-      output_shape = (int64_t *)malloc(num_output_dims  * sizeof(int64_t));
       for (int i = 0; i < num_input_dims; i++)
       {
-          output_shape[i] = (int64_t) 1;
+          output_shape.push_back((int64_t) 1);
           correction_n *= input_shape[i];
       }
       // scalar --> vector case [[1.0034567]]
@@ -690,21 +692,22 @@
                            axes);
 
       num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0;
-      output_shape = (int64_t *)malloc(num_output_dims  * sizeof(int64_t));
 
       for(int i = 0; i < num_reduce_dims; i++)
       {
-          correction_n *= input_shape[dim_value[i]];
+          auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
+          correction_n *= input_shape[wrap_dim];
       }
 
       for (int i = 0; i < num_input_dims; i++)
       {
-          output_shape[i] = [apparent_output_shape[i] longValue];
+          output_shape.push_back([apparent_output_shape[i] longValue]);
       }
   }
 
+
   Tensor output_t = at::native::empty_mps(
-                      IntArrayRef(output_shape, num_output_dims),
+                      IntArrayRef(output_shape.data(), num_output_dims),
                       input_t.scalar_type(),
                       c10::nullopt,
                       kMPS,
@@ -789,7 +792,7 @@
   };
   native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
   }
-  free(output_shape);
+
   return output_t;
 }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 2a5c2d0..ddddb8a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2961,104 +2961,48 @@
         helper((9, 5, 6, 7))
 
     # Test var
-    def test_var(self):
-        def helper(shape):
+    def test_var_simple(self):
+        def helper():
+
+            shape = [2, 3, 4, 5]
+
             cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
             x = cpu_x.detach().clone().to('mps')
 
-            all_var = torch.var(x, unbiased=False)
-            all_var_cpu = torch.var(cpu_x, unbiased=False)
+            for unbiased in [False, True]:
+                for keepdim in [False, True]:
 
-            self.assertEqual(all_var, all_var_cpu)
+                    zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
+                    zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
 
-            nil_dim_var = torch.var(x, dim=[], unbiased=False)
-            nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False)
+                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
 
-            self.assertEqual(nil_dim_var, nil_dim_var_cpu)
+                    all_var = torch.var(x, unbiased=unbiased)
+                    all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
 
-            nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False)
-            nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False)
+                    self.assertEqual(all_var, all_var_cpu)
 
-            self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
+                    nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
+                    nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
 
-            zero_dim_var = torch.var(x, dim=[0], unbiased=False)
-            zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False)
+                    self.assertEqual(nil_dim_var, nil_dim_var_cpu)
 
-            self.assertEqual(zero_dim_var, zero_dim_var_cpu)
+                    zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
+                    zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
 
-            zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False)
-            zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False)
+                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
 
-            self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
+                    zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
+                    zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
 
-            zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False)
-            zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False)
+                    self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
 
-            self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
+                    two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
+                    two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
 
-            zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False)
-            zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
+                    self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
 
-            self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
-            two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False)
-            two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False)
-
-            self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
-            two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False)
-            two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
-
-            self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
-            all_var = torch.var(x, unbiased=True)
-            all_var_cpu = torch.var(cpu_x, unbiased=True)
-
-            self.assertEqual(all_var, all_var_cpu)
-
-            nil_dim_var = torch.var(x, dim=[], unbiased=True)
-            nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True)
-
-            self.assertEqual(nil_dim_var, nil_dim_var_cpu)
-
-            nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True)
-            nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True)
-
-            self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
-
-            zero_dim_var = torch.var(x, dim=[0], unbiased=True)
-            zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True)
-
-            self.assertEqual(zero_dim_var, zero_dim_var_cpu)
-
-            zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True)
-            zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True)
-
-            self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
-
-            zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True)
-            zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True)
-
-            self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
-
-            zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True)
-            zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
-
-            self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
-            two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True)
-            two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True)
-
-            self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
-            two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True)
-            two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
-
-            self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
-        helper((4, 5, 6, 7))
-        # verify if a change in shape of input would cause problems with graph caching
-        helper((9, 5, 6, 7))
+        helper()
 
     # Test forward amax
     def test_amax(self):