[MPS] Enable caching for random ops with Philox engine (#85833) Also Fix type cast issue in Bernoulli (Fixes #85611) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85833 Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py index fa5f663..6ccc6b1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -4697,27 +4697,25 @@ helper((100, 100), 2.5, 1.2) def test_bernoulli(self): - def helper(shape, prob=0.5): - prob_array = np.ones(shape) - prob_array *= prob - cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False) - prob_tensor = cpu_prob_tensor.detach().clone().to('mps') + shape = (10, 10) + all_ones = torch.ones(shape, device='mps') + all_zeros = torch.zeros(shape, device='mps') - mps_out = torch.bernoulli(prob_tensor) - # We can't check reliably the mean and std. - # Just make sure we don't return constant values - self.assertNotEqual(mps_out.to('cpu').mean(), 0.) - self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.) + prob_tensor = all_ones * 0.5 + # probability of drawing "1" is 0.5 + mps_out = torch.bernoulli(prob_tensor) + # We can't check reliably the mean and std. + # Just make sure we don't return constant values + self.assertNotEqual(mps_out.to('cpu').mean(), 0.) + self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.) - mps_out = torch.zeros(shape, device='mps') - mps_out = torch.bernoulli(mps_out, prob) + # probability of drawing "1" is 0 + mps_out = torch.bernoulli(all_zeros) + self.assertEqual(mps_out, all_zeros) - self.assertNotEqual(mps_out.to('cpu').mean(), 0.) - self.assertNotEqual(mps_out.to('cpu').std(), 0.) - - helper((100, 100), 0.50) - helper((100, 100), 0.76) - helper((100, 100), 0.23) + # probability of drawing "1" is 1 + mps_out = torch.bernoulli(all_ones) + self.assertEqual(mps_out, all_ones) # Test random_.to and random_.from def test_random(self):