[MPS] Enable caching for random ops with Philox engine (#85833)
Also Fix type cast issue in Bernoulli (Fixes #85611)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85833
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index fa5f663..6ccc6b1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4697,27 +4697,25 @@
helper((100, 100), 2.5, 1.2)
def test_bernoulli(self):
- def helper(shape, prob=0.5):
- prob_array = np.ones(shape)
- prob_array *= prob
- cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False)
- prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
+ shape = (10, 10)
+ all_ones = torch.ones(shape, device='mps')
+ all_zeros = torch.zeros(shape, device='mps')
- mps_out = torch.bernoulli(prob_tensor)
- # We can't check reliably the mean and std.
- # Just make sure we don't return constant values
- self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
- self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
+ prob_tensor = all_ones * 0.5
+ # probability of drawing "1" is 0.5
+ mps_out = torch.bernoulli(prob_tensor)
+ # We can't check reliably the mean and std.
+ # Just make sure we don't return constant values
+ self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
+ self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
- mps_out = torch.zeros(shape, device='mps')
- mps_out = torch.bernoulli(mps_out, prob)
+ # probability of drawing "1" is 0
+ mps_out = torch.bernoulli(all_zeros)
+ self.assertEqual(mps_out, all_zeros)
- self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
- self.assertNotEqual(mps_out.to('cpu').std(), 0.)
-
- helper((100, 100), 0.50)
- helper((100, 100), 0.76)
- helper((100, 100), 0.23)
+ # probability of drawing "1" is 1
+ mps_out = torch.bernoulli(all_ones)
+ self.assertEqual(mps_out, all_ones)
# Test random_.to and random_.from
def test_random(self):