Fix copy_ broadcast behavior on mps (#105617)

Fixes #105277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105617
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 2804ee0..a672de2 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3545,6 +3545,30 @@
         x_mps[2:4] = update_mps  # implicit type casting and copy
         self.assertEqual(x_cpu, x_mps)
 
+    def test_copy_broadcasting(self):
+        def helper(src_shape, dst_shape, src_dtype, dst_dtype):
+            cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
+            cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
+            cpu_result = cpu_dst.copy_(cpu_src)
+            mps_src = cpu_src.to("mps")
+            mps_dst = cpu_dst.to("mps")
+            mps_result = mps_dst.copy_(mps_src)
+            self.assertEqual(cpu_result, mps_result)
+
+        test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
+
+        for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
+            helper((2, 1), (2, 3), src_dtype, dst_dtype)
+            helper((2, 1), (2, 2), src_dtype, dst_dtype)
+            helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
+            helper((3,), (2, 3), src_dtype, dst_dtype)
+            helper((2,), (2, 2), src_dtype, dst_dtype)
+            helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
+            helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
+            helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
+            helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
+            helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
+
     # See https://github.com/pytorch/pytorch/pull/84742
     # and https://github.com/pytorch/pytorch/pull/78319
     def test_binops_dtype_precedence(self):