add Half support for multinomial on CPU (#104178)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104178
Approved by: https://github.com/jgong5, https://github.com/kulinseth, https://github.com/cpuhrsch
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 5205f87..ebd4a27 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -206,7 +206,7 @@
"geometric": {f16},
"log_normal": {f16},
"masked_scatter": {f16, f32, f64},
- "multinomial": {f32, f64},
+ "multinomial": {f16, f32, f64},
"nn.functional.avg_pool1d": {i64},
"nn.functional.avg_pool2d": {i64},
"nn.functional.local_response_norm": {i64},
diff --git a/test/test_mps.py b/test/test_mps.py
index efa6daf..e6da181 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -743,7 +743,7 @@
# Failures due to random output that they generate using
# Philox engine causing mismatch with CPU results
- 'multinomial': [torch.float32], # random results
+ 'multinomial': [torch.float16, torch.float32], # random results
'uniform': [torch.float16, torch.float32],
'rand_like': [torch.float16, torch.float32],
'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
diff --git a/test/test_torch.py b/test/test_torch.py
index 91b1cc4..f090b6b 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -5000,7 +5000,7 @@
# FIXME: move to test distributions
@skipIfMps
@dtypesIfCUDA(torch.float, torch.double, torch.half)
- @dtypes(torch.float, torch.double)
+ @dtypes(torch.float, torch.double, torch.half)
def test_multinomial(self, device, dtype):
def make_prob_dist(shape, is_contiguous):
if is_contiguous:
@@ -5370,7 +5370,7 @@
self._test_multinomial_empty(device, False, 2)
@dtypesIfCUDA(torch.float, torch.double, torch.half)
- @dtypesIfCPU(torch.float, torch.double, torch.bfloat16)
+ @dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
@dtypes(torch.float, torch.double)
def test_multinomial_cpu(self, device, dtype):
def make_prob_dist(shape, is_contiguous):