[MPS] Fix memory error in var (#85571)
* Fix memory corruption + wrong handling of negative dims
* Use vector for shape
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85571
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 44802e7..56d1e0f 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -589,7 +589,7 @@
NSMutableArray<NSNumber *> *axes = nil;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
- int64_t* output_shape = nil;
+ std::vector<int64_t> output_shape;
if ((!keepdim && !use_dim) || (!keepdim && use_dim && dim_value.size() <= 0))
{
@@ -629,7 +629,6 @@
axes);
num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims;
- output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));
unsigned int curr_i = 0;
for (int i = 0; i < num_input_dims; i++)
@@ -644,13 +643,17 @@
}
}
if (found) continue;
- output_shape[curr_i] = input_shape[i];
+ output_shape.push_back(input_shape[i]);
curr_i += 1;
+ // End loop when output shape is filled
+ if (curr_i == num_output_dims)
+ break;
}
for(int i = 0; i < num_reduce_dims; i++)
{
- correction_n *= input_shape[dim_value[i]];
+ auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
+ correction_n *= input_shape[wrap_dim];
}
// (3, 4, 5) --> (3, 5)
}
@@ -667,10 +670,9 @@
input_shape,
axes);
num_output_dims = num_input_dims;
- output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));
for (int i = 0; i < num_input_dims; i++)
{
- output_shape[i] = (int64_t) 1;
+ output_shape.push_back((int64_t) 1);
correction_n *= input_shape[i];
}
// scalar --> vector case [[1.0034567]]
@@ -690,21 +692,22 @@
axes);
num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0;
- output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));
for(int i = 0; i < num_reduce_dims; i++)
{
- correction_n *= input_shape[dim_value[i]];
+ auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
+ correction_n *= input_shape[wrap_dim];
}
for (int i = 0; i < num_input_dims; i++)
{
- output_shape[i] = [apparent_output_shape[i] longValue];
+ output_shape.push_back([apparent_output_shape[i] longValue]);
}
}
+
Tensor output_t = at::native::empty_mps(
- IntArrayRef(output_shape, num_output_dims),
+ IntArrayRef(output_shape.data(), num_output_dims),
input_t.scalar_type(),
c10::nullopt,
kMPS,
@@ -789,7 +792,7 @@
};
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
- free(output_shape);
+
return output_t;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 2a5c2d0..ddddb8a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2961,104 +2961,48 @@
helper((9, 5, 6, 7))
# Test var
- def test_var(self):
- def helper(shape):
+ def test_var_simple(self):
+ def helper():
+
+ shape = [2, 3, 4, 5]
+
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
- all_var = torch.var(x, unbiased=False)
- all_var_cpu = torch.var(cpu_x, unbiased=False)
+ for unbiased in [False, True]:
+ for keepdim in [False, True]:
- self.assertEqual(all_var, all_var_cpu)
+ zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
+ zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
- nil_dim_var = torch.var(x, dim=[], unbiased=False)
- nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False)
+ self.assertEqual(zero_dim_var, zero_dim_var_cpu)
- self.assertEqual(nil_dim_var, nil_dim_var_cpu)
+ all_var = torch.var(x, unbiased=unbiased)
+ all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
- nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False)
- nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False)
+ self.assertEqual(all_var, all_var_cpu)
- self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
+ nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
+ nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
- zero_dim_var = torch.var(x, dim=[0], unbiased=False)
- zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False)
+ self.assertEqual(nil_dim_var, nil_dim_var_cpu)
- self.assertEqual(zero_dim_var, zero_dim_var_cpu)
+ zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
+ zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
- zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False)
- zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False)
+ self.assertEqual(zero_dim_var, zero_dim_var_cpu)
- self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
+ zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
+ zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
- zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False)
- zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False)
+ self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
- self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
+ two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
+ two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
- zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False)
- zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
+ self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
- self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
- two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False)
- two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False)
-
- self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
- two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False)
- two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
-
- self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
- all_var = torch.var(x, unbiased=True)
- all_var_cpu = torch.var(cpu_x, unbiased=True)
-
- self.assertEqual(all_var, all_var_cpu)
-
- nil_dim_var = torch.var(x, dim=[], unbiased=True)
- nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True)
-
- self.assertEqual(nil_dim_var, nil_dim_var_cpu)
-
- nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True)
- nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True)
-
- self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
-
- zero_dim_var = torch.var(x, dim=[0], unbiased=True)
- zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True)
-
- self.assertEqual(zero_dim_var, zero_dim_var_cpu)
-
- zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True)
- zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True)
-
- self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
-
- zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True)
- zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True)
-
- self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
-
- zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True)
- zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
-
- self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
-
- two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True)
- two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True)
-
- self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
-
- two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True)
- two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
-
- self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
-
- helper((4, 5, 6, 7))
- # verify if a change in shape of input would cause problems with graph caching
- helper((9, 5, 6, 7))
+ helper()
# Test forward amax
def test_amax(self):