[MPS] Add histogram ops (#96652)
Adds `torch.histc`, `torch.histogram`, `torch.histogramdd`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96652
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index d8ff872..d6e0564 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -89,6 +89,10 @@
'floor_divide': [torch.float16, torch.float32],
# derivative for aten::narrow_copy is not implemented on CPU
'narrow_copy': [torch.float16, torch.float32],
+ # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+ 'histogramdd': [torch.float16, torch.float32],
+ # derivative for aten::histogram is not implemented
+ 'histogram': [torch.float16, torch.float32],
# 'bool' object is not iterable
'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32],
@@ -409,9 +413,6 @@
'geqrf': None,
'nn.functional.grid_sample': None, # Unsupported Border padding mode
'heaviside': None,
- 'histc': None,
- 'histogram': None,
- 'histogramdd': None,
'i0': None,
'igamma': None,
'igammac': None,
@@ -10244,11 +10245,11 @@
self.assertEqual(out, "")
def _get_not_implemented_op(self):
- # This can be changed once we actually implement `torch.histc`
+ # This can be changed once we actually implement `torch.lgamma`
# Should return fn, args, kwargs, string_version
- return (torch.histc,
+ return (torch.lgamma,
torch.tensor([100], device='mps'), {},
- "torch.histc(torch.tensor([4], device='mps', dtype=torch.float))")
+ "torch.lgamma(torch.tensor([4], device='mps', dtype=torch.float))")
def test_error_on_not_implemented(self):
fn, args, kwargs, _ = self._get_not_implemented_op()