Fix copy_ broadcast behavior on mps (#105617)
Fixes #105277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105617
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 2804ee0..a672de2 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3545,6 +3545,30 @@
x_mps[2:4] = update_mps # implicit type casting and copy
self.assertEqual(x_cpu, x_mps)
+ def test_copy_broadcasting(self):
+ def helper(src_shape, dst_shape, src_dtype, dst_dtype):
+ cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
+ cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
+ cpu_result = cpu_dst.copy_(cpu_src)
+ mps_src = cpu_src.to("mps")
+ mps_dst = cpu_dst.to("mps")
+ mps_result = mps_dst.copy_(mps_src)
+ self.assertEqual(cpu_result, mps_result)
+
+ test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
+
+ for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
+ helper((2, 1), (2, 3), src_dtype, dst_dtype)
+ helper((2, 1), (2, 2), src_dtype, dst_dtype)
+ helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
+ helper((3,), (2, 3), src_dtype, dst_dtype)
+ helper((2,), (2, 2), src_dtype, dst_dtype)
+ helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
+ helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
+ helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
+ helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
+ helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
+
# See https://github.com/pytorch/pytorch/pull/84742
# and https://github.com/pytorch/pytorch/pull/78319
def test_binops_dtype_precedence(self):