[MPS] Enable caching for random ops with Philox engine (#85833)

Also Fix type cast issue in Bernoulli (Fixes #85611)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85833
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index fa5f663..6ccc6b1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4697,27 +4697,25 @@
         helper((100, 100), 2.5, 1.2)
 
     def test_bernoulli(self):
-        def helper(shape, prob=0.5):
-            prob_array = np.ones(shape)
-            prob_array *= prob
-            cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False)
-            prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
+        shape = (10, 10)
+        all_ones = torch.ones(shape, device='mps')
+        all_zeros = torch.zeros(shape, device='mps')
 
-            mps_out = torch.bernoulli(prob_tensor)
-            # We can't check reliably the mean and std.
-            # Just make sure we don't return constant values
-            self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
-            self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
+        prob_tensor = all_ones * 0.5
+        # probability of drawing "1" is 0.5
+        mps_out = torch.bernoulli(prob_tensor)
+        # We can't check reliably the mean and std.
+        # Just make sure we don't return constant values
+        self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
+        self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
 
-            mps_out = torch.zeros(shape, device='mps')
-            mps_out = torch.bernoulli(mps_out, prob)
+        # probability of drawing "1" is 0
+        mps_out = torch.bernoulli(all_zeros)
+        self.assertEqual(mps_out, all_zeros)
 
-            self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
-            self.assertNotEqual(mps_out.to('cpu').std(), 0.)
-
-        helper((100, 100), 0.50)
-        helper((100, 100), 0.76)
-        helper((100, 100), 0.23)
+        # probability of drawing "1" is 1
+        mps_out = torch.bernoulli(all_ones)
+        self.assertEqual(mps_out, all_ones)
 
     # Test random_.to and random_.from
     def test_random(self):