[MPS] Fix `copy_kernel_mps` (#78428)
By passing `storage_offset` of source and destination Tensors
This fixes following simple usecase:
```
python3` -c "import torch;x=torch.zeros(3, 3, device='mps'); x[1, 1]=1;print(x)"
```
Add test to validate it would not regress in the future
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78428
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index e1e2a9e..b950bff 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -483,9 +483,9 @@
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
[blitEncoder copyFromBuffer:sourceBuffer
- sourceOffset:0
+ sourceOffset:src_byte_offset
toBuffer:destBuffer
- destinationOffset:0
+ destinationOffset:dst_byte_offset
size:size];
[blitEncoder endEncoding];
stream->commitAndWait();
diff --git a/test/test_mps.py b/test/test_mps.py
index f4e9568..37ed5cc 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -15,6 +15,7 @@
from torch._six import inf
from torch.nn import Parameter
from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
+from torch.testing._internal.common_device_type import dtypes
import torch.backends.mps
from torch.distributions import Uniform
@@ -1300,6 +1301,20 @@
self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
+ @dtypes(torch.int32, torch.float32, torch.int64, device_type="mps")
+ def test_setitem_scalar(self, device, dtype) -> None:
+ for i in range(3, 6):
+ for j in range(3, 6):
+ t = torch.zeros(i, j, dtype=dtype, device=device)
+ self.assertEqual(t.sum(), 0)
+ t[1, 1] = 1
+ t[2, 1] = j
+ t[2, 1] = i
+ assertEqual(t[1, 1], 1)
+ assertEqual(t[1, 2], i)
+ assertEqual(t[2, 1], j)
+ self.assertEqual(t.sum(), 1 + i + j)
+
class TestSmoothL1Loss(TestCase):