[MPS] Add histogram ops (#96652)
Adds `torch.histc`, `torch.histogram`, `torch.histogramdd`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96652
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/mps/IndexKernels.h b/aten/src/ATen/mps/IndexKernels.h
index 4be40ca..a845e5f 100644
--- a/aten/src/ATen/mps/IndexKernels.h
+++ b/aten/src/ATen/mps/IndexKernels.h
@@ -224,6 +224,21 @@
}
}
+kernel void kernel_index_offset(constant uint * strides [[buffer(0)]],
+ device uint * data_offsets [[buffer(1)]],
+ constant uint * iter_shape [[buffer(2)]],
+ constant uint & num_dimensions [[buffer(3)]],
+ uint thread_index [[thread_position_in_grid]]) {
+ uint32_t idx = thread_index;
+ for (uint32_t dim = 0; dim < num_dimensions; dim++) {
+ uint32_t reversed_dim = num_dimensions - dim -1;
+ uint32_t remainder = idx % iter_shape[reversed_dim];
+ idx /= iter_shape[reversed_dim];
+
+ data_offsets[thread_index] += remainder * strides[reversed_dim];
+ }
+}
+
template<typename T, typename E>
kernel void index_put_accumulate_native_dtypes(
#if __METAL_VERSION__ >= 300
diff --git a/aten/src/ATen/native/Histogram.cpp b/aten/src/ATen/native/Histogram.cpp
index 116e91a..2a22650 100644
--- a/aten/src/ATen/native/Histogram.cpp
+++ b/aten/src/ATen/native/Histogram.cpp
@@ -16,11 +16,13 @@
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
#include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
#include <ATen/ops/aminmax.h>
+#include <ATen/ops/amin.h>
+#include <ATen/ops/amax.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/histc_native.h>
#include <ATen/ops/histogram_native.h>
#include <ATen/ops/histogramdd_native.h>
-#include <ATen/ops/linspace_native.h>
+#include <ATen/ops/linspace.h>
#endif
#include <numeric>
@@ -193,7 +195,18 @@
} else if (input.numel() > 0) {
// non-empty input
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
- infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
+ if (input.is_mps()) {
+ // aminmax has not been implemented on mps.
+ Tensor min = at::amin(input, 0);
+ Tensor max = at::amax(input, 0);
+
+ for (const auto i : c10::irange(N)) {
+ leftmost_edges[i] = min[i].item().to<scalar_t>();
+ rightmost_edges[i] = max[i].item().to<scalar_t>();
+ }
+ } else {
+ infer_bin_edges_from_input<scalar_t>(input, N, leftmost_edges, rightmost_edges);
+ }
});
}
@@ -226,9 +239,18 @@
double rightmost_edge = max.to<double>();
if (leftmost_edge == rightmost_edge && input.numel() > 0) {
- auto extrema = aminmax(input);
- leftmost_edge = std::get<0>(extrema).item<double>();
- rightmost_edge = std::get<1>(extrema).item<double>();
+ if (input.is_mps()) {
+ // aminmax has not been implemented on mps.
+ Tensor min = at::amin(input);
+ Tensor max = at::amax(input);
+
+ leftmost_edge = min.item<double>();
+ rightmost_edge = max.item<double>();
+ } else {
+ auto extrema = aminmax(input);
+ leftmost_edge = std::get<0>(extrema).item<double>();
+ rightmost_edge = std::get<1>(extrema).item<double>();
+ }
}
if (leftmost_edge == rightmost_edge) {
@@ -259,7 +281,7 @@
/* Versions of histogramdd in which bins is a Tensor[] defining the sequences of bin edges.
*/
-Tensor& histogramdd_out_cpu(const Tensor& self, TensorList bins,
+Tensor& histogramdd_out(const Tensor& self, TensorList bins,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, TensorList& bin_edges) {
histogramdd_check_inputs(self, bins, weight);
@@ -273,20 +295,20 @@
return hist;
}
-Tensor histogramdd_cpu(const Tensor& self, TensorList bins,
+Tensor _histogramdd(const Tensor& self, TensorList bins,
const c10::optional<Tensor>& weight, bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
TensorList bin_edges_out_tl(bin_edges_out);
- histogramdd_out_cpu(self, bins, weight, density, hist, bin_edges_out_tl);
+ histogramdd_out(self, bins, weight, density, hist, bin_edges_out_tl);
return hist;
}
/* Versions of histogramdd in which bins is an int[]
* defining the number of bins in each dimension.
*/
-std::vector<Tensor>& histogramdd_bin_edges_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density,
std::vector<Tensor>& bin_edges_out) {
@@ -304,25 +326,25 @@
N == bin_size,
"histogramdd: The size of bins must be equal to the innermost dimension of the input.");
for (const auto dim : c10::irange(N)) {
- linspace_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
- bin_ct[dim] + 1, bin_edges_out[dim]);
+ at::linspace_out(bin_edges_out[dim], outer_bin_edges.first[dim], outer_bin_edges.second[dim],
+ bin_ct[dim] + 1);
}
return bin_edges_out;
}
-std::vector<Tensor> histogramdd_bin_edges_cpu(const Tensor& self, IntArrayRef bin_ct,
+std::vector<Tensor> histogramdd_bin_edges(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density) {
std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
- return histogramdd_bin_edges_out_cpu(self, bin_ct, range, weight, density, bin_edges_out);
+ return histogramdd_bin_edges_out(self, bin_ct, range, weight, density, bin_edges_out);
}
-Tensor& histogramdd_out_cpu(const Tensor& self, IntArrayRef bin_ct,
+Tensor& histogramdd_out(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, TensorList& bin_edges) {
- std::vector<Tensor> bins = histogramdd_bin_edges_cpu(self, bin_ct, range, weight, density);
+ std::vector<Tensor> bins = histogramdd_bin_edges(self, bin_ct, range, weight, density);
histogramdd_check_inputs(self, bins, weight);
histogramdd_prepare_out(self, bins, hist, bin_edges);
@@ -335,21 +357,21 @@
return hist;
}
-Tensor histogramdd_cpu(const Tensor& self, IntArrayRef bin_ct,
+Tensor _histogramdd(const Tensor& self, IntArrayRef bin_ct,
c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
std::vector<Tensor> bin_edges_out = allocate_bin_edges_tensors(self);
TensorList bin_edges_out_tl(bin_edges_out);
- histogramdd_out_cpu(self, bin_ct, range, weight, density, hist, bin_edges_out_tl);
+ histogramdd_out(self, bin_ct, range, weight, density, hist, bin_edges_out_tl);
return hist;
}
/* Versions of histogram in which bins is a Tensor defining the sequence of bin edges.
*/
std::tuple<Tensor&, Tensor&>
-histogram_out_cpu(const Tensor& self, const Tensor& bins,
+histogram_out(const Tensor& self, const Tensor& bins,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({ self.numel(), 1 });
@@ -358,23 +380,23 @@
TensorList bins_in = bins;
TensorList bins_out = bin_edges;
- histogramdd_out_cpu(reshaped_self, bins_in, reshaped_weight, density, hist, bins_out);
+ histogramdd_out(reshaped_self, bins_in, reshaped_weight, density, hist, bins_out);
return std::forward_as_tuple(hist, bin_edges);
}
std::tuple<Tensor, Tensor>
-histogram_cpu(const Tensor& self, const Tensor& bins,
+histogram(const Tensor& self, const Tensor& bins,
const c10::optional<Tensor>& weight, bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous);
- return histogram_out_cpu(self, bins, weight, density, hist, bin_edges);
+ return histogram_out(self, bins, weight, density, hist, bin_edges);
}
/* Versions of histogram in which bins is an integer specifying the number of equal-width bins.
*/
std::tuple<Tensor&, Tensor&>
-histogram_out_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
+histogram_out(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density,
Tensor& hist, Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({ self.numel(), 1 });
@@ -385,7 +407,7 @@
histogramdd_prepare_out(reshaped_self, std::vector<int64_t>{bin_ct}, hist, bins_out);
auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
- linspace_out(outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1, bin_edges);
+ at::linspace_out(bin_edges, outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1);
histogramdd_check_inputs(reshaped_self, bins_in, reshaped_weight);
@@ -394,16 +416,16 @@
}
std::tuple<Tensor, Tensor>
-histogram_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
+histogram(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRef<double>> range,
const c10::optional<Tensor>& weight, bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges_out = at::empty({0}, self.options());
- return histogram_out_cpu(self, bin_ct, range, weight, density, hist, bin_edges_out);
+ return histogram_out(self, bin_ct, range, weight, density, hist, bin_edges_out);
}
/* Narrowed interface for the legacy torch.histc function.
*/
-Tensor& histogram_histc_cpu_out(const Tensor& self, int64_t bin_ct,
+Tensor& histogram_histc_out(const Tensor& self, int64_t bin_ct,
const Scalar& min, const Scalar& max, Tensor& hist) {
Tensor bin_edges = at::empty({0}, self.options());
@@ -414,7 +436,7 @@
histogramdd_prepare_out(reshaped, std::vector<int64_t>{bin_ct}, hist, bins_out);
auto outer_bin_edges = histc_select_outer_bin_edges(self, min, max);
- linspace_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
+ at::linspace_out(bin_edges, outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1);
histogramdd_check_inputs(reshaped, bins_in, {});
@@ -423,10 +445,10 @@
return hist;
}
-Tensor histogram_histc_cpu(const Tensor& self, int64_t bin_ct,
+Tensor histogram_histc(const Tensor& self, int64_t bin_ct,
const Scalar& min, const Scalar& max) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
- return histogram_histc_cpu_out(self, bin_ct, min, max, hist);
+ return histogram_histc_out(self, bin_ct, min, max, hist);
}
std::tuple<Tensor, std::vector<Tensor>> histogramdd(
diff --git a/aten/src/ATen/native/mps/operations/HistogramKernel.mm b/aten/src/ATen/native/mps/operations/HistogramKernel.mm
new file mode 100644
index 0000000..4a9beaf
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/HistogramKernel.mm
@@ -0,0 +1,391 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include <ATen/Dispatch.h>
+#include <ATen/native/Histogram.h>
+#include <ATen/native/mps/OperationUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#include <ATen/ops/sum.h>
+#endif
+
+namespace at::native {
+namespace mps {
+
+enum BIN_SELECTION_ALGORITHM {
+ LINEAR_INTERPOLATION,
+ LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
+ BINARY_SEARCH,
+};
+
+static const char* METAL_HISTOGRAM = R"HISTOGRAM_METAL(
+
+#include <metal_stdlib>
+using namespace metal;
+
+enum BIN_SELECTION_ALGORITHM {
+ LINEAR_INTERPOLATION,
+ LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
+ BINARY_SEARCH,
+};
+
+// Re-implementation of std::upper_bound with some modifications.
+template<typename T, typename U>
+U upper_bound(constant T * arr, U first, U len, T val) {
+ while (len > 0) {
+ U half_ = len >> 1;
+ U middle = first + half_;
+
+ if (val < arr[middle]) {
+ len = half_;
+ } else {
+ first = middle + 1;
+ len -= half_ + 1;
+ }
+ }
+ return first;
+}
+
+// The implementation here is mostly taken from the CPU's implementation with some modifications.
+// Please see `aten/src/ATen/native/cpu/HistogramKernel.cpp` for more details.
+template<typename T>
+kernel void histogramdd(constant T * input_ [[buffer(0)]],
+ constant T * weight [[buffer(1)]],
+ device T * local_out [[buffer(2)]],
+ constant uint * offsets [[buffer(3)]],
+ constant size_t & num_dims [[buffer(4)]],
+ constant T * bin_seq [[buffer(5)]],
+ constant int64_t * num_bin_edges [[buffer(6)]],
+ constant T * leftmost_edge [[buffer(7)]],
+ constant T * rightmost_edge [[buffer(8)]],
+ constant int64_t * local_out_strides [[buffer(9)]],
+ constant uint8_t & algorithm [[buffer(10)]],
+ constant uint8_t & has_weight [[buffer(11)]],
+ uint tid [[thread_position_in_grid]]) {
+
+ constexpr T eps = 4e-6;
+ bool skip_element = false;
+ int64_t hist_index = 0;
+ int64_t bin_seq_offset = 0;
+
+ for (size_t dim = 0; dim < num_dims; dim++) {
+ T element = input_[offsets[tid * num_dims + dim]];
+
+ // Skips elements which fall outside the specified bins and NaN elements
+ // Adding an eps to the edges to eliminate precision issues that cause elements accidentally skipped,
+ // this is likely due to the minuscule implementation differences between the CPU and MPS's linspace.
+ if (!(element >= (leftmost_edge[dim] - eps) && element <= (rightmost_edge[dim] + eps))) {
+ skip_element = true;
+ break;
+ }
+ int64_t pos = -1;
+
+ if (algorithm == BIN_SELECTION_ALGORITHM::BINARY_SEARCH) {
+ pos = upper_bound(
+ bin_seq,
+ bin_seq_offset,
+ num_bin_edges[dim],
+ element
+ ) - bin_seq_offset - 1;
+ } else if (
+ algorithm == BIN_SELECTION_ALGORITHM::LINEAR_INTERPOLATION ||
+ algorithm == BIN_SELECTION_ALGORITHM::LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+ pos = static_cast<int64_t>((element - leftmost_edge[dim])
+ * (num_bin_edges[dim] - 1)
+ / (rightmost_edge[dim] - leftmost_edge[dim]));
+ if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
+ int64_t pos_min = max(static_cast<int64_t>(0), pos - 1);
+ int64_t pos_max = min(pos + 2, num_bin_edges[dim]);
+ pos = upper_bound(
+ bin_seq,
+ bin_seq_offset + pos_min,
+ pos_max - pos_min,
+ element
+ ) - bin_seq_offset - 1;
+ }
+ }
+
+ if (pos == (num_bin_edges[dim] - 1)) {
+ pos -= 1;
+ }
+ hist_index += local_out_strides[dim + 1] * pos;
+ bin_seq_offset += num_bin_edges[dim];
+ }
+ if (!skip_element) {
+ // In the unweighted case, the default weight is 1
+ local_out[local_out_strides[0] * tid + hist_index] += has_weight ? weight[tid] : 1;
+ }
+}
+
+
+#define REGISTER_HISTOGRAMDD_OP(DTYPE) \
+template \
+[[host_name("histogramdd_" #DTYPE)]] \
+kernel void histogramdd<DTYPE>( \
+ constant DTYPE * input_ [[buffer(0)]], \
+ constant DTYPE * weight [[buffer(1)]], \
+ device DTYPE * local_out [[buffer(2)]], \
+ constant uint * offsets [[buffer(3)]], \
+ constant size_t & num_dims [[buffer(4)]], \
+ constant DTYPE * bin_seq [[buffer(5)]], \
+ constant int64_t * num_bin_edges [[buffer(6)]], \
+ constant DTYPE * leftmost_edge [[buffer(7)]], \
+ constant DTYPE * rightmost_edge [[buffer(8)]], \
+ constant int64_t * local_out_strides [[buffer(9)]], \
+ constant uint8_t & bin_selection_algorithm [[buffer(10)]], \
+ constant uint8_t & has_weight [[buffer(11)]], \
+ uint tid [[thread_position_in_grid]]);
+
+REGISTER_HISTOGRAMDD_OP(float);
+REGISTER_HISTOGRAMDD_OP(half);
+
+)HISTOGRAM_METAL";
+
+static id<MTLLibrary> compileHistogramOpLibrary(id<MTLDevice> device) {
+ static id<MTLLibrary> histogramLibrary = nil;
+ if (histogramLibrary) {
+ return histogramLibrary;
+ }
+
+ NSError* error = nil;
+ MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
+ [options setLanguageVersion:MTLLanguageVersion2_3];
+ histogramLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_HISTOGRAM
+ encoding:NSASCIIStringEncoding]
+ options:options
+ error:&error];
+ TORCH_CHECK(histogramLibrary, "Failed to create metal histogram library, error: ", [[error description] UTF8String]);
+ return histogramLibrary;
+}
+
+static id<MTLComputePipelineState> histogramPipelineState(id<MTLDevice> device, const std::string& kernel) {
+ static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
+ id<MTLComputePipelineState> pso = psoCache[kernel];
+ if (pso) {
+ return pso;
+ }
+
+ NSError* error = nil;
+ id<MTLLibrary> crossLib = compileHistogramOpLibrary(device);
+ id<MTLFunction> crossFunc = [crossLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
+ TORCH_CHECK(crossFunc, "Failed to create function state object for: ", kernel);
+ pso = [device newComputePipelineStateWithFunction:crossFunc error:&error];
+ TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
+
+ psoCache[kernel] = pso;
+ return pso;
+}
+
+template <typename input_t, BIN_SELECTION_ALGORITHM algorithm>
+void histogramdd_kernel_impl(Tensor& hist_output,
+ const TensorList& bin_edges,
+ const Tensor& input,
+ const c10::optional<Tensor>& weight) {
+ TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS");
+ TORCH_INTERNAL_ASSERT(input.dim() == 2);
+
+ constexpr uint8_t bin_selection_algorithm = algorithm;
+ const int64_t N = input.size(0);
+ const bool has_weight = weight.has_value();
+
+ if (has_weight) {
+ TORCH_CHECK(weight.value().is_contiguous(), "histogramdd(): weight should be contiguous on MPS");
+ TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
+ TORCH_INTERNAL_ASSERT(weight.value().scalar_type() == input.scalar_type());
+ }
+
+ const int64_t D = input.size(1);
+ size_t bin_edges_numel = 0;
+ TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
+ for (const auto dim : c10::irange(D)) {
+ bin_edges_numel += bin_edges[dim].numel();
+ TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
+ TORCH_INTERNAL_ASSERT(hist_output.size(dim) + 1 == bin_edges[dim].numel());
+ }
+
+ if (D == 0) {
+ // hist is an empty tensor in this case; nothing to do here
+ return;
+ }
+
+ std::vector<input_t> bin_seq(bin_edges_numel);
+ std::vector<int64_t> num_bin_edges(D);
+ std::vector<input_t> leftmost_edge(D);
+ std::vector<input_t> rightmost_edge(D);
+ size_t bin_seq_offset = 0;
+
+ for (const auto dim : c10::irange(D)) {
+ for (const auto elem_idx : c10::irange(bin_edges[dim].numel())) {
+ bin_seq[bin_seq_offset + elem_idx] = (bin_edges[dim][elem_idx].item().to<input_t>());
+ }
+ num_bin_edges[dim] = bin_edges[dim].numel();
+ leftmost_edge[dim] = bin_seq[bin_seq_offset];
+ rightmost_edge[dim] = bin_seq[bin_seq_offset + num_bin_edges[dim] - 1];
+ bin_seq_offset += num_bin_edges[dim];
+ }
+
+ const uint32_t stridedIndicesNumThreads = input.numel();
+ const uint32_t numThreads = N;
+ const auto hist_sizes = hist_output.sizes();
+
+ DimVector thread_hist_sizes(hist_sizes.size() + 1); // [n_threads, output_sizes...]
+ thread_hist_sizes[0] = numThreads;
+ std::copy(hist_sizes.begin(), hist_sizes.end(), thread_hist_sizes.begin() + 1);
+ Tensor thread_histograms = at::zeros(
+ thread_hist_sizes, hist_output.scalar_type(), c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */
+ );
+ TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
+
+ id<MTLDevice> device = MPSDevice::getInstance()->device();
+ id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
+ id<MTLBuffer> outputBuffer = getMTLBufferStorage(thread_histograms);
+ id<MTLBuffer> weightBuffer =
+ has_weight ? getMTLBufferStorage(weight.value()) : [[device newBufferWithLength:0 options:0] autorelease];
+ size_t weightOffset = has_weight ? weight.value().storage_offset() * weight.value().element_size() : 0;
+ MPSStream* mpsStream = getCurrentMPSStream();
+ const uint32_t nDim = input.sizes().size();
+
+ dispatch_sync(mpsStream->queue(), ^() {
+ @autoreleasepool {
+ id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
+ MTLSize gridSize = MTLSizeMake(stridedIndicesNumThreads, 1, 1);
+ const IntArrayRef& inputShape = input.sizes();
+ std::vector<uint32_t> inputShapeData(inputShape.size());
+ std::vector<uint32_t> strides(input.strides().begin(), input.strides().end());
+
+ for (const auto i : c10::irange(inputShape.size())) {
+ TORCH_CHECK(i <= UINT32_MAX);
+ inputShapeData[i] = (uint32_t)(inputShape[i]);
+ }
+
+ id<MTLBuffer> stridedIndicesBuffer = [[device newBufferWithLength:stridedIndicesNumThreads * sizeof(uint)
+ options:0] autorelease];
+ id<MTLComputePipelineState> stridedIndicesPSO = MPSDevice::getInstance()->metalIndexingPSO("kernel_index_offset");
+
+ [computeEncoder setComputePipelineState:stridedIndicesPSO];
+ [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim atIndex:0];
+ [computeEncoder setBuffer:stridedIndicesBuffer offset:0 atIndex:1];
+ [computeEncoder setBytes:inputShapeData.data() length:sizeof(uint32_t) * inputShape.size() atIndex:2];
+ [computeEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3];
+
+ NSUInteger stridedIndicesTGSize = stridedIndicesPSO.maxTotalThreadsPerThreadgroup;
+ if (stridedIndicesTGSize > stridedIndicesNumThreads)
+ stridedIndicesTGSize = stridedIndicesNumThreads;
+
+ MTLSize stridedIndicesThreadGroupSize = MTLSizeMake(stridedIndicesTGSize, 1, 1);
+ [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:stridedIndicesThreadGroupSize];
+
+ const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input.scalar_type());
+ id<MTLComputePipelineState> histogramPSO = histogramPipelineState(device, kernel);
+ [computeEncoder setComputePipelineState:histogramPSO];
+ [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
+ [computeEncoder setBuffer:weightBuffer offset:weightOffset atIndex:1];
+ [computeEncoder setBuffer:outputBuffer
+ offset:thread_histograms.storage_offset() * thread_histograms.element_size()
+ atIndex:2];
+ [computeEncoder setBuffer:stridedIndicesBuffer offset:0 atIndex:3];
+ [computeEncoder setBytes:&D length:sizeof(int64_t) atIndex:4];
+ [computeEncoder setBytes:bin_seq.data() length:sizeof(input_t) * bin_seq_offset atIndex:5];
+ [computeEncoder setBytes:num_bin_edges.data() length:sizeof(int64_t) * D atIndex:6];
+ [computeEncoder setBytes:leftmost_edge.data() length:sizeof(input_t) * D atIndex:7];
+ [computeEncoder setBytes:rightmost_edge.data() length:sizeof(input_t) * D atIndex:8];
+ [computeEncoder setBytes:thread_histograms.strides().data()
+ length:sizeof(int64_t) * thread_hist_sizes.size()
+ atIndex:9];
+ [computeEncoder setBytes:&bin_selection_algorithm length:sizeof(uint8_t) atIndex:10];
+ [computeEncoder setBytes:&has_weight length:sizeof(uint8_t) atIndex:11];
+
+ NSUInteger tgSize = histogramPSO.maxTotalThreadsPerThreadgroup;
+ if (tgSize > numThreads) {
+ tgSize = numThreads;
+ }
+ gridSize = MTLSizeMake(numThreads, 1, 1);
+ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
+ [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
+
+ [computeEncoder endEncoding];
+ mpsStream->synchronize(SyncType::COMMIT);
+ }
+ });
+ at::sum_out(hist_output, thread_histograms, /*dim=*/{0});
+}
+
+template <BIN_SELECTION_ALGORITHM bin_algorithm>
+static void histogramdd_out_mps_template(const Tensor& self,
+ const c10::optional<Tensor>& weight,
+ bool density,
+ Tensor& hist,
+ const TensorList& bin_edges) {
+ hist.fill_(0);
+
+ const int64_t N = self.size(-1);
+ const int64_t M =
+ std::accumulate(self.sizes().begin(), self.sizes().end() - 1, (int64_t)1, std::multiplies<int64_t>());
+
+ const Tensor reshaped_input = self.reshape({M, N});
+
+ const auto reshaped_weight =
+ weight.has_value() ? c10::optional<Tensor>(weight.value().reshape({M})) : c10::optional<Tensor>();
+
+ std::vector<Tensor> bin_edges_contig(bin_edges.size());
+ for (const auto dim : c10::irange(bin_edges_contig.size())) {
+ bin_edges_contig[dim] = bin_edges[dim].contiguous();
+ }
+
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "histogram_mps", [&]() {
+ mps::histogramdd_kernel_impl<scalar_t, bin_algorithm>(hist, bin_edges_contig, reshaped_input, reshaped_weight);
+ });
+
+ /* Divides each bin's value by the total count/weight in all bins,
+ * and by the bin's volume.
+ */
+ if (density) {
+ const auto hist_sum = hist.sum().item();
+ hist.div_(hist_sum);
+
+ /* For each dimension, divides each bin's value
+ * by the bin's length in that dimension.
+ */
+ for (const auto dim : c10::irange(N)) {
+ const auto bin_lengths = bin_edges[dim].diff();
+
+ // Used to reshape bin_lengths to align with the corresponding dimension of hist.
+ std::vector<int64_t> shape(N, 1);
+ shape[dim] = bin_lengths.numel();
+
+ hist.div_(bin_lengths.reshape(shape));
+ }
+ }
+}
+} // namespace mps
+
+static void histogramdd_kernel(const Tensor& self,
+ const c10::optional<Tensor>& weight,
+ bool density,
+ Tensor& hist,
+ const TensorList& bin_edges) {
+ mps::histogramdd_out_mps_template<mps::BINARY_SEARCH>(self, weight, density, hist, bin_edges);
+}
+
+static void histogramdd_linear_kernel(const Tensor& self,
+ const c10::optional<Tensor>& weight,
+ bool density,
+ Tensor& hist,
+ const TensorList& bin_edges,
+ bool local_search) {
+ if (local_search) {
+ // histogramdd codepath: both hist and bin_edges are eventually returned as output,
+ // so we'll keep them consistent
+ mps::histogramdd_out_mps_template<mps::LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
+ self, weight, density, hist, bin_edges);
+ } else {
+ // histc codepath: bin_edges are not returned to the caller
+ mps::histogramdd_out_mps_template<mps::LINEAR_INTERPOLATION>(self, weight, density, hist, bin_edges);
+ }
+}
+
+REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel);
+REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel);
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3392b00..79190bd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9160,46 +9160,46 @@
- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
- CPU: histogram_histc_cpu_out
+ CPU, MPS: histogram_histc_out
CUDA: _histc_out_cuda
- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
variants: method, function
dispatch:
- CPU: histogram_histc_cpu
+ CPU, MPS: histogram_histc
CUDA: _histc_cuda
- func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
dispatch:
- CPU: histogram_out_cpu
+ CPU, MPS: histogram_out
- func: histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
variants: method, function
dispatch:
- CPU: histogram_cpu
+ CPU, MPS: histogram
- func: histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
dispatch:
- CPU: histogram_out_cpu
+ CPU, MPS: histogram_out
- func: histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
variants: method, function
dispatch:
- CPU: histogram_cpu
+ CPU, MPS: histogram
- func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]
dispatch:
- CPU: histogramdd_bin_edges_cpu
+ CPU, MPS: histogramdd_bin_edges
autogen: _histogramdd_bin_edges.out
- func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor
dispatch:
- CPU: histogramdd_cpu
+ CPU, MPS: _histogramdd
autogen: _histogramdd_from_bin_cts.out
- func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor
dispatch:
- CPU: histogramdd_cpu
+ CPU, MPS: _histogramdd
autogen: _histogramdd_from_bin_tensors.out
- func: histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
diff --git a/test/test_mps.py b/test/test_mps.py
index d8ff872..d6e0564 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -89,6 +89,10 @@
'floor_divide': [torch.float16, torch.float32],
# derivative for aten::narrow_copy is not implemented on CPU
'narrow_copy': [torch.float16, torch.float32],
+ # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+ 'histogramdd': [torch.float16, torch.float32],
+ # derivative for aten::histogram is not implemented
+ 'histogram': [torch.float16, torch.float32],
# 'bool' object is not iterable
'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32],
@@ -409,9 +413,6 @@
'geqrf': None,
'nn.functional.grid_sample': None, # Unsupported Border padding mode
'heaviside': None,
- 'histc': None,
- 'histogram': None,
- 'histogramdd': None,
'i0': None,
'igamma': None,
'igammac': None,
@@ -10244,11 +10245,11 @@
self.assertEqual(out, "")
def _get_not_implemented_op(self):
- # This can be changed once we actually implement `torch.histc`
+ # This can be changed once we actually implement `torch.lgamma`
# Should return fn, args, kwargs, string_version
- return (torch.histc,
+ return (torch.lgamma,
torch.tensor([100], device='mps'), {},
- "torch.histc(torch.tensor([4], device='mps', dtype=torch.float))")
+ "torch.lgamma(torch.tensor([4], device='mps', dtype=torch.float))")
def test_error_on_not_implemented(self):
fn, args, kwargs, _ = self._get_not_implemented_op()
diff --git a/torch/csrc/jit/runtime/static/generated_ops.cpp b/torch/csrc/jit/runtime/static/generated_ops.cpp
index 03e3098..fdb5b93 100644
--- a/torch/csrc/jit/runtime/static/generated_ops.cpp
+++ b/torch/csrc/jit/runtime/static/generated_ops.cpp
@@ -3005,13 +3005,12 @@
const auto min = p_node->Input(2).toScalar();
const auto max = p_node->Input(3).toScalar();
if (p_node->Output(0).isNone()) {
- p_node->Output(0) =
- at::native::histogram_histc_cpu(self, bins, min, max);
+ p_node->Output(0) = at::native::histogram_histc(self, bins, min, max);
return;
}
auto& out = p_node->Output(0).toTensor();
fastResizeToZero(out);
- at::native::histogram_histc_cpu_out(self, bins, min, max, out);
+ at::native::histogram_histc_out(self, bins, min, max, out);
};
}
LogAndDumpSchema(n);