Add full support for serialization of MPS Tensors (#79465)
Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h
index 44f72b1..d957c54 100644
--- a/aten/src/ATen/mps/MPSDevice.h
+++ b/aten/src/ATen/mps/MPSDevice.h
@@ -58,7 +58,7 @@
TORCH_API bool is_available();
-at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
+TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
} // namespace mps
} // namespace at
diff --git a/test/test_mps.py b/test/test_mps.py
index d7b0f9e..e5dc637 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7,6 +7,7 @@
import unittest
import warnings
import subprocess
+import tempfile
import os
import torch
import torch.nn as nn
@@ -4536,6 +4537,42 @@
b = a.new(1)
+ def test_serialization_map_location(self):
+
+ # Ensures that cpu Tensor can be loaded on mps
+ with tempfile.NamedTemporaryFile() as f:
+ x = torch.rand(2)
+ torch.save(x, f)
+
+ f.seek(0)
+ x2 = torch.load(f, map_location="mps")
+
+ self.assertEqual(x, x2)
+ self.assertEqual(x2.device.type, "mps")
+
+ # Ensures that mps Tensors can be loaded on mps
+ with tempfile.NamedTemporaryFile() as f:
+ x = torch.rand(2, device="mps")
+ torch.save(x, f)
+
+ f.seek(0)
+ x2 = torch.load(f)
+
+ self.assertEqual(x, x2)
+ self.assertEqual(x2.device.type, "mps")
+
+ # Ensures that mps Tensors can be loaded on cpu
+ with tempfile.NamedTemporaryFile() as f:
+ x = torch.rand(2, device="mps")
+ torch.save(x, f)
+
+ f.seek(0)
+ x2 = torch.load(f, map_location="cpu")
+
+ self.assertEqual(x, x2)
+ self.assertEqual(x2.device.type, "cpu")
+
+
if __name__ == "__main__":
diff --git a/torch/_tensor.py b/torch/_tensor.py
index bb490c5..158bce2 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -219,7 +219,7 @@
# 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
- if self.device.type in ['xla', 'ort', 'mps', 'hpu']:
+ if self.device.type in ['xla', 'ort', 'hpu']:
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp
index b3cb810..674c170 100644
--- a/torch/csrc/DynamicTypes.cpp
+++ b/torch/csrc/DynamicTypes.cpp
@@ -38,6 +38,8 @@
backend = at::Backend::CPU;
} else if (device_type == at::kCUDA) {
backend = at::Backend::CUDA;
+ } else if (device_type == at::kMPS) {
+ backend = at::Backend::MPS;
} else if (device_type == at::DeviceType::Meta) {
backend = at::Backend::Undefined;
} else {
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
index 5170fe7..4e1a74b 100644
--- a/torch/csrc/Storage.cpp
+++ b/torch/csrc/Storage.cpp
@@ -4,6 +4,7 @@
#endif
#include <structmember.h>
+#include <ATen/mps/MPSDevice.h>
#include <c10/core/CPUAllocator.h>
#include <libshm.h>
#include <torch/csrc/CudaIPCTypes.h>
@@ -94,6 +95,10 @@
at::globalContext().lazyInitCUDA();
allocator = c10::cuda::CUDACachingAllocator::get();
#endif
+#ifdef USE_MPS
+ } else if (device.type() == at::kMPS) {
+ allocator = at::mps::GetMPSAllocator();
+#endif
} else if (device.type() == at::DeviceType::Meta) {
allocator = c10::GetAllocator(device.type());
} else {
diff --git a/torch/serialization.py b/torch/serialization.py
index 4072c17..8262b96 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -124,6 +124,11 @@
return 'cuda:' + str(obj.device.index)
+def _mps_tag(obj):
+ if obj.device.type == 'mps':
+ return 'mps'
+
+
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
@@ -156,9 +161,14 @@
else:
return obj.cuda(device)
+def _mps_deserialize(obj, location):
+ if location == 'mps':
+ return obj.mps()
+
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
+register_package(21, _mps_tag, _mps_deserialize)
def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
diff --git a/torch/storage.py b/torch/storage.py
index 6397bb0..a6bef20 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -116,6 +116,13 @@
else:
return self
+ def mps(self):
+ """Returns a CPU copy of this storage if it's not already on the CPU"""
+ if self.device.type != 'mps':
+ return torch._UntypedStorage(self.size(), device="mps").copy_(self, False)
+ else:
+ return self
+
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")