[MPS] Fix torch.full for uint8 (#83697)
By creating uint32 tensor and then downcasting it to uint8
Workaround https://github.com/pytorch/pytorch/issues/83692
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83697
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm
index 0cfd7cc..a5ddd82 100644
--- a/aten/src/ATen/native/mps/operations/ConstantOps.mm
+++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm
@@ -35,11 +35,15 @@
MPSGraph *mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
- auto dataType = (!isBool) ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8;
+ auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
+ auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
// workaround by filing it as int8 tensor and than casting to bool
// See https://github.com/pytorch/pytorch/issues/82427
- MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
+ // constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
+ // workaround by filing it as uint32 tensor and than casting to uint8
+ // See https://github.com/pytorch/pytorch/issues/83692
+ MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar: value.toDouble()
shape:getMPSShape(self)
dataType:dataType];
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor
@@ -49,6 +53,11 @@
toType:MPSDataTypeBool
name:@"constWithBool-workaround"];
}
+ if (isUInt8) {
+ outputTensor = [mpsGraph castTensor:outputTensor
+ toType:MPSDataTypeUInt8
+ name:@"constWithUInt8-workaround"];
+ }
newCachedGraph->outputTensor_ = outputTensor;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 32aa2c1..d1403fc 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1545,9 +1545,14 @@
self.assertEqual(t, t_mps.cpu())
# See https://github.com/pytorch/pytorch/issues/82427
- # Test should not crash
- def test_bool_full(self):
+ # and https://github.com/pytorch/pytorch/issues/83692
+ def test_full_bugs(self):
+ # Test should not crash
x = torch.full((3, 3), True, device='mps')
+ # torch.full should work for uint8
+ y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
+ y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
+ self.assertEqual(y_mps, y_cpu)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):