Add Half support for kthvalue, cross, hist, and logit on CPU (#112135)
Add Half support for kthvalue, cross, hist, and logit on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112135
Approved by: https://github.com/cpuhrsch
diff --git a/test/test_mps.py b/test/test_mps.py
index c89cc68..17954c8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -700,6 +700,7 @@
# Unsupported dtypes
'dot': [torch.int64],
+ 'histc': [torch.float16],
'index_add': [torch.int64],
'log1p': [torch.int64],
'sigmoid': [torch.int64],
@@ -793,6 +794,8 @@
# Failures due to casting negative float to uint8 is undefined
'byte': [torch.float16, torch.float32],
+ # float output for float16 input on MPS
+ 'logit': [torch.float16],
}
EMPTY_OPS_SKIPLIST = {
@@ -10940,6 +10943,7 @@
'masked.softmin',
'nn.functional.kl_div',
'nn.functional.softmin',
+ 'cross', 'linalg.cross',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',