[MPS] Fix out-of-bounds fill to sliced tensor (#114838)

This fixes regression introduced by https://github.com/pytorch/pytorch/pull/81951 that caused out-of-bounds access when sliced tensor is filled with zeros

Remove bogus `TORCH_INTERNAL_ASSERT(length >= offset)` as [NSMakeRange](https://developer.apple.com/documentation/foundation/1417188-nsmakerange?language=objc) arguments are location and length rather than start and end offset.

In `fill_mps_tensor_`:
- Pass `value` argument to `MPSStream::fill`
- Pass `self.nbytes()` rather than `self.storage().nbytes()` as length of of buffer to fill as later will always results in out-of-bounds write if offset within the store is non-zero

Add regression test

Fixes https://github.com/pytorch/pytorch/issues/114692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114838
Approved by: https://github.com/atalman, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index a1f3f52..4754d25 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1447,17 +1447,22 @@
         tensor_cpu = tensor_0[:][1].fill_(val)
 
         self.assertEqual(tensor_mps, tensor_cpu)
+        self.assertEqual(tensor, tensor_0)
 
         shape = [1, 10]
         val = 0.0
         tensor = torch.ones(shape, device="mps")
         val_tensor_mps = torch.tensor(val, device="mps")
         tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
+        # Regression test for https://github.com/pytorch/pytorch/issues/114692
+        tensor[:, 5].fill_(val_tensor_mps)
         tensor_0 = torch.ones(shape, device="cpu")
         val_tensor_cpu = torch.tensor(val, device="cpu")
         tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
+        tensor_0[:, 5].fill_(val_tensor_cpu)
 
-        self.assertEqual(tensor_mps, tensor_cpu)
+        self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
+        self.assertEqual(tensor.to(device="cpu"), tensor_0)
 
     def test_cdist_large(self, device="mps"):
         for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: