[MPS] Fix MPSGraph casting issue to MPSDataTypeBool in masked_fill op (#94263)
Fixes TestConsistency masked_fill for bool data type.
Casting a tensor > 1 to MPSDataTypeBool will result in 0 instead of 1. This change manually casts the scalar to a value of 0 or 1 when casting a non-boolean tensor to a boolean tensor:
```
(inputDataType == MPSDataTypeBool) ? !!value.to<double>() : value.to<double>()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94263
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 65331e9..9ecaa30 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8379,7 +8379,7 @@
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'],
'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
- 'masked_fill': ['f16', 'i16', 'i32', 'i64'],
+ 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'matmul': ['f32'],
'mm': ['f32'],
@@ -8496,7 +8496,7 @@
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
- 'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'where': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nonzero': ['f32', 'i16', 'i32', 'i64'],
'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8911,6 +8911,9 @@
op.name == "masked.sum" or op.name == "masked.std" or op.name == "masked.var") and dtype == torch.float16:
atol = 1e-2
rtol = 1e-2
+ elif (op.name == "masked.mean"):
+ atol = 7e-4
+ rtol = 2e-3
else:
atol = None
rtol = None