[MPS] Add histogram ops (#96652)

Adds `torch.histc`, `torch.histogram`, `torch.histogramdd`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96652
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/mps/IndexKernels.h b/aten/src/ATen/mps/IndexKernels.h
index 4be40ca..a845e5f 100644
--- a/aten/src/ATen/mps/IndexKernels.h
+++ b/aten/src/ATen/mps/IndexKernels.h
@@ -224,6 +224,21 @@
     }
 }
 
+kernel void kernel_index_offset(constant uint         * strides         [[buffer(0)]],
+                                device uint           * data_offsets    [[buffer(1)]],
+                                constant uint         * iter_shape      [[buffer(2)]],
+                                constant uint         & num_dimensions  [[buffer(3)]],
+                                uint thread_index [[thread_position_in_grid]]) {
+    uint32_t idx = thread_index;
+    for (uint32_t dim = 0; dim < num_dimensions; dim++) {
+        uint32_t reversed_dim = num_dimensions - dim -1;
+        uint32_t remainder = idx % iter_shape[reversed_dim];
+        idx /= iter_shape[reversed_dim];
+
+        data_offsets[thread_index] += remainder * strides[reversed_dim];
+    }
+}
+
 template<typename T, typename E>
 kernel void index_put_accumulate_native_dtypes(
 #if __METAL_VERSION__ >= 300
diff --git a/aten/src/ATen/native/Histogram.cpp b/aten/src/ATen/native/Histogram.cpp
index 116e91a..2a22650 100644
--- a/aten/src/ATen/native/Histogram.cpp
+++ b/aten/src/ATen/native/Histogram.cpp
@@ -16,11 +16,13 @@
 #include <ATen/ops/_histogramdd_from_bin_tensors.h>
 #include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
 #include <ATen/ops/aminmax.h>
+#include <ATen/ops/amin.h>
+#include <ATen/ops/amax.h>
 #include <ATen/ops/empty.h>
 #include <ATen/ops/histc_native.h>
 #include <ATen/ops/histogram_native.h>
 #include <ATen/ops/histogramdd_native.h>
-#include <ATen/ops/linspace_native.h>
+#include <ATen/ops/linspace.h>
 #endif
 
 #include <numeric>
@@ -193,7 +195,18 @@
     } else if (input.numel() > 0) {
         // non-empty input
         AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
-            infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
+            if (input.is_mps()) {
+                // aminmax has not been implemented on mps.
+                Tensor min = at::amin(input, 0);
+                Tensor max = at::amax(input, 0);
+
+                for (const auto i : c10::irange(N)) {
+                    leftmost_edges[i] = min[i].item().to<scalar_t>();
+                    rightmost_edges[i] = max[i].item().to<scalar_t>();
+                }
+            } else {
+                infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
+            }
         });
     }
 
@@ -226,9 +239,18 @@
     double rightmost_edge = max.to<double>();
 
     if (leftmost_edge == rightmost_edge && input.numel() > 0) {
-        auto extrema = aminmax(input);
-        leftmost_edge = std::get<0>(extrema).item<double>();
-        rightmost_edge = std::get<1>(extrema).item<double>();
+        if (input.is_mps()) {
+            // aminmax has not been implemented on mps.
+            Tensor min = at::amin(input);
+            Tensor max = at::amax(input);
+
+            leftmost_edge = min.item<double>();
+            rightmost_edge = max.item<double>();
+        } else {
+            auto extrema = aminmax(input);
+            leftmost_edge = std::get<0>(extrema).item<double>();
+            rightmost_edge = std::get<1>(extrema).item<double>();
+        }
     }
 
     if (leftmost_edge == rightmost_edge) {
@@ -259,7 +281,7 @@
 
 /* Versions of histogramdd in which bins is a Tensor[] defining the sequences of bin edges.
  */
-Tensor& histogramdd_out_cpu(const Tensor& self, TensorList bins,
+Tensor& histogramdd_out(const Tensor& self, TensorList bins,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, TensorList& bin_edges) {
     histogramdd_check_inputs(self, bins, weight);
@@ -273,20 +295,20 @@
     return hist;
 }
 
-Tensor histogramdd_cpu(const Tensor& self, TensorList bins,
+Tensor _histogramdd(const Tensor& self, TensorList bins,
         const c10::optional<Tensor>& weight, bool density) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
     std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
     TensorList bin_edges_out_tl(bin_edges_out);
 
-    histogramdd_out_cpu(self, bins, weight, density, hist, bin_edges_out_tl);
+    histogramdd_out(self, bins, weight, density, hist, bin_edges_out_tl);
     return hist;
 }
 
 /* Versions of histogramdd in which bins is an int[]
  * defining the number of bins in each dimension.
  */
-std::vector<Tensor>& histogramdd_bin_edges_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArrayRef bin_ct,
         c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density,
         std::vector<Tensor>& bin_edges_out) {
@@ -304,25 +326,25 @@
         N == bin_size,
         "histogramdd: The size of bins must be equal to the innermost dimension of the input.");
     for (const auto dim : c10::irange(N)) {
-        linspace_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
-                bin_ct[dim] + 1, bin_edges_out[dim]);
+        at::linspace_out(bin_edges_out[dim], outer_bin_edges.first[dim], outer_bin_edges.second[dim],
+                bin_ct[dim] + 1);
     }
 
     return bin_edges_out;
 }
 
-std::vector<Tensor> histogramdd_bin_edges_cpu(const Tensor& self, IntArrayRef bin_ct,
+std::vector<Tensor> histogramdd_bin_edges(const Tensor& self, IntArrayRef bin_ct,
         c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density) {
     std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
-    return histogramdd_bin_edges_out_cpu(self, bin_ct, range, weight, density, bin_edges_out);
+    return histogramdd_bin_edges_out(self, bin_ct, range, weight, density, bin_edges_out);
 }
 
-Tensor& histogramdd_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+Tensor& histogramdd_out(const Tensor& self, IntArrayRef bin_ct,
         c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, TensorList& bin_edges) {
-    std::vector<Tensor> bins = histogramdd_bin_edges_cpu(self, bin_ct, range, weight, density);
+    std::vector<Tensor> bins = histogramdd_bin_edges(self, bin_ct, range, weight, density);
 
     histogramdd_check_inputs(self, bins, weight);
     histogramdd_prepare_out(self, bins, hist, bin_edges);
@@ -335,21 +357,21 @@
     return hist;
 }
 
-Tensor histogramdd_cpu(const Tensor& self, IntArrayRef bin_ct,
+Tensor _histogramdd(const Tensor& self, IntArrayRef bin_ct,
         c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
     std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
     TensorList bin_edges_out_tl(bin_edges_out);
 
-    histogramdd_out_cpu(self, bin_ct, range, weight, density, hist, bin_edges_out_tl);
+    histogramdd_out(self, bin_ct, range, weight, density, hist, bin_edges_out_tl);
     return hist;
 }
 
 /* Versions of histogram in which bins is a Tensor defining the sequence of bin edges.
  */
 std::tuple<Tensor&, Tensor&>
-histogram_out_cpu(const Tensor& self, const Tensor& bins,
+histogram_out(const Tensor& self, const Tensor& bins,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, Tensor& bin_edges) {
     Tensor reshaped_self = self.reshape({ self.numel(), 1 });
@@ -358,23 +380,23 @@
     TensorList bins_in = bins;
     TensorList bins_out = bin_edges;
 
-    histogramdd_out_cpu(reshaped_self, bins_in, reshaped_weight, density, hist, bins_out);
+    histogramdd_out(reshaped_self, bins_in, reshaped_weight, density, hist, bins_out);
 
     return std::forward_as_tuple(hist, bin_edges);
 }
 
 std::tuple<Tensor, Tensor>
-histogram_cpu(const Tensor& self, const Tensor& bins,
+histogram(const Tensor& self, const Tensor& bins,
         const c10::optional<Tensor>& weight, bool density) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
     Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous);
-    return histogram_out_cpu(self, bins, weight, density, hist, bin_edges);
+    return histogram_out(self, bins, weight, density, hist, bin_edges);
 }
 
 /* Versions of histogram in which bins is an integer specifying the number of equal-width bins.
  */
 std::tuple<Tensor&, Tensor&>
-histogram_out_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
+histogram_out(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density,
         Tensor& hist, Tensor& bin_edges) {
     Tensor reshaped_self = self.reshape({ self.numel(), 1 });
@@ -385,7 +407,7 @@
 
     histogramdd_prepare_out(reshaped_self, std::vector<int64_t>{bin_ct}, hist, bins_out);
     auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
-    linspace_out(outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1, bin_edges);
+    at::linspace_out(bin_edges, outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1);
 
     histogramdd_check_inputs(reshaped_self, bins_in, reshaped_weight);
 
@@ -394,16 +416,16 @@
 }
 
 std::tuple<Tensor, Tensor>
-histogram_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
+histogram(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
         const c10::optional<Tensor>& weight, bool density) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
     Tensor bin_edges_out = at::empty({0}, self.options());
-    return histogram_out_cpu(self, bin_ct, range, weight, density, hist, bin_edges_out);
+    return histogram_out(self, bin_ct, range, weight, density, hist, bin_edges_out);
 }
 
 /* Narrowed interface for the legacy torch.histc function.
  */
-Tensor& histogram_histc_cpu_out(const Tensor& self, int64_t bin_ct,
+Tensor& histogram_histc_out(const Tensor& self, int64_t bin_ct,
         const Scalar& min, const Scalar& max, Tensor& hist) {
     Tensor bin_edges = at::empty({0}, self.options());
 
@@ -414,7 +436,7 @@
     histogramdd_prepare_out(reshaped, std::vector<int64_t>{bin_ct}, hist, bins_out);
 
     auto outer_bin_edges = histc_select_outer_bin_edges(self, min, max);
-    linspace_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
+    at::linspace_out(bin_edges, outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1);
 
     histogramdd_check_inputs(reshaped, bins_in, {});
 
@@ -423,10 +445,10 @@
     return hist;
 }
 
-Tensor histogram_histc_cpu(const Tensor& self, int64_t bin_ct,
+Tensor histogram_histc(const Tensor& self, int64_t bin_ct,
         const Scalar& min, const Scalar& max) {
     Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
-    return histogram_histc_cpu_out(self, bin_ct, min, max, hist);
+    return histogram_histc_out(self, bin_ct, min, max, hist);
 }
 
 std::tuple<Tensor, std::vector<Tensor>> histogramdd(
diff --git a/aten/src/ATen/native/mps/operations/HistogramKernel.mm b/aten/src/ATen/native/mps/operations/HistogramKernel.mm
new file mode 100644
index 0000000..4a9beaf
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/HistogramKernel.mm
@@ -0,0 +1,391 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include <ATen/Dispatch.h>
+#include <ATen/native/Histogram.h>
+#include <ATen/native/mps/OperationUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#include <ATen/ops/sum.h>
+#endif
+
+namespace at::native {
+namespace mps {
+
+enum BIN_SELECTION_ALGORITHM {
+  LINEAR_INTERPOLATION,
+  LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
+  BINARY_SEARCH,
+};
+
+static const char* METAL_HISTOGRAM = R"HISTOGRAM_METAL(
+
+#include <metal_stdlib>
+using namespace metal;
+
+enum BIN_SELECTION_ALGORITHM {
+  LINEAR_INTERPOLATION,
+  LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
+  BINARY_SEARCH,
+};
+
+// Re-implementation of std::upper_bound with some modifications.
+template<typename T, typename U>
+U upper_bound(constant T * arr, U first, U len, T val) {
+  while (len > 0) {
+    U half_ = len >> 1;
+    U middle = first + half_;
+
+    if (val < arr[middle]) {
+      len = half_;
+    } else {
+      first = middle + 1;
+      len -= half_ + 1;
+    }
+  }
+  return first;
+}
+
+// The implementation here is mostly taken from the CPU's implementation with some modifications.
+// Please see `aten/src/ATen/native/cpu/HistogramKernel.cpp` for more details.
+template<typename T>
+kernel void histogramdd(constant T  * input_            [[buffer(0)]],
+                  constant T        * weight            [[buffer(1)]],
+                  device   T        * local_out         [[buffer(2)]],
+                  constant uint     * offsets           [[buffer(3)]],
+                  constant size_t   & num_dims          [[buffer(4)]],
+                  constant T        * bin_seq           [[buffer(5)]],
+                  constant int64_t  * num_bin_edges     [[buffer(6)]],
+                  constant T        * leftmost_edge     [[buffer(7)]],
+                  constant T        * rightmost_edge    [[buffer(8)]],
+                  constant int64_t  * local_out_strides [[buffer(9)]],
+                  constant uint8_t  & algorithm         [[buffer(10)]],
+                  constant uint8_t  & has_weight        [[buffer(11)]],
+                  uint tid [[thread_position_in_grid]]) {
+
+  constexpr T eps = 4e-6;
+  bool skip_element = false;
+  int64_t hist_index = 0;
+  int64_t bin_seq_offset = 0;
+
+  for (size_t dim = 0; dim < num_dims; dim++) {
+    T element = input_[offsets[tid * num_dims + dim]];
+
+    // Skips elements which fall outside the specified bins and NaN elements
+    // Adding an eps to the edges to eliminate precision issues that cause elements accidentally skipped,
+    // this is likely due to the minuscule implementation differences between the CPU and MPS's linspace.
+    if (!(element >= (leftmost_edge[dim] - eps) && element <= (rightmost_edge[dim] + eps))) {
+        skip_element = true;
+        break;
+    }
+    int64_t pos = -1;
+
+    if (algorithm == BIN_SELECTION_ALGORITHM::BINARY_SEARCH) {
+      pos = upper_bound(
+        bin_seq,
+        bin_seq_offset,
+        num_bin_edges[dim],
+        element
+      ) - bin_seq_offset - 1;
+    } else if (
+      algorithm == BIN_SELECTION_ALGORITHM::LINEAR_INTERPOLATION ||
+      algorithm == BIN_SELECTION_ALGORITHM::LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+      pos = static_cast<int64_t>((element - leftmost_edge[dim])
+                            * (num_bin_edges[dim] - 1)
+                            / (rightmost_edge[dim] - leftmost_edge[dim]));
+      if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+          int64_t pos_min = max(static_cast<int64_t>(0), pos - 1);
+          int64_t pos_max = min(pos + 2, num_bin_edges[dim]);
+          pos = upper_bound(
+            bin_seq,
+            bin_seq_offset + pos_min,
+            pos_max - pos_min,
+            element
+          ) - bin_seq_offset - 1;
+      }
+    }
+
+    if (pos == (num_bin_edges[dim] - 1)) {
+      pos -= 1;
+    }
+    hist_index += local_out_strides[dim + 1] * pos;
+    bin_seq_offset += num_bin_edges[dim];
+  }
+  if (!skip_element) {
+    // In the unweighted case, the default weight is 1
+    local_out[local_out_strides[0] * tid + hist_index] += has_weight ? weight[tid] : 1;
+  }
+}
+
+
+#define REGISTER_HISTOGRAMDD_OP(DTYPE)                        \
+template                                                      \
+[[host_name("histogramdd_" #DTYPE)]]                          \
+kernel void histogramdd<DTYPE>(                               \
+  constant DTYPE    * input_                  [[buffer(0)]],  \
+  constant DTYPE    * weight                  [[buffer(1)]],  \
+  device   DTYPE    * local_out               [[buffer(2)]],  \
+  constant uint     * offsets                 [[buffer(3)]],  \
+  constant size_t   & num_dims                [[buffer(4)]],  \
+  constant DTYPE    * bin_seq                 [[buffer(5)]],  \
+  constant int64_t  * num_bin_edges           [[buffer(6)]],  \
+  constant DTYPE    * leftmost_edge           [[buffer(7)]],  \
+  constant DTYPE    * rightmost_edge          [[buffer(8)]],  \
+  constant int64_t  * local_out_strides       [[buffer(9)]],  \
+  constant uint8_t  & bin_selection_algorithm [[buffer(10)]], \
+  constant uint8_t  & has_weight              [[buffer(11)]], \
+  uint tid [[thread_position_in_grid]]);
+
+REGISTER_HISTOGRAMDD_OP(float);
+REGISTER_HISTOGRAMDD_OP(half);
+
+)HISTOGRAM_METAL";
+
+static id<MTLLibrary> compileHistogramOpLibrary(id<MTLDevice> device) {
+  static id<MTLLibrary> histogramLibrary = nil;
+  if (histogramLibrary) {
+    return histogramLibrary;
+  }
+
+  NSError* error = nil;
+  MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
+  [options setLanguageVersion:MTLLanguageVersion2_3];
+  histogramLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_HISTOGRAM
+                                                                     encoding:NSASCIIStringEncoding]
+                                          options:options
+                                            error:&error];
+  TORCH_CHECK(histogramLibrary, "Failed to create metal histogram library, error: ", [[error description] UTF8String]);
+  return histogramLibrary;
+}
+
+static id<MTLComputePipelineState> histogramPipelineState(id<MTLDevice> device, const std::string& kernel) {
+  static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
+  id<MTLComputePipelineState> pso = psoCache[kernel];
+  if (pso) {
+    return pso;
+  }
+
+  NSError* error = nil;
+  id<MTLLibrary> crossLib = compileHistogramOpLibrary(device);
+  id<MTLFunction> crossFunc = [crossLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
+  TORCH_CHECK(crossFunc, "Failed to create function state object for: ", kernel);
+  pso = [device newComputePipelineStateWithFunction:crossFunc error:&error];
+  TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
+
+  psoCache[kernel] = pso;
+  return pso;
+}
+
+template <typename input_t, BIN_SELECTION_ALGORITHM algorithm>
+void histogramdd_kernel_impl(Tensor& hist_output,
+                             const TensorList& bin_edges,
+                             const Tensor& input,
+                             const c10::optional<Tensor>& weight) {
+  TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS");
+  TORCH_INTERNAL_ASSERT(input.dim() == 2);
+
+  constexpr uint8_t bin_selection_algorithm = algorithm;
+  const int64_t N = input.size(0);
+  const bool has_weight = weight.has_value();
+
+  if (has_weight) {
+    TORCH_CHECK(weight.value().is_contiguous(), "histogramdd(): weight should be contiguous on MPS");
+    TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
+    TORCH_INTERNAL_ASSERT(weight.value().scalar_type() == input.scalar_type());
+  }
+
+  const int64_t D = input.size(1);
+  size_t bin_edges_numel = 0;
+  TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
+  for (const auto dim : c10::irange(D)) {
+    bin_edges_numel += bin_edges[dim].numel();
+    TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
+    TORCH_INTERNAL_ASSERT(hist_output.size(dim) + 1 == bin_edges[dim].numel());
+  }
+
+  if (D == 0) {
+    // hist is an empty tensor in this case; nothing to do here
+    return;
+  }
+
+  std::vector<input_t> bin_seq(bin_edges_numel);
+  std::vector<int64_t> num_bin_edges(D);
+  std::vector<input_t> leftmost_edge(D);
+  std::vector<input_t> rightmost_edge(D);
+  size_t bin_seq_offset = 0;
+
+  for (const auto dim : c10::irange(D)) {
+    for (const auto elem_idx : c10::irange(bin_edges[dim].numel())) {
+      bin_seq[bin_seq_offset + elem_idx] = (bin_edges[dim][elem_idx].item().to<input_t>());
+    }
+    num_bin_edges[dim] = bin_edges[dim].numel();
+    leftmost_edge[dim] = bin_seq[bin_seq_offset];
+    rightmost_edge[dim] = bin_seq[bin_seq_offset + num_bin_edges[dim] - 1];
+    bin_seq_offset += num_bin_edges[dim];
+  }
+
+  const uint32_t stridedIndicesNumThreads = input.numel();
+  const uint32_t numThreads = N;
+  const auto hist_sizes = hist_output.sizes();
+
+  DimVector thread_hist_sizes(hist_sizes.size() + 1); // [n_threads, output_sizes...]
+  thread_hist_sizes[0] = numThreads;
+  std::copy(hist_sizes.begin(), hist_sizes.end(), thread_hist_sizes.begin() + 1);
+  Tensor thread_histograms = at::zeros(
+      thread_hist_sizes, hist_output.scalar_type(), c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */
+  );
+  TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
+
+  id<MTLDevice> device = MPSDevice::getInstance()->device();
+  id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
+  id<MTLBuffer> outputBuffer = getMTLBufferStorage(thread_histograms);
+  id<MTLBuffer> weightBuffer =
+      has_weight ? getMTLBufferStorage(weight.value()) : [[device newBufferWithLength:0 options:0] autorelease];
+  size_t weightOffset = has_weight ? weight.value().storage_offset() * weight.value().element_size() : 0;
+  MPSStream* mpsStream = getCurrentMPSStream();
+  const uint32_t nDim = input.sizes().size();
+
+  dispatch_sync(mpsStream->queue(), ^() {
+    @autoreleasepool {
+      id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
+      id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
+      MTLSize gridSize = MTLSizeMake(stridedIndicesNumThreads, 1, 1);
+      const IntArrayRef& inputShape = input.sizes();
+      std::vector<uint32_t> inputShapeData(inputShape.size());
+      std::vector<uint32_t> strides(input.strides().begin(), input.strides().end());
+
+      for (const auto i : c10::irange(inputShape.size())) {
+        TORCH_CHECK(i <= UINT32_MAX);
+        inputShapeData[i] = (uint32_t)(inputShape[i]);
+      }
+
+      id<MTLBuffer> stridedIndicesBuffer = [[device newBufferWithLength:stridedIndicesNumThreads * sizeof(uint)
+                                                                options:0] autorelease];
+      id<MTLComputePipelineState> stridedIndicesPSO = MPSDevice::getInstance()->metalIndexingPSO("kernel_index_offset");
+
+      [computeEncoder setComputePipelineState:stridedIndicesPSO];
+      [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim atIndex:0];
+      [computeEncoder setBuffer:stridedIndicesBuffer offset:0 atIndex:1];
+      [computeEncoder setBytes:inputShapeData.data() length:sizeof(uint32_t) * inputShape.size() atIndex:2];
+      [computeEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3];
+
+      NSUInteger stridedIndicesTGSize = stridedIndicesPSO.maxTotalThreadsPerThreadgroup;
+      if (stridedIndicesTGSize > stridedIndicesNumThreads)
+        stridedIndicesTGSize = stridedIndicesNumThreads;
+
+      MTLSize stridedIndicesThreadGroupSize = MTLSizeMake(stridedIndicesTGSize, 1, 1);
+      [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:stridedIndicesThreadGroupSize];
+
+      const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input.scalar_type());
+      id<MTLComputePipelineState> histogramPSO = histogramPipelineState(device, kernel);
+      [computeEncoder setComputePipelineState:histogramPSO];
+      [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
+      [computeEncoder setBuffer:weightBuffer offset:weightOffset atIndex:1];
+      [computeEncoder setBuffer:outputBuffer
+                         offset:thread_histograms.storage_offset() * thread_histograms.element_size()
+                        atIndex:2];
+      [computeEncoder setBuffer:stridedIndicesBuffer offset:0 atIndex:3];
+      [computeEncoder setBytes:&D length:sizeof(int64_t) atIndex:4];
+      [computeEncoder setBytes:bin_seq.data() length:sizeof(input_t) * bin_seq_offset atIndex:5];
+      [computeEncoder setBytes:num_bin_edges.data() length:sizeof(int64_t) * D atIndex:6];
+      [computeEncoder setBytes:leftmost_edge.data() length:sizeof(input_t) * D atIndex:7];
+      [computeEncoder setBytes:rightmost_edge.data() length:sizeof(input_t) * D atIndex:8];
+      [computeEncoder setBytes:thread_histograms.strides().data()
+                        length:sizeof(int64_t) * thread_hist_sizes.size()
+                       atIndex:9];
+      [computeEncoder setBytes:&bin_selection_algorithm length:sizeof(uint8_t) atIndex:10];
+      [computeEncoder setBytes:&has_weight length:sizeof(uint8_t) atIndex:11];
+
+      NSUInteger tgSize = histogramPSO.maxTotalThreadsPerThreadgroup;
+      if (tgSize > numThreads) {
+        tgSize = numThreads;
+      }
+      gridSize = MTLSizeMake(numThreads, 1, 1);
+      MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
+      [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
+
+      [computeEncoder endEncoding];
+      mpsStream->synchronize(SyncType::COMMIT);
+    }
+  });
+  at::sum_out(hist_output, thread_histograms, /*dim=*/{0});
+}
+
+template <BIN_SELECTION_ALGORITHM bin_algorithm>
+static void histogramdd_out_mps_template(const Tensor& self,
+                                         const c10::optional<Tensor>& weight,
+                                         bool density,
+                                         Tensor& hist,
+                                         const TensorList& bin_edges) {
+  hist.fill_(0);
+
+  const int64_t N = self.size(-1);
+  const int64_t M =
+      std::accumulate(self.sizes().begin(), self.sizes().end() - 1, (int64_t)1, std::multiplies<int64_t>());
+
+  const Tensor reshaped_input = self.reshape({M, N});
+
+  const auto reshaped_weight =
+      weight.has_value() ? c10::optional<Tensor>(weight.value().reshape({M})) : c10::optional<Tensor>();
+
+  std::vector<Tensor> bin_edges_contig(bin_edges.size());
+  for (const auto dim : c10::irange(bin_edges_contig.size())) {
+    bin_edges_contig[dim] = bin_edges[dim].contiguous();
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "histogram_mps", [&]() {
+    mps::histogramdd_kernel_impl<scalar_t, bin_algorithm>(hist, bin_edges_contig, reshaped_input, reshaped_weight);
+  });
+
+  /* Divides each bin's value by the total count/weight in all bins,
+   * and by the bin's volume.
+   */
+  if (density) {
+    const auto hist_sum = hist.sum().item();
+    hist.div_(hist_sum);
+
+    /* For each dimension, divides each bin's value
+     * by the bin's length in that dimension.
+     */
+    for (const auto dim : c10::irange(N)) {
+      const auto bin_lengths = bin_edges[dim].diff();
+
+      // Used to reshape bin_lengths to align with the corresponding dimension of hist.
+      std::vector<int64_t> shape(N, 1);
+      shape[dim] = bin_lengths.numel();
+
+      hist.div_(bin_lengths.reshape(shape));
+    }
+  }
+}
+} // namespace mps
+
+static void histogramdd_kernel(const Tensor& self,
+                               const c10::optional<Tensor>& weight,
+                               bool density,
+                               Tensor& hist,
+                               const TensorList& bin_edges) {
+  mps::histogramdd_out_mps_template<mps::BINARY_SEARCH>(self, weight, density, hist, bin_edges);
+}
+
+static void histogramdd_linear_kernel(const Tensor& self,
+                                      const c10::optional<Tensor>& weight,
+                                      bool density,
+                                      Tensor& hist,
+                                      const TensorList& bin_edges,
+                                      bool local_search) {
+  if (local_search) {
+    // histogramdd codepath: both hist and bin_edges are eventually returned as output,
+    // so we'll keep them consistent
+    mps::histogramdd_out_mps_template<mps::LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
+        self, weight, density, hist, bin_edges);
+  } else {
+    // histc codepath: bin_edges are not returned to the caller
+    mps::histogramdd_out_mps_template<mps::LINEAR_INTERPOLATION>(self, weight, density, hist, bin_edges);
+  }
+}
+
+REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel);
+REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel);
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3392b00..79190bd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9160,46 +9160,46 @@
 
 - func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
-    CPU: histogram_histc_cpu_out
+    CPU, MPS: histogram_histc_out
     CUDA: _histc_out_cuda
 
 - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
   variants: method, function
   dispatch:
-    CPU: histogram_histc_cpu
+    CPU, MPS: histogram_histc
     CUDA: _histc_cuda
 
 - func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
   dispatch:
-    CPU: histogram_out_cpu
+    CPU, MPS: histogram_out
 
 - func: histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
   variants: method, function
   dispatch:
-    CPU: histogram_cpu
+    CPU, MPS: histogram
 
 - func: histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
   dispatch:
-    CPU: histogram_out_cpu
+    CPU, MPS: histogram_out
 
 - func: histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
   variants: method, function
   dispatch:
-    CPU: histogram_cpu
+    CPU, MPS: histogram
 
 - func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]
   dispatch:
-    CPU: histogramdd_bin_edges_cpu
+    CPU, MPS: histogramdd_bin_edges
   autogen: _histogramdd_bin_edges.out
 
 - func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor
   dispatch:
-    CPU: histogramdd_cpu
+    CPU, MPS: _histogramdd
   autogen: _histogramdd_from_bin_cts.out
 
 - func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor
   dispatch:
-    CPU: histogramdd_cpu
+    CPU, MPS: _histogramdd
   autogen: _histogramdd_from_bin_tensors.out
 
 - func: histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
diff --git a/test/test_mps.py b/test/test_mps.py
index d8ff872..d6e0564 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -89,6 +89,10 @@
         'floor_divide': [torch.float16, torch.float32],
         # derivative for aten::narrow_copy is not implemented on CPU
         'narrow_copy': [torch.float16, torch.float32],
+        # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+        'histogramdd': [torch.float16, torch.float32],
+        # derivative for aten::histogram is not implemented
+        'histogram': [torch.float16, torch.float32],
         # 'bool' object is not iterable
         'allclose': [torch.float16, torch.float32],
         'equal': [torch.float16, torch.float32],
@@ -409,9 +413,6 @@
         'geqrf': None,
         'nn.functional.grid_sample': None,  # Unsupported Border padding mode
         'heaviside': None,
-        'histc': None,
-        'histogram': None,
-        'histogramdd': None,
         'i0': None,
         'igamma': None,
         'igammac': None,
@@ -10244,11 +10245,11 @@
         self.assertEqual(out, "")
 
     def _get_not_implemented_op(self):
-        # This can be changed once we actually implement `torch.histc`
+        # This can be changed once we actually implement `torch.lgamma`
         # Should return fn, args, kwargs, string_version
-        return (torch.histc,
+        return (torch.lgamma,
                 torch.tensor([100], device='mps'), {},
-                "torch.histc(torch.tensor([4], device='mps', dtype=torch.float))")
+                "torch.lgamma(torch.tensor([4], device='mps', dtype=torch.float))")
 
     def test_error_on_not_implemented(self):
         fn, args, kwargs, _ = self._get_not_implemented_op()
diff --git a/torch/csrc/jit/runtime/static/generated_ops.cpp b/torch/csrc/jit/runtime/static/generated_ops.cpp
index 03e3098..fdb5b93 100644
--- a/torch/csrc/jit/runtime/static/generated_ops.cpp
+++ b/torch/csrc/jit/runtime/static/generated_ops.cpp
@@ -3005,13 +3005,12 @@
       const auto min = p_node->Input(2).toScalar();
       const auto max = p_node->Input(3).toScalar();
       if (p_node->Output(0).isNone()) {
-        p_node->Output(0) =
-            at::native::histogram_histc_cpu(self, bins, min, max);
+        p_node->Output(0) = at::native::histogram_histc(self, bins, min, max);
         return;
       }
       auto& out = p_node->Output(0).toTensor();
       fastResizeToZero(out);
-      at::native::histogram_histc_cpu_out(self, bins, min, max, out);
+      at::native::histogram_histc_out(self, bins, min, max, out);
     };
   }
   LogAndDumpSchema(n);