[MPS] Fix `torch.full` for boolean types (#82575)
By creating int8 tensor and casting it to bool later
Workaround for MPSGraph deficiency reported in https://github.com/pytorch/pytorch/issues/82427
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82575
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm
index a7f145f..0cfd7cc 100644
--- a/aten/src/ATen/native/mps/operations/ConstantOps.mm
+++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm
@@ -34,12 +34,21 @@
@autoreleasepool{
MPSGraph *mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
-
+ auto isBool = self.scalar_type() == c10::ScalarType::Bool;
+ auto dataType = (!isBool) ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8;
+ // constantWithScalar does not work for boolTypes on MacOS-12.[34]
+ // workaround by filing it as int8 tensor and than casting to bool
+ // See https://github.com/pytorch/pytorch/issues/82427
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
shape:getMPSShape(self)
- dataType:getMPSScalarType(self.scalar_type())];
+ dataType:dataType];
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
name:nil];
+ if (isBool) {
+ outputTensor = [mpsGraph castTensor:outputTensor
+ toType:MPSDataTypeBool
+ name:@"constWithBool-workaround"];
+ }
newCachedGraph->outputTensor_ = outputTensor;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index c95165b..e05b055 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1543,6 +1543,12 @@
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu())
+ # See https://github.com/pytorch/pytorch/issues/82427
+ # Test should not crash
+ def test_bool_full(self):
+ x = torch.full((3, 3), True, device='mps')
+
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)