[MPS] Fix `torch.full` for boolean types (#82575)

By creating int8 tensor and casting it to bool later

Workaround for MPSGraph deficiency reported in https://github.com/pytorch/pytorch/issues/82427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82575
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm
index a7f145f..0cfd7cc 100644
--- a/aten/src/ATen/native/mps/operations/ConstantOps.mm
+++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm
@@ -34,12 +34,21 @@
         @autoreleasepool{
           MPSGraph *mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
-
+          auto isBool = self.scalar_type() == c10::ScalarType::Bool;
+          auto dataType = (!isBool) ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8;
+          // constantWithScalar does not work for boolTypes on MacOS-12.[34]
+          // workaround by filing it as int8 tensor and than casting to bool
+          // See https://github.com/pytorch/pytorch/issues/82427
           MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
                                                                shape:getMPSShape(self)
-                                                            dataType:getMPSScalarType(self.scalar_type())];
+                                                            dataType:dataType];
           MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
                                                                  name:nil];
+          if (isBool) {
+              outputTensor = [mpsGraph castTensor:outputTensor
+                                           toType:MPSDataTypeBool
+                                             name:@"constWithBool-workaround"];
+          }
 
           newCachedGraph->outputTensor_ = outputTensor;
         }
diff --git a/test/test_mps.py b/test/test_mps.py
index c95165b..e05b055 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1543,6 +1543,12 @@
             t_mps = t.to("mps")
             self.assertEqual(t, t_mps.cpu())
 
+    # See https://github.com/pytorch/pytorch/issues/82427
+    # Test should not crash
+    def test_bool_full(self):
+        x = torch.full((3, 3), True, device='mps')
+
+
 class TestLogical(TestCase):
     def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
         return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)