Add Half for aten2, logaddexp, logaddexp2, hypot, and nextafter on CPU (#112138)
Add Half for aten2, logaddexp, logaddexp2, hypot, and nextafter on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112138
Approved by: https://github.com/cpuhrsch
diff --git a/test/onnx/test_op_consistency.py b/test/onnx/test_op_consistency.py
index d64179a..83e4232 100644
--- a/test/onnx/test_op_consistency.py
+++ b/test/onnx/test_op_consistency.py
@@ -100,7 +100,10 @@
"atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
),
- xfail("atan2", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
+ xfail(
+ "atan2", dtypes=[torch.float64, torch.float16],
+ reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64", "f16"])
+ ),
xfail(
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 53cecb3..2f22569 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -2801,7 +2801,7 @@
abs(c[0] - d[0]) == abs(b[0])
) # differ by one divisor
- @dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64)
+ @dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64)
@dtypes(torch.float32, torch.float64)
def test_hypot(self, device, dtype):
inputs = [
@@ -2824,7 +2824,7 @@
]
for input in inputs:
actual = torch.hypot(input[0], input[1])
- if dtype == torch.bfloat16:
+ if dtype in [torch.bfloat16, torch.half]:
expected = torch.sqrt(input[0] * input[0] + input[1] * input[1])
else:
expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy())
@@ -2873,6 +2873,7 @@
self.assertEqual(actual, expected, exact_dtype=False)
@onlyNativeDeviceTypes
+ @dtypesIfCPU(torch.float32, torch.float64, torch.float16)
@dtypes(torch.float32, torch.float64)
def test_nextafter(self, device, dtype):
# Test special cases
@@ -3804,12 +3805,19 @@
)
self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
- # bfloat16
- a_bf16 = a.bfloat16()
- b_bf16 = b.bfloat16()
- actual_bf16 = a_bf16.atan2(b_bf16)
- self.assertEqual(actual_bf16, actual.bfloat16())
- self.assertEqual(expected, actual_bf16.view(-1), exact_dtype=False, rtol=0, atol=0.02)
+ # bfloat16/float16
+ for lowp_dtype in [torch.bfloat16, torch.float16]:
+ if lowp_dtype == torch.bfloat16:
+ rtol = 0
+ atol = 0.02
+ else:
+ rtol = 0
+ atol = 0.001
+ a_16 = a.to(dtype=lowp_dtype)
+ b_16 = b.to(dtype=lowp_dtype)
+ actual_16 = a_16.atan2(b_16)
+ self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
+ self.assertEqual(expected, actual_16.view(-1), exact_dtype=False, rtol=rtol, atol=atol)
_test_atan2_with_size((2, 2), device)
_test_atan2_with_size((3, 3), device)
diff --git a/test/test_mps.py b/test/test_mps.py
index 817e11e..0f2be10 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -417,7 +417,7 @@
'cdist': [torch.float32],
# CPU Error: cpu not giving nan for x/0.0
- 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+ 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# test blow pass on macOS 12 as it falls back to cpu
# Argsort case using duplicate indices (undefined behaviour):
@@ -10946,6 +10946,7 @@
'nn.functional.softmin',
'cross', 'linalg.cross',
'prod', 'masked.prod',
+ 'nextafter',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',