[MPS] Add Python Module Bindings for the MPS backend (#94417)
- This PR is a prerequisite for the upcoming Memory Leak Detection PR.
- Enable global manual seeding via `torch.manual_seed()` + test case
- Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case
- Enable the following python interfaces for MPS:
`torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]`
- Added some test cases in test_mps.py
- Added `mps.rst` to document the `torch.mps` module.
- Fixed the failure with `test_public_bindings.py`
Description of new files added:
- `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`.
- `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94417
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 34ecb2e..2ee068c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5972,6 +5972,45 @@
mps_x = torch.randn(5, device='mps', generator=g_mps)
self.assertEqual(mps_x, mps_y)
+ def test_default_mps_generator(self):
+ # manual seeding on the "default" MPS generator using
+ # the global torch.manual_seed()
+ torch.manual_seed(230)
+ mps_x = torch.randn(5, device='mps')
+ # manual seeding using torch.mps.manual_seed()
+ # which should set the "default" MPS generator
+ # like the global torch.manual_seed()
+ torch.mps.manual_seed(230)
+ mps_y = torch.randn(5, device='mps')
+ # seed values were the same, so the random tensor contents should match
+ self.assertEqual(mps_x, mps_y)
+
+ # save the default generator's state to restore it later
+ g_state = torch.mps.get_rng_state()
+
+ # generate random numbers without seeding
+ mps_x = torch.randn(5, device='mps')
+ # in this case, the random results must differ from the last generated random results
+ self.assertNotEqual(mps_x, mps_y)
+
+ # restore the previously saved state, and the results should match again
+ torch.mps.set_rng_state(g_state)
+ mps_x = torch.randn(5, device='mps')
+ self.assertEqual(mps_x, mps_y)
+
+ def test_device_synchronize(self):
+ # just running some ops each followed by a synchronize to wait for
+ # MPS stream to finish running each of them
+ net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+ .to(device='mps', dtype=torch.float)
+
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ torch.mps.synchronize()
+ x = net1(x)
+ torch.mps.synchronize()
+ x.backward(torch.randn_like(x))
+ torch.mps.synchronize()
+
# Test random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):