[MPS] Fix correctness issues with Upsample 1D and 2D (#91669)
- Implemented following new ops: upsample_nearest1d_backward
upsample_nearest_exact1d
upsample_nearest_exact1d_backward
- Moved Upsample code from Shape.mm to Upsample.mm
- Fallback to CPU for nearest mode on Monterey
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91669
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm
index e5dfde1..69dd47f 100644
--- a/aten/src/ATen/mps/MPSFallback.mm
+++ b/aten/src/ATen/mps/MPSFallback.mm
@@ -62,6 +62,7 @@
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
+ m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
}
} // namespace at
diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h
index d4e8112..92ee725 100644
--- a/aten/src/ATen/native/UpSample.h
+++ b/aten/src/ATen/native/UpSample.h
@@ -56,7 +56,7 @@
inline c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx) {
if (!scales) {
- return nullopt;
+ return c10::nullopt;
}
return scales->at(idx);
}
diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
index 71d6748..19434c0 100644
--- a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
+++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
@@ -3,18 +3,89 @@
// TODO: Remove me when moved to MacOS 13
@interface MPSGraph (VenturaOps)
-- (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor
- axis:(NSInteger)axis
- name:(NSString *)name;
-- (MPSGraphTensor *)sortWithTensor:(MPSGraphTensor *)tensor
- axis:(NSInteger)axis
- name:(NSString *)name;
+#if !defined(__MAC_13_0) && \
+ (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
-- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
- axis:(NSInteger)axis
- name:(NSString *)name;
+typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
+{
+ MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
+ MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
+ MPSGraphResizeNearestRoundingModeCeil = 2L,
+ MPSGraphResizeNearestRoundingModeFloor = 3L,
+ MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
+ MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
+};
+#endif
-- (MPSGraphTensor *)inverseOfTensor: (MPSGraphTensor *)tensor
- name:(NSString *)name;
-@end
+- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor
+ axis:(NSInteger)axis
+ name:(NSString * _Nullable)name;
+
+- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor
+ axis:(NSInteger)axis
+ name:(NSString * _Nullable)name;
+
+- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor
+ axis:(NSInteger)axis
+ name:(NSString * _Nullable)name;
+
+- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
+ name:(NSString * _Nullable)name;
+
+- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
+ sizeTensor:(MPSGraphTensor * _Nonnull) size
+ nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
+ centerResult:(BOOL) centerResult
+ alignCorners:(BOOL) alignCorners
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
+ sizeTensor:(MPSGraphTensor * _Nonnull) size
+ scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
+ nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
+ sizeTensor:(MPSGraphTensor * _Nonnull) size
+ centerResult:(BOOL) centerResult
+ alignCorners:(BOOL) alignCorners
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor
+ sizeTensor:(MPSGraphTensor * _Nonnull) size
+ scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
+ input:(MPSGraphTensor * _Nonnull) input
+ nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
+ centerResult:(BOOL) centerResult
+ alignCorners:(BOOL) alignCorners
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
+ input:(MPSGraphTensor * _Nonnull) input
+ scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
+ nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
+ input:(MPSGraphTensor * _Nonnull) input
+ centerResult:(BOOL) centerResult
+ alignCorners:(BOOL) alignCorners
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+
+- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient
+ input:(MPSGraphTensor * _Nonnull) input
+ scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset
+ layout:(MPSGraphTensorNamedDataLayout) layout
+ name:(NSString * _Nullable) name;
+@end
\ No newline at end of file
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index f491f2f..3190991 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -1,18 +1,10 @@
// Copyright © 2022 Apple Inc.
-#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
-#include <ATen/Tensor.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
-#include <ATen/mps/MPSStream.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/mps/OperationUtils.h>
-#include <c10/core/MemoryFormat.h>
-#include <c10/util/Optional.h>
-#include <torch/library.h>
namespace at {
namespace native {
@@ -460,307 +452,5 @@
}
-void upsample_backward_out_mps(const Tensor& grad_output,
- IntArrayRef output_size,
- IntArrayRef input_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& grad_input,
- MPSGraphResizeMode requested_mode,
- bool requested_align_corners
- )
-{
- using namespace mps;
- int64_t input_dims = input_size.size();
-
- TORCH_CHECK((input_dims == 4),
- "NCHW tensor expected for input");
-
- struct CachedGraph : public MPSCachedGraph {
- CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *gradInputTensor = nil, *gradOutputTensor = nil;
- };
- MPSGraphCache* cache_ = MPSGraphCache::getInstance();
- /* sizes */
- int64_t output_height = output_size[0];
- int64_t output_width = output_size[1];
-
- int64_t input_n = input_size[0];
- int64_t input_c = input_size[1];
- int64_t input_height = input_size[2];
- int64_t input_width = input_size[3];
-
- @autoreleasepool {
- MPSShape* output_shape = getMPSShape(grad_output);
- string key = string("upsample_backward:") + mps::getMPSShapeString(output_shape) + ":" +
- getMPSTypeString(grad_output.scalar_type()) +
- ":oh" + to_string(output_height) + ":ow" + to_string(output_width) +
- ":ih" + to_string(input_height) + ":iw" + to_string(input_width) +
- ":mode" + to_string(requested_mode);
-
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- if(!cachedGraph) {
- cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
-
- CachedGraph *newCachedGraph = nil;
- @autoreleasepool {
- MPSGraph* mpsGraph = make_mps_graph();
- newCachedGraph = new CachedGraph(mpsGraph);
-
- newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_input.scalar_type()), output_shape);
- MPSGraphTensor * shapeTensor = [mpsGraph constantWithScalar:0
- shape:@[[NSNumber numberWithLong: input_n],
- [NSNumber numberWithLong: input_c],
- [NSNumber numberWithLong:input_height],
- [NSNumber numberWithLong:input_width]]
- dataType:getMPSDataType(grad_output.scalar_type())];
-
- newCachedGraph->gradInputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->gradOutputTensor
- input: shapeTensor
- mode: requested_mode
- centerResult: true
- alignCorners: requested_align_corners
- layout: MPSGraphTensorNamedDataLayoutNCHW
- name: nil];
-
- }
- return newCachedGraph;
- }));
- }
- Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output);
- Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor, grad_input);
-
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
- gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
- };
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
- gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
- };
- runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
- }
-}
-
-TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) (
- const Tensor& grad_output,
- IntArrayRef output_size,
- IntArrayRef input_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& grad_input)
-{
- upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeNearest, false);
-}
-
-TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) (
- const Tensor& grad_output,
- IntArrayRef output_size,
- IntArrayRef input_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& grad_input)
-{
- upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeNearest, false);
-}
-
-TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) (
- const Tensor& grad_output,
- IntArrayRef output_size,
- IntArrayRef input_size,
- bool align_corners,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& grad_input)
-{
- upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeBilinear, align_corners);
-}
-
-void upsample_out_mps(const Tensor& input,
- IntArrayRef output_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& output,
- MPSGraphResizeMode requested_mode,
- bool requested_align_corners)
-{
- // Get stream
- using namespace mps;
- struct CachedGraph : public MPSCachedGraph {
- CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
- };
- MPSGraphCache* cache_ = MPSGraphCache::getInstance();
-
- /* sizes */
- int64_t output_height = output_size[0];
- int64_t output_width = output_size[1];
- @autoreleasepool {
- MPSShape* input_shape = getMPSShape(input);
- string key = string("upsample_2d:") + mps::getMPSShapeString(input_shape) + ":" +
- getMPSTypeString(input.scalar_type()) +
- ":h" + to_string(output_height) + ":w" + to_string(output_width) +
- ":mode" + to_string(requested_mode);
-
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- if(!cachedGraph) {
- cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
-
- CachedGraph *newCachedGraph = nil;
-
- @autoreleasepool {
- MPSGraph* mpsGraph = make_mps_graph();
- newCachedGraph = new CachedGraph(mpsGraph);
-
- newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
- newCachedGraph->outputTensor = [mpsGraph resizeTensor:newCachedGraph->inputTensor
- size:@[ @(output_height), @(output_width)]
- mode:requested_mode
- centerResult: true
- alignCorners: requested_align_corners
- layout: MPSGraphTensorNamedDataLayoutNCHW
- name:nil];
- }
- return newCachedGraph;
- }));
- }
- Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
-
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
- inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
- };
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
- outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
- };
- runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
- }
-}
-
-TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (
- const Tensor& input,
- IntArrayRef output_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& output)
-{
- // Note: this differs from the CPU implementation in the way
- // ties are resolved wrt to nearest mostly in cases where the scale
- // is not an integer.
- // Example:
- // For upsampling from (2, 5) to (2, 16)
- // MPS:
- // tensor([[[[0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.],
- // [5., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]])
- // CPU:
- // tensor([[[[0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4.],
- // [5., 5., 5., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]])
- using namespace mps;
- upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeNearest, false);
-}
-
-
-TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (
- const Tensor& input,
- IntArrayRef output_size,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& output)
-{
- // Note: this differs from the CPU implementation in the way
- // ties are resolved wrt to nearest mostly in cases where the scale
- // is not an integer.
- // Example:
- // For upsampling from (2, 5) to (2, 16)
- // MPS:
- // tensor([[[[0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.],
- // [5., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]])
- // CPU:
- // tensor([[[[0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4.],
- // [5., 5., 5., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]])
- using namespace mps;
- upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeNearest, false);
-}
-
-TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
- const Tensor& input,
- IntArrayRef output_size,
- bool align_corners,
- c10::optional<double> scales_h,
- c10::optional<double> scales_w,
- const Tensor& output)
-{
- using namespace mps;
- upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeBilinear, align_corners);
-}
-
-void upsample1d_out_mps(const Tensor& input,
- IntArrayRef output_size,
- c10::optional<double> scales,
- const Tensor& output,
- MPSGraphResizeMode requested_mode)
-{
- // Get stream
- using namespace mps;
- using CachedGraph = MPSUnaryCachedGraph;
- MPSGraphCache* cache_ = MPSGraphCache::getInstance();
-
- /* sizes */
- int64_t out_size = output_size[0];
- @autoreleasepool {
- MPSShape* input_shape = getMPSShape(input);
- string key = string("upsample_1d:") + mps::getMPSShapeString(input_shape) + ":" +
- getMPSTypeString(input.scalar_type()) +
- ":size" + to_string(out_size) +
- ":mode" + to_string(requested_mode);
-
- CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
- if(!cachedGraph) {
- cachedGraph = static_cast<CachedGraph*>(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
-
- CachedGraph *newCachedGraph = nil;
-
- @autoreleasepool {
- MPSGraph* mpsGraph = make_mps_graph();
- newCachedGraph = new CachedGraph(mpsGraph);
-
- newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
- newCachedGraph->outputTensor_ = [mpsGraph resizeTensor:newCachedGraph->inputTensor_
- size:@[ @(out_size), @(1)]
- mode:requested_mode
- centerResult: true
- alignCorners: true
- layout: MPSGraphTensorNamedDataLayoutCHW
- name:nil];
- }
- return newCachedGraph;
- }));
- }
- Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
-
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
- inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
- };
- NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
- outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
- };
- runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
- }
-}
-
-
-TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
- const Tensor& input,
- IntArrayRef output_size,
- c10::optional<double> scales,
- const Tensor& output)
-{
- using namespace mps;
- upsample1d_out_mps(input, output_size, scales, output, MPSGraphResizeNearest);
-}
-
-
-
-
-
} // namespace native
-} // namespace at
+} // namespace at
\ No newline at end of file
diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm
new file mode 100644
index 0000000..0f05582
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/UpSample.mm
@@ -0,0 +1,393 @@
+// Copyright © 2023 Apple Inc.
+
+#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/mps/MPSGraphVenturaOps.h>
+#include <ATen/native/UpSample.h>
+
+namespace at {
+namespace native {
+namespace mps {
+
+// Upsampling operations (1D/2D forward and backward)
+// supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact'
+void upsample_out_template(const Tensor& input,
+ IntArrayRef output_size,
+ c10::optional<IntArrayRef> input_size_opt, // only used for backward pass
+ c10::optional<double> scale_h_opt,
+ c10::optional<double> scale_w_opt,
+ const Tensor& output,
+ bool align_corners,
+ const c10::string_view resize_mode_str) {
+ if (input.numel() == 0) {
+ return;
+ }
+ const auto input_dim = input.sizes();
+ if (input_dim.size() <= 3) {
+ native::upsample_1d_common_check(input.sizes(), output_size);
+ } else {
+ native::upsample_2d_common_check(input.sizes(), output_size);
+ }
+ bool centerResults = false;
+ MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
+ MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
+ MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ?
+ MPSGraphTensorNamedDataLayoutNCHW :
+ MPSGraphTensorNamedDataLayoutCHW;
+ if (resize_mode_str == "nearest") {
+ resizeMode = MPSGraphResizeNearest;
+ } else if (resize_mode_str == "bilinear") {
+ resizeMode = MPSGraphResizeBilinear;
+ centerResults = true;
+ } else if (resize_mode_str == "nearest-exact") {
+ centerResults = true;
+ nearestRoundingMode = MPSGraphResizeNearestRoundingModeRoundPreferCeil;
+ } else {
+ AT_ERROR("Unsupported resize mode ", resize_mode_str);
+ }
+
+ const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
+ const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0];
+ const int64_t output_height = output_size.size() > 1 ? output_size[0] : 1;
+ const float scale_w = (scale_w_opt.value_or(0.) > 0.) ? static_cast<float>(scale_w_opt.value()) : 0.;
+ const float scale_h = (scale_h_opt.value_or(0.) > 0.) ? static_cast<float>(scale_h_opt.value()) : 1.;
+ const float offset_y = centerResults ? (scale_h - 1.0f) / 2.0f : 0.0f;
+ const float offset_x = centerResults ? (scale_w - 1.0f) / 2.0f : 0.0f;
+
+ IntArrayRef input_size;
+ const bool is_backward_pass = input_size_opt.has_value();
+ if (is_backward_pass) {
+ input_size = input_size_opt.value();
+ }
+ struct CachedGraph : public MPSCachedGraph {
+ CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
+ MPSGraphTensor *outputSizeTensor = nil;
+ };
+ MPSStream* stream = getCurrentMPSStream();
+
+ @autoreleasepool {
+ string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
+ getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
+ (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
+
+ MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+ CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+ if(!cachedGraph) {
+ cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
+ CachedGraph *newCachedGraph = nil;
+ @autoreleasepool {
+ MPSGraph* mpsGraph = make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
+ newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(2)]);
+
+ MPSGraphTensor* scaleOffsetTensor = nullptr;
+ MPSGraphTensor* inputSizeTensor = nullptr;
+
+ if (scale_w > 0.0) {
+ const float outScales[4] = {scale_h, scale_w, offset_y, offset_x};
+ scaleOffsetTensor = [mpsGraph constantWithData: [NSData dataWithBytes: outScales length: sizeof(outScales)]
+ shape: @[@4]
+ dataType: MPSDataTypeFloat32];
+ }
+ if (is_backward_pass) {
+ std::vector<NSNumber*> inputSizeVec(4);
+ inputSizeVec[0] = @(input_size[0]);
+ inputSizeVec[1] = @(input_size[1]);
+ inputSizeVec[2] = @(input_size[2]);
+ inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1);
+ inputSizeTensor = [mpsGraph constantWithScalar: 0
+ shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()]
+ dataType: getMPSDataType(input.scalar_type())];
+ }
+ if (is_macOS_13_0_or_newer) {
+ if (!is_backward_pass) {
+ if (scaleOffsetTensor && !align_corners) {
+ if (resizeMode == MPSGraphResizeNearest) {
+ newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
+ sizeTensor: newCachedGraph->outputSizeTensor
+ scaleOffsetTensor: scaleOffsetTensor
+ nearestRoundingMode: nearestRoundingMode
+ layout: dataLayout
+ name: nil];
+ } else { // bilinear forward
+ newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
+ sizeTensor: newCachedGraph->outputSizeTensor
+ scaleOffsetTensor: scaleOffsetTensor
+ layout: dataLayout
+ name: nil];
+ }
+ } else { // scaleOffsetTensor == nil || align_corners
+ if (resizeMode == MPSGraphResizeNearest) {
+ newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor
+ sizeTensor: newCachedGraph->outputSizeTensor
+ nearestRoundingMode: nearestRoundingMode
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ } else { // bilinear forward
+ newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor
+ sizeTensor: newCachedGraph->outputSizeTensor
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ }
+ }
+ } else { // is_backward_pass == true
+ if (scaleOffsetTensor && !align_corners) {
+ if (resizeMode == MPSGraphResizeNearest) {
+ newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
+ input: inputSizeTensor
+ scaleOffsetTensor: scaleOffsetTensor
+ nearestRoundingMode: nearestRoundingMode
+ layout: dataLayout
+ name: nil];
+ } else { // bilinear backward
+ newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
+ input: inputSizeTensor
+ scaleOffsetTensor: scaleOffsetTensor
+ layout: dataLayout
+ name: nil];
+ }
+ } else { // scaleOffsetTensor == nil || align_corners
+ if (resizeMode == MPSGraphResizeNearest) {
+ newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor
+ input: inputSizeTensor
+ nearestRoundingMode: nearestRoundingMode
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ } else { // bilinear backward
+ newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor
+ input: inputSizeTensor
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ }
+ }
+ }
+ } else { // if macOS version < 13.0 (for backwards compatibility)
+ if (!is_backward_pass) {
+ newCachedGraph->outputTensor = [mpsGraph resizeTensor: newCachedGraph->inputTensor
+ sizeTensor: newCachedGraph->outputSizeTensor
+ mode: resizeMode
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ } else {
+ newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->inputTensor
+ input: inputSizeTensor
+ mode: resizeMode
+ centerResult: centerResults
+ alignCorners: align_corners
+ layout: dataLayout
+ name: nil];
+ }
+ }
+ }
+ return newCachedGraph;
+ });
+ }
+ MPSNDArrayDescriptor *sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(2)]];
+ MPSNDArray *sizeNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: sizeDesc] autorelease];
+ [sizeNDArray writeBytes: (int32_t[]) {(int32_t)output_height, (int32_t)output_width} strideBytes: nil];
+ MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
+
+ Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
+
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
+ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+ cachedGraph->outputSizeTensor : sizeTensorData,
+ };
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
+ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+ };
+ runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ }
+}
+
+} // namespace mps
+
+static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale)
+{
+ static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
+ if (!is_macOS_13_0_or_newer) {
+ // passing scale factors to MPS's resize APIs is not supported on macOS < 13
+ if (scale.value_or(0.) > 0.) {
+ TORCH_WARN_ONCE("MPS: passing scale factor to upsample ops is supported natively starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performance implications.");
+ return false;
+ // nearest mode on Monterey uses round() to compute source indices which
+ // is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey.
+ // The nearest mode should work fine on Ventura.
+ } else if (resize_mode_str == "nearest" || resize_mode_str == "nearest-exact") {
+ TORCH_WARN_ONCE("MPS: '", resize_mode_str, "' mode upsampling is supported natively starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performance implications.");
+ return false;
+ }
+ }
+ return true;
+}
+
+TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) (
+ const Tensor& input,
+ IntArrayRef output_size,
+ c10::optional<double> scale,
+ const Tensor& output)
+{
+ if (check_mps_compatibility("nearest", scale)) {
+ mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(output) = at::upsample_nearest1d(input.to("cpu"), output_size, scale).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) (
+ const Tensor& grad_output,
+ IntArrayRef output_size,
+ IntArrayRef input_size,
+ c10::optional<double> scale,
+ const Tensor& grad_input)
+{
+ if (check_mps_compatibility("nearest", scale)) {
+ mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) (
+ const Tensor& input,
+ IntArrayRef output_size,
+ c10::optional<double> scale,
+ const Tensor& output)
+{
+ if (check_mps_compatibility("nearest-exact", scale)) {
+ mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(output) = at::_upsample_nearest_exact1d(input.to("cpu"), output_size, scale).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) (
+ const Tensor& grad_output,
+ IntArrayRef output_size,
+ IntArrayRef input_size,
+ c10::optional<double> scale,
+ const Tensor& grad_input)
+{
+ if (check_mps_compatibility("nearest-exact", scale)) {
+ mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (
+ const Tensor& input,
+ IntArrayRef output_size,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& output)
+{
+ if (check_mps_compatibility("nearest", scales_w)) {
+ mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) (
+ const Tensor& grad_output,
+ IntArrayRef output_size,
+ IntArrayRef input_size,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& grad_input)
+{
+ if (check_mps_compatibility("nearest", scales_w)) {
+ mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (
+ const Tensor& input,
+ IntArrayRef output_size,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& output)
+{
+ if (check_mps_compatibility("nearest-exact", scales_w)) {
+ mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) (
+ const Tensor& grad_output,
+ IntArrayRef output_size,
+ IntArrayRef input_size,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& grad_input)
+{
+ if (check_mps_compatibility("nearest-exact", scales_w)) {
+ mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (
+ const Tensor& input,
+ IntArrayRef output_size,
+ bool align_corners,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& output)
+{
+ if (check_mps_compatibility("bilinear", scales_w)) {
+ mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) (
+ const Tensor& grad_output,
+ IntArrayRef output_size,
+ IntArrayRef input_size,
+ bool align_corners,
+ c10::optional<double> scales_h,
+ c10::optional<double> scales_w,
+ const Tensor& grad_input)
+{
+ if (check_mps_compatibility("bilinear", scales_w)) {
+ mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear");
+ } else {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ const_cast<Tensor&>(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps");
+ }
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 1aca9c6..ce1bee6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11684,6 +11684,7 @@
dispatch:
CPU: _upsample_nearest_exact1d_out_cpu
CUDA: _upsample_nearest_exact1d_out_cuda
+ MPS: _upsample_nearest_exact1d_out_mps
- func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
python_module: nn
@@ -11699,6 +11700,7 @@
dispatch:
CPU: upsample_nearest1d_backward_out_cpu
CUDA: upsample_nearest1d_backward_out_cuda
+ MPS: upsample_nearest1d_backward_out_mps
- func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
@@ -11706,6 +11708,7 @@
dispatch:
CPU: _upsample_nearest_exact1d_backward_out_cpu
CUDA: _upsample_nearest_exact1d_backward_out_cuda
+ MPS: _upsample_nearest_exact1d_backward_out_mps
- func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
python_module: nn
diff --git a/test/test_mps.py b/test/test_mps.py
index e953906..360ebf0 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4033,26 +4033,6 @@
helper((1, 5))
helper((5, 9, 7, 4))
- def test_upsample_nearest_exact2d(self):
- def helper(N, C, H, W):
- inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
- requires_grad=True).reshape(N, C, H, W)
- inputCPU.retain_grad()
- inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
-
- outputCPU = torch.nn.functional.interpolate(inputCPU, size=(5, 5), mode='nearest-exact')
- outputMPS = torch.nn.functional.interpolate(inputMPS, size=(5, 5), mode='nearest-exact')
-
- self.assertEqual(outputCPU, outputMPS)
-
- outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
- outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
-
- self.assertEqual(inputCPU.grad, inputMPS.grad)
-
- helper(1, 1, 4, 4)
- helper(7, 5, 3, 2)
-
def test_upsample_nearest2d(self):
def helper(N, C, H, W):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
@@ -4118,19 +4098,49 @@
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
- def test_upsample_nearest1d(self):
- def helper(N, C, H, W):
- inputCPU = torch.arange(C * H * W, device='cpu', dtype=torch.float,
- requires_grad=True).reshape(C, H, W)
- inputMPS = inputCPU.detach().clone().to('mps')
+ def test_interpolate(self):
+ def helper(shape, output_size, scales, mode, align_corners=False):
+ inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ inputCPU.retain_grad()
+ inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
- outputCPU = torch.nn.functional.interpolate(inputCPU, scale_factor=2.0, mode='nearest')
- outputMPS = torch.nn.functional.interpolate(inputMPS, scale_factor=2.0, mode='nearest')
+ # align_corners is used for 2D interpolation only
+ if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
+ if scales is not None:
+ outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
+ outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
+ else:
+ outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
+ outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
+ elif scales is not None:
+ outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
+ outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
+ else:
+ outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
+ outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
self.assertEqual(outputCPU, outputMPS)
- helper(1, 1, 4, 4)
- helper(7, 5, 3, 2)
+ # backward pass (chose 0.6 just to have the grad_output != 1)
+ outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
+ outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
+ self.assertEqual(inputCPU.grad, inputMPS.grad)
+
+ # 1D interpolation
+ for mode in ['nearest', 'nearest-exact']:
+ helper([2, 3, 4], [3], None, mode) # downsample with size
+ helper([2, 3, 4], [6], None, mode) # upsample with size
+ helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
+ helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
+ # 2D interpolation
+ for mode in ['nearest', 'nearest-exact', 'bilinear']:
+ helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
+ helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
+ helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
+ helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
+ # align_corners=True
+ helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
+ helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
# Test concat forward
def test_cat1(self):
@@ -8234,6 +8244,7 @@
'nn.functional.triplet_margin_loss': ['f32', 'i16', 'i32', 'i64'],
'nn.functional.triplet_margin_with_distance_loss': ['f32', 'i16', 'i32', 'i64'],
'nn.functional.upsample_bilinear': ['f32'],
+ 'nn.functional.upsample_nearest': ['f32'],
'norm': ['f32', 'f16'],
'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'pow': ['f16'],