Add Half support for kthvalue, cross, hist, and logit on CPU (#112135)
Add Half support for kthvalue, cross, hist, and logit on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112135
Approved by: https://github.com/cpuhrsch
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 4a91a4e..a7ad96f 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -205,10 +205,13 @@
"cauchy": {f16},
"cholesky": {f32, f64},
"complex": {f16},
+ "cross": {f16},
"exponential": {f16},
"resize_": {b8, f16, f32, f64, i32, i64},
"resize_as_": {b8, f16, f32, f64, i32, i64},
"geometric": {f16},
+ "histc": {f16},
+ "linalg.cross": {f16},
"log_normal": {f16},
"masked_scatter": {f16, f32, f64},
"multinomial": {f16, f32, f64},
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index c23c4af..5a862e7 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -769,6 +769,7 @@
"nn.functional.batch_norm",
"native_batch_norm",
"dot",
+ "logit",
]
@common_device_type.ops(
diff --git a/test/test_meta.py b/test/test_meta.py
index 7eab797..ea845e4 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -628,10 +628,10 @@
torch.frexp : {f64, f16, bf16, f32},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
- torch.histc : {f64, bf16, f32},
+ torch.histc : {f64, f16, bf16, f32},
torch.histogram : {f64, f32},
torch.histogramdd : {f64, f32},
- torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
+ torch.kthvalue : {f64, i32, i64, u8, i16, f16, bf16, i8, f32},
torch.nn.functional.ctc_loss : {f64, f32},
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.one_hot : {i64},
@@ -805,7 +805,7 @@
aten.histc.out : {bf16, f32, f64},
aten.histogram.bin_ct : {f32, f64},
aten.histogram.bins_tensor : {f32, f64},
- aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
+ aten.kthvalue.default : {i8, f64, i64, f16, bf16, f32, i32, i16, u8},
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.upsample_nearest3d.vec : {bf16, f32, f64, u8},
@@ -842,6 +842,8 @@
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
aten.native_layer_norm.default: {bf16},
+ aten.histc.default: {f16},
+ aten.histc.out: {f16},
}
meta_dispatch_device_expected_failures['cuda'] = {
diff --git a/test/test_mps.py b/test/test_mps.py
index c89cc68..17954c8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -700,6 +700,7 @@
# Unsupported dtypes
'dot': [torch.int64],
+ 'histc': [torch.float16],
'index_add': [torch.int64],
'log1p': [torch.int64],
'sigmoid': [torch.int64],
@@ -793,6 +794,8 @@
# Failures due to casting negative float to uint8 is undefined
'byte': [torch.float16, torch.float32],
+ # float output for float16 input on MPS
+ 'logit': [torch.float16],
}
EMPTY_OPS_SKIPLIST = {
@@ -10940,6 +10943,7 @@
'masked.softmin',
'nn.functional.kl_div',
'nn.functional.softmin',
+ 'cross', 'linalg.cross',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/test/test_reductions.py b/test/test_reductions.py
index 1206ab8..f755267 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -2982,13 +2982,14 @@
test_against_np(linear, bins=20, min=0, max=0.99)
@onlyCPU
- def test_histc_bfloat16(self, device):
+ @dtypes(torch.bfloat16, torch.half)
+ def test_histc_lowp(self, device, dtype):
actual = torch.histc(
- torch.tensor([1, 2, 1], dtype=torch.bfloat16, device=device), bins=4, min=0, max=3)
+ torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3)
self.assertEqual(
- torch.tensor([0, 2, 1, 0], dtype=torch.bfloat16, device=device),
+ torch.tensor([0, 2, 1, 0], dtype=dtype, device=device),
actual)
- self.assertEqual(actual.dtype, torch.bfloat16)
+ self.assertEqual(actual.dtype, dtype)
"""
Runs torch.histogram and numpy.histogram on the specified input parameters