add Half support for maxpool on CPU (#98819)
### Testing
Single socket (28 cores):
shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig | 4.12895 | 6.9669 | 5.30297 | 0.55775 | 1.98917 | 0.72233
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL | 0.85093 | 1.88813 | 1.38063 | 5.5742 | 36.5086 | 10.58552
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig | 22.37212 | 37.90383 | 30.94482 | 6.85868 | 10.6116 | 3.9993
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL | 5.41658 | 4.71098 | 4.66578 | 6.69875 | 14.7171 | 5.1167
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig | 10.69831 | 18.0468 | 13.71657 | 2.61192 | 4.96172 | 1.68635
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL | 2.52637 | 2.0096 | 2.0055 | 2.60314 | 7.2093 | 2.49843
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig | 0.47605 | 0.88398 | 0.65326 | 0.06525 | 0.115489 | 0.0674
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d | 0.10902 | 0.25293 | 0.157475 | 0.11386 | 0.53319 | 0.17836
Single core:
shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig | 90.9809 | 163.473 | 126.1276 | 6.57721 | 41.40833 | 11.82505
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL | 9.88405 | 38.39137 | 29.62069 | 7.10636 | 36.97535 | 11.0525
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig | 476.782 | 855.4769 | 648.2248 | 46.6488 | 219.2586 | 67.10599
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL | 80.29271 | 91.33854 | 87.80345 | 48.81692 | 203.9974 | 63.39004
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig | 235.2113 | 419.0799 | 315.4284 | 20.6049 | 107.1524 | 32.39169
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL | 29.47653 | 33.54905 | 32.82823 | 22.59674 | 98.5586 | 30.05763
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig | 7.90684 | 13.9208 | 10.03272 | 0.23725 | 1.35269 | 0.41728
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d | 2.33638 | 3.36894 | 2.64635 | 0.26535 | 1.244 | 0.38895
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98819
Approved by: https://github.com/mingfeima, https://github.com/mikaylagawarecki
diff --git a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
index d53678c..9380aa9 100644
--- a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
+++ b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
@@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
+#include <ATen/cpu/vec/functional.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
@@ -60,13 +61,15 @@
return ret;
}
-template <typename scalar_t, typename accscalar_t>
-inline void compute_internal(
+template <typename scalar_t, typename opmath_t>
+inline
+typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
+compute_internal(
scalar_t* input_data,
scalar_t* out_data,
- accscalar_t* max_ptr,
- vec::int_same_size_t<accscalar_t>* index_ptr,
- int64_t* ind,
+ opmath_t* max_ptr,
+ vec::int_same_size_t<opmath_t>* index_ptr,
+ int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
@@ -78,7 +81,7 @@
int64_t dilationH,
int64_t dilationW) {
using Vec = vec::Vectorized<scalar_t>;
- using integer_t = vec::int_same_size_t<accscalar_t>;
+ using integer_t = vec::int_same_size_t<opmath_t>;
using iVec = vec::Vectorized<integer_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
@@ -130,13 +133,16 @@
}
}
-template <>
-inline void compute_internal(
- BFloat16* input_data,
- BFloat16* out_data,
- float* max_ptr,
- int32_t* index_ptr,
- int64_t* ind,
+// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
+template <typename scalar_t, typename opmath_t>
+inline
+typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
+compute_internal(
+ scalar_t* input_data,
+ scalar_t* out_data,
+ opmath_t* max_ptr,
+ vec::int_same_size_t<opmath_t>* index_ptr,
+ int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
@@ -147,12 +153,12 @@
int64_t dilationD,
int64_t dilationH,
int64_t dilationW) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
+ using Vec = vec::Vectorized<scalar_t>;
+ using fVec = vec::Vectorized<opmath_t>;
using iVec = vec::Vectorized<int32_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
- fVec out_vec = fVec(-std::numeric_limits<float>::infinity());
+ fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
int64_t d1 = 0;
for (; d1 < len; d1 += fVec::size()) {
index0_vec.store(index_ptr + d1);
@@ -160,21 +166,21 @@
}
for (; d1 < size; d1++) {
ind[d1] = ih0 * input_width + iw0;
- max_ptr[d1] = -std::numeric_limits<float>::infinity();
+ max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
- BFloat16* in = input_data + (n * input_depth * input_height * input_width +
+ scalar_t* in = input_data + (n * input_depth * input_height * input_width +
id * input_height * input_width + ih * input_width + iw) * channels;
int64_t d2 = 0;
- for (; d2 < len; d2 += bVec::size()) {
+ for (; d2 < len; d2 += Vec::size()) {
iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
- bVec val_bvec = bVec::loadu(in + d2);
+ Vec val_bvec = Vec::loadu(in + d2);
fVec val_fvec0, val_fvec1;
- std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
+ std::tie(val_fvec0, val_fvec1) = convert_to_float<scalar_t>(val_bvec);
iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
@@ -200,9 +206,9 @@
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
- float val = float(in[d2]);
+ opmath_t val = opmath_t(in[d2]);
int64_t maxindex = ind[d2];
- float maxval = max_ptr[d2];
+ opmath_t maxval = max_ptr[d2];
bool mask = (val > maxval) || std::isnan(val);
max_ptr[d2] = mask ? val : maxval;
@@ -211,16 +217,16 @@
}
}
}
- // Convert max values from float to bfloat16
+ // Convert max values from float to bfloat16/half
int64_t d3 = 0;
- for (; d3 < len; d3 += bVec::size()) {
+ for (; d3 < len; d3 += Vec::size()) {
fVec max_fvec0 = fVec::loadu(max_ptr + d3);
fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
- bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
+ Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
max_bvec.store(out_data + d3);
}
for (; d3 < size; d3++) {
- out_data[d3] = BFloat16(max_ptr[d3]);
+ out_data[d3] = scalar_t(max_ptr[d3]);
}
}
@@ -281,7 +287,7 @@
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);
- using accscalar_t = at::opmath_type<scalar_t>;
+ using opmath_t = at::opmath_type<scalar_t>;
// parallel on dim N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (int64_t c = begin; c < end; c++) {
@@ -306,17 +312,18 @@
// compute local max
int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
- accscalar_t maxval;
- if (std::numeric_limits<accscalar_t>::has_infinity) {
- maxval = -std::numeric_limits<accscalar_t>::infinity();
+ opmath_t maxval;
+ if (std::numeric_limits<opmath_t>::has_infinity) {
+ maxval = -std::numeric_limits<opmath_t>::infinity();
} else {
- maxval = std::numeric_limits<accscalar_t>::min();
+ maxval = std::numeric_limits<opmath_t>::min();
}
+
for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
- accscalar_t val = input_ptr[index];
+ opmath_t val = input_ptr[index];
if ((val > maxval) || is_nan(static_cast<double>(val))) {
maxval = val;
maxindex = index;
@@ -396,9 +403,9 @@
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);
- using accscalar_t = at::opmath_type<scalar_t>;
+ using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<scalar_t>;
- using integer_t = vec::int_same_size_t<accscalar_t>;
+ using integer_t = vec::int_same_size_t<opmath_t>;
// for the convience of vectorization, use integer of the same size of scalar_t,
// e.g. int32_t for float, int64_t for double
// need to make sure doesn't overflow
@@ -418,11 +425,11 @@
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<integer_t []> index_buffer(new integer_t[len]);
integer_t * index_ptr = index_buffer.get();
- // temp buffer holding max value with accscalar_t
- std::unique_ptr<accscalar_t []> max_arr;
- accscalar_t* max_ptr = nullptr;
- if (!std::is_same<scalar_t, accscalar_t>::value) {
- max_arr = std::make_unique<accscalar_t[]>(size);
+ // temp buffer holding max value with opmath_t
+ std::unique_ptr<opmath_t []> max_arr;
+ opmath_t* max_ptr = nullptr;
+ if (!std::is_same<scalar_t, opmath_t>::value) {
+ max_arr = std::make_unique<opmath_t[]>(size);
max_ptr = max_arr.get();
}
@@ -598,13 +605,13 @@
int dilationW, int dilationH) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
});
break;
}
case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
});
break;
@@ -637,7 +644,7 @@
DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
indices_sizes.insert(indices_sizes.begin(), 1);
indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
@@ -648,14 +655,14 @@
}
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
});
@@ -672,13 +679,13 @@
const Tensor& indices) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
});
break;
@@ -705,7 +712,7 @@
sizes.insert(sizes.begin(), 1);
grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
});
grad_input.squeeze_(0);
@@ -714,13 +721,13 @@
}
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
});
break;
diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp
index 1f30294..70443e6 100644
--- a/aten/src/ATen/native/cpu/MaxPooling.cpp
+++ b/aten/src/ATen/native/cpu/MaxPooling.cpp
@@ -1,7 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
-#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
+#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/MaxPooling.h>
#include <c10/util/irange.h>
@@ -31,25 +31,30 @@
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
- AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
- const Tensor in = input.contiguous();
- scalar_t* const OP = output.data_ptr<scalar_t>();
- const scalar_t* const IP = in.data_ptr<scalar_t>();
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ ScalarType::BFloat16,
+ ScalarType::Half,
+ input.scalar_type(),
+ "max_pool1d_impl",
+ [&] {
+ const Tensor in = input.contiguous();
+ scalar_t* const OP = output.data_ptr<scalar_t>();
+ const scalar_t* const IP = in.data_ptr<scalar_t>();
- // Value used for padding
- scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
- ? -std::numeric_limits<scalar_t>::infinity()
- : std::numeric_limits<scalar_t>::lowest();
+ // Value used for padding
+ scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
+ ? -std::numeric_limits<scalar_t>::infinity()
+ : std::numeric_limits<scalar_t>::lowest();
- at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
- for (const auto it : c10::irange(begin, end)) {
- scalar_t* op = OP + it * p.OW;
- const scalar_t* ip = IP + it * p.IW;
- std::fill_n(op, p.OW, FILL);
- max_pool1d_kernel(op, ip, p);
- }
- });
- });
+ at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
+ for (const auto it : c10::irange(begin, end)) {
+ scalar_t* op = OP + it * p.OW;
+ const scalar_t* ip = IP + it * p.IW;
+ std::fill_n(op, p.OW, FILL);
+ max_pool1d_kernel(op, ip, p);
+ }
+ });
+ });
}
} // namespace
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index af61358..f92ff745 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -871,7 +871,7 @@
helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d
@onlyNativeDeviceTypes
- @dtypes(torch.bfloat16, torch.float, torch.double)
+ @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@gcIfJetson
def test_max_pool2d_nhwc(self, device, dtype):
@@ -908,7 +908,7 @@
helper(1, 129, 8, 8, 3, stride=2)
@onlyNativeDeviceTypes
- @dtypes(torch.bfloat16, torch.float, torch.double)
+ @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@gcIfJetson
def test_max_pool3d_ndhwc(self, device, dtype):
@@ -974,9 +974,10 @@
helper(0, 79, 4, 4, 4, 3, stride=2)
@onlyCPU
- def test_max_pool_bfloat16(self, device):
- def helper(shape, kernel_size, stride, memory_format):
- input = torch.randn(shape, dtype=torch.float32, device=device).bfloat16()
+ @dtypes(torch.half, torch.bfloat16)
+ def test_max_pool_bfloat16_half(self, device, dtype):
+ def helper(shape, kernel_size, stride, memory_format, dtype):
+ input = torch.randn(shape, dtype=dtype, device=device)
input = input.to(memory_format=memory_format).requires_grad_()
if len(shape) == 4:
pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)
@@ -991,22 +992,22 @@
out2.sum().backward()
self.assertTrue(out.is_contiguous(memory_format=memory_format))
- self.assertEqual(out.dtype, torch.bfloat16)
- self.assertEqual(input.grad.dtype, torch.bfloat16)
- self.assertEqual(out, out2.bfloat16())
+ self.assertEqual(out.dtype, dtype)
+ self.assertEqual(input.grad.dtype, dtype)
+ self.assertEqual(out, out2.to(dtype=dtype))
self.assertEqual(ind, ind2)
- self.assertEqual(input.grad, input2.grad.bfloat16())
+ self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
- helper((4, 30, 8, 8), 7, 1, torch.contiguous_format)
- helper((4, 65, 8, 8), 7, 1, torch.channels_last)
- helper((1, 19, 20, 10), 8, 2, torch.contiguous_format)
- helper((1, 19, 20, 10), 8, 2, torch.channels_last)
- helper((4, 30, 8, 8), 7, 1, torch.contiguous_format)
- helper((4, 65, 8, 8), 7, 1, torch.channels_last)
- helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format)
- helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d)
- helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format)
- helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d)
+ helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
+ helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
+ helper((1, 19, 20, 10), 8, 2, torch.contiguous_format, dtype)
+ helper((1, 19, 20, 10), 8, 2, torch.channels_last, dtype)
+ helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
+ helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
+ helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format, dtype)
+ helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d, dtype)
+ helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format, dtype)
+ helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d, dtype)
@onlyCUDA
@gcIfJetson
diff --git a/test/test_mps.py b/test/test_mps.py
index 7a06efc..1159e8d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10599,6 +10599,9 @@
'nn.functional.triplet_margin_loss',
'nn.functional.triplet_margin_with_distance_loss',
'round', 'xlogy', 'addcmul',
+ 'nn.functional.max_pool2d',
+ 'nn.functional.gelu',
+ 'nn.functional.glu',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
@@ -10756,10 +10759,6 @@
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
- if op.name in ["nn.functional.gelu", "nn.functional.glu"] and dtype == torch.float16:
- atol = 1e-3
- rtol = 1e-3
-
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7ce26ae..7bb3612 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13114,7 +13114,7 @@
check_batched_forward_grad=False,
# TODO: add shape checks
assert_jit_shape_analysis=False,
- dtypes=floating_types_and(torch.bfloat16),
+ dtypes=floating_types_and(torch.bfloat16, torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
skips=(
# Pre-existing condition; Needs to be fixed
@@ -13139,7 +13139,7 @@
# got: Batching rule not implemented for aten::flatten.using_ints
check_batched_forward_grad=False,
assert_jit_shape_analysis=True,
- dtypes=all_types_and(torch.bfloat16),
+ dtypes=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
error_inputs_func=error_inputs_max_pool2d,
sample_inputs_func=sample_inputs_max_pool),
@@ -13156,8 +13156,7 @@
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
assert_jit_shape_analysis=False,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.bfloat16, torch.float16),
sample_inputs_func=sample_inputs_max_pool,
skips=(
# We've defined a custom op here, and we don't handle the case where we receive an out kwarg
@@ -13179,7 +13178,7 @@
check_batched_forward_grad=False,
# TODO: add shape checks
assert_jit_shape_analysis=False,
- dtypes=all_types_and(torch.bfloat16),
+ dtypes=all_types_and(torch.bfloat16, torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
# TODO: investigate nondeterminism
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,