Remove prints and add proper asserts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78454

Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 94e084f..f49e1fb 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -28,7 +28,6 @@
     TestCase = object  # noqa: F811
     NNTestCase = object  # noqa: F811
 
-
 class MPSReluTest(TestCase):
     def _npRelu(self, np_features):
         return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
@@ -218,7 +217,6 @@
     def test_exp1(self, device="mps", dtype=torch.float):
         input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
         output = torch.exp(input).to('cpu')
-        print(output)
 
     def _testLeakyRelu(self, np_features, negative_slope, device):
         cpu_x = torch.from_numpy(np_features).requires_grad_()
@@ -508,7 +506,6 @@
             net = torch.nn.AdaptiveAvgPool2d((1, 1))
             out = net(x)
             ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
-            print(ref_out)
 
             out.sum().backward()    # make sure it doesn't crash
 
@@ -1196,8 +1193,6 @@
 
         cpu_x.transpose_(0, 1)
         mps_x.transpose_(0, 1)
-        print(cpu_x)
-        print(mps_x.to('cpu'))
         self.assertEqual(cpu_x, mps_x.to('cpu'))
 
     def test_slice(self):
@@ -1518,7 +1513,6 @@
 
             all_sum.backward()
             all_sum_cpu.backward()
-            print(torch.ones(1, device="mps").expand(10).clone())
             self.assertEqual(all_sum, all_sum_cpu)
             self.assertEqual(x.grad, cpu_x.grad)
 
@@ -3396,8 +3390,6 @@
             cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
             idx = cpu_idx.detach().clone().to('mps')
 
-            print(cpu_idx.shape)
-
             idx_result = torch.index_select(x, dim=dim, index=idx)
             idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
 
@@ -3826,15 +3818,16 @@
             prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
 
             mps_out = torch.bernoulli(prob_tensor)
-            # Compare "real" with theoretical values
-            print(mps_out.to('cpu').mean(), prob)
-            print(mps_out.to('cpu').std() ** 2, prob * (1 - prob))
+            # We can't check reliably the mean and std.
+            # Just make sure we don't return constant values
+            self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
+            self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
 
             mps_out = torch.zeros(shape, device='mps')
             mps_out = torch.bernoulli(mps_out, prob)
 
-            print(mps_out.to('cpu').mean(), prob)
-            print(mps_out.to('cpu').std() ** 2, prob * (1 - prob))
+            self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
+            self.assertNotEqual(mps_out.to('cpu').std(), 0.)
 
         helper((100, 100), 0.50)
         helper((100, 100), 0.76)
@@ -3844,11 +3837,12 @@
     def test_random(self):
         def helper(shape, low, high, dtype=torch.int32):
 
-            print(low, high)
             mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
 
-            print(mps_out.to('cpu').float().mean(), (low + (high - 1)) / 2.)
-            print(mps_out.to('cpu').float().std() ** 2, ((high - low)**2 - 1) / 12.)
+            # We can't check reliably the mean and std.
+            # Just make sure we don't return constant values
+            self.assertNotEqual(mps_out.to('cpu').float().mean(), 0.)
+            self.assertNotEqual(mps_out.to('cpu').float().std(), 0.)
 
         helper([100, 100], 0, 10)
         helper([100, 100], 23, 89)
@@ -3886,8 +3880,6 @@
             cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
             out = torch.add(x, y, alpha=alpha)
 
-            print(out.to('cpu'))
-
             self.assertEqual(out, cpu_out)
 
         helper()
@@ -4166,35 +4158,40 @@
         input = torch.randn(2, 3, 1, device="cpu")
         hx = torch.zeros(2, 3, 4, device="cpu")
         cx = torch.zeros(2, 3, 4, device="cpu")
-        outputs = []
-        for device in [torch.device("cpu"), torch.device("mps")]:
-            rnn = rnn.to(device)
-            input = input.to(device)
-            hx = hx.to(device)
-            cx = cx.to(device)
-            weight_list = []
-            output, _ = rnn(input, (hx, cx))
-            print(output.to('cpu'))
 
+        cpu_output, _ = rnn(input, (hx, cx))
+
+        device = torch.device("mps")
+        rnn = rnn.to(device)
+        input = input.to(device)
+        hx = hx.to(device)
+        cx = cx.to(device)
+        output, _ = rnn(input, (hx, cx))
+        self.assertEqual(cpu_output, output)
+
+    @unittest.skipIf(True, "Backward of lstm returns wrong result")
     def test_lstm_2(self, device="mps", dtype=torch.float32):
-        rnn = nn.LSTM(1, 4, 1, device="cpu")
-        input = torch.randn(2, 3, 1, device="cpu", requires_grad=True)
-        hx = torch.zeros(1, 3, 4, device="cpu")
-        cx = torch.zeros(1, 3, 4, device="cpu")
-        outputs = []
-        for device in [torch.device("cpu"), torch.device("mps")]:
-            rnn = rnn.to(device)
-            input = input.to(device)
-            input.retain_grad()
-            hx = hx.to(device)
-            cx = cx.to(device)
+        def get_results(device):
+            rnn = nn.LSTM(1, 4, 1, device=device)
+            inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
+            hx = torch.zeros(1, 3, 4, device=device)
+            cx = torch.zeros(1, 3, 4, device=device)
 
-            output, _ = rnn(input, (hx, cx))
-            # Test by passing ones as the gradient from the loss.
-            output.backward(torch.ones_like(output))
+            output, _ = rnn(inp, (hx, cx))
+            output.sum().backward()
 
-            print(rnn.weight_ih_l0.grad)
-            # Gradient on GPU is 2x the CPU gradient???
+            weight_grad = rnn.weight_ih_l0.grad.clone()
+            input_grad = inp.grad.clone()
+
+            return output, weight_grad, input_grad
+
+
+        cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
+        mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
+
+        self.assertEqual(cpu_output, mps_output)
+        self.assertEqual(cpu_input_grad, mps_input_grad)
+        self.assertEqual(cpu_weight_grad, mps_weight_grad)
 
 class TestFallbackWarning(TestCase):
     def test_no_warning_on_import(self):