Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b809f922f2e0e09011afd00037d20a771.
Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests https://hud.pytorch.org/pytorch/pytorch/commit/4095ef8b809f922f2e0e09011afd00037d20a771 https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
diff --git a/test/test_mps.py b/test/test_mps.py
index a90e6c3..77c2a34 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1650,7 +1650,7 @@
def test_bool_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
- self.assertNotEqual(x.expand(2, 2), y.expand(2, 2), rtol=0, atol=0, exact_device=True)
+ self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
# Empty unary op should return tensor of the same size
def test_empty_neg(self):
@@ -5043,7 +5043,7 @@
# see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
x = torch.ones(4, dtype=torch.int32, device='mps')
self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
- self.assertEqual(x + 1.5, torch.full((4,), 2.5, device='mps'), rtol=0, atol=0, exact_device=True)
+ self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
def test_types_binary_op(self):
# Float * Bool