| # Owner(s): ["module: internals"] |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| class TestComparisonUtils(TestCase): |
| def test_all_equal_no_assert(self): |
| torch._assert_tensor_metadata(t, [1], [1], torch.float) |
| def test_all_equal_no_assert_nones(self): |
| torch._assert_tensor_metadata(t, None, None, None) |
| def test_assert_dtype(self): |
| with self.assertRaises(RuntimeError): |
| torch._assert_tensor_metadata(t, None, None, torch.int32) |
| def test_assert_strides(self): |
| with self.assertRaises(RuntimeError): |
| torch._assert_tensor_metadata(t, None, [3], torch.float) |
| def test_assert_sizes(self): |
| with self.assertRaises(RuntimeError): |
| torch._assert_tensor_metadata(t, [3], [1], torch.float) |
| if __name__ == "__main__": |