Add full support for serialization of MPS Tensors (#79465)

Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h
index 44f72b1..d957c54 100644
--- a/aten/src/ATen/mps/MPSDevice.h
+++ b/aten/src/ATen/mps/MPSDevice.h
@@ -58,7 +58,7 @@
 
 TORCH_API bool is_available();
 
-at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
+TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
 
 } // namespace mps
 } // namespace at
diff --git a/test/test_mps.py b/test/test_mps.py
index d7b0f9e..e5dc637 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7,6 +7,7 @@
 import unittest
 import warnings
 import subprocess
+import tempfile
 import os
 import torch
 import torch.nn as nn
@@ -4536,6 +4537,42 @@
 
         b = a.new(1)
 
+    def test_serialization_map_location(self):
+
+        # Ensures that cpu Tensor can be loaded on mps
+        with tempfile.NamedTemporaryFile() as f:
+            x = torch.rand(2)
+            torch.save(x, f)
+
+            f.seek(0)
+            x2 = torch.load(f, map_location="mps")
+
+            self.assertEqual(x, x2)
+            self.assertEqual(x2.device.type, "mps")
+
+        # Ensures that mps Tensors can be loaded on mps
+        with tempfile.NamedTemporaryFile() as f:
+            x = torch.rand(2, device="mps")
+            torch.save(x, f)
+
+            f.seek(0)
+            x2 = torch.load(f)
+
+            self.assertEqual(x, x2)
+            self.assertEqual(x2.device.type, "mps")
+
+        # Ensures that mps Tensors can be loaded on cpu
+        with tempfile.NamedTemporaryFile() as f:
+            x = torch.rand(2, device="mps")
+            torch.save(x, f)
+
+            f.seek(0)
+            x2 = torch.load(f, map_location="cpu")
+
+            self.assertEqual(x, x2)
+            self.assertEqual(x2.device.type, "cpu")
+
+
 
 
 if __name__ == "__main__":
diff --git a/torch/_tensor.py b/torch/_tensor.py
index bb490c5..158bce2 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -219,7 +219,7 @@
         # 2. Python list is not a good fit due to performance reason.
         #    `tolist()` converts every single element in the tensor into python objects
         #    and serialize them one by one.
-        if self.device.type in ['xla', 'ort', 'mps', 'hpu']:
+        if self.device.type in ['xla', 'ort', 'hpu']:
             # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
             # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
             # this would reconstruct the BFloat16 tensor from numpy.
diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp
index b3cb810..674c170 100644
--- a/torch/csrc/DynamicTypes.cpp
+++ b/torch/csrc/DynamicTypes.cpp
@@ -38,6 +38,8 @@
     backend = at::Backend::CPU;
   } else if (device_type == at::kCUDA) {
     backend = at::Backend::CUDA;
+  } else if (device_type == at::kMPS) {
+    backend = at::Backend::MPS;
   } else if (device_type == at::DeviceType::Meta) {
     backend = at::Backend::Undefined;
   } else {
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
index 5170fe7..4e1a74b 100644
--- a/torch/csrc/Storage.cpp
+++ b/torch/csrc/Storage.cpp
@@ -4,6 +4,7 @@
 #endif
 #include <structmember.h>
 
+#include <ATen/mps/MPSDevice.h>
 #include <c10/core/CPUAllocator.h>
 #include <libshm.h>
 #include <torch/csrc/CudaIPCTypes.h>
@@ -94,6 +95,10 @@
       at::globalContext().lazyInitCUDA();
       allocator = c10::cuda::CUDACachingAllocator::get();
 #endif
+#ifdef USE_MPS
+    } else if (device.type() == at::kMPS) {
+      allocator = at::mps::GetMPSAllocator();
+#endif
     } else if (device.type() == at::DeviceType::Meta) {
       allocator = c10::GetAllocator(device.type());
     } else {
diff --git a/torch/serialization.py b/torch/serialization.py
index 4072c17..8262b96 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -124,6 +124,11 @@
         return 'cuda:' + str(obj.device.index)
 
 
+def _mps_tag(obj):
+    if obj.device.type == 'mps':
+        return 'mps'
+
+
 def _cpu_deserialize(obj, location):
     if location == 'cpu':
         return obj
@@ -156,9 +161,14 @@
         else:
             return obj.cuda(device)
 
+def _mps_deserialize(obj, location):
+    if location == 'mps':
+        return obj.mps()
+
 
 register_package(10, _cpu_tag, _cpu_deserialize)
 register_package(20, _cuda_tag, _cuda_deserialize)
+register_package(21, _mps_tag, _mps_deserialize)
 
 
 def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
diff --git a/torch/storage.py b/torch/storage.py
index 6397bb0..a6bef20 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -116,6 +116,13 @@
         else:
             return self
 
+    def mps(self):
+        """Returns a CPU copy of this storage if it's not already on the CPU"""
+        if self.device.type != 'mps':
+            return torch._UntypedStorage(self.size(), device="mps").copy_(self, False)
+        else:
+            return self
+
     def _to(self, dtype):
         if not isinstance(dtype, torch.dtype):
             raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")