[MPS] Error on unsupported types (#95982)
I.e. attempt to create tensor of all possible types and make sure that
it raises a structured error for non-MPS types
Also, rename `test_resize_as_all_dtypes_and_devices` to `test_resize_as_mps_dtypes` and `test_resize_all_dtypes_and_devices` to `test_resize_mps_dtypes` and run both test for all MPS dtypes (rather than just bool, float16 and bfloat16 as they were running before)
Fixes https://github.com/pytorch/pytorch/issues/95976
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95982
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 9949e9c..062a513 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -39,7 +39,7 @@
SpectralFuncInfo,
BinaryUfuncInfo,
)
-from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests, onlyMPS
+from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
import numpy as np
import torch
@@ -7988,15 +7988,15 @@
x.set_(x.storage(), 0, x.size(), stride)
self.assertTrue(x.is_contiguous())
- def test_resize_all_dtypes_and_devices(self, device="mps"):
+ def test_resize_mps_dtypes(self, device="mps"):
shape = (2, 2)
- for dt in (torch.half, torch.bfloat16, torch.bool):
+ for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
x.resize_(shape)
self.assertEqual(shape, x.shape)
- def test_resize_as_all_dtypes_and_devices(self, device="mps"):
- for dt in (torch.half, torch.bfloat16, torch.bool):
+ def test_resize_as_mps_dtypes(self, device="mps"):
+ for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
x.resize_as_(y)
@@ -10367,7 +10367,6 @@
# When MPS becomes more consistent, this can probably be merged with that test using
# `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
- @onlyMPS
@suppress_warnings
# MPS only supports float32
@ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
@@ -10381,12 +10380,24 @@
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)
+ @dtypes(*get_all_dtypes())
+ def test_tensor_creation(self, device, dtype):
+ def ones(device):
+ return torch.ones((2, 2), dtype=dtype, device=device)
+ if dtype not in MPS_DTYPES:
+ with self.assertRaises(TypeError):
+ ones(device)
+ else:
+ mps_tensor = ones(device)
+ cpu_tensor = ones("cpu")
+ self.assertEqual(mps_tensor.cpu(), cpu_tensor)
+
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
# to achieve this.
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
-instantiate_device_type_tests(TestCommon, globals(), allow_mps=True)
+instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
if __name__ == "__main__":
run_tests()