add Half support for maxpool on CPU (#98819)

### Testing
Single socket (28 cores):

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: contig | 4.12895 | 6.9669 | 5.30297 | 0.55775 | 1.98917 | 0.72233
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: CL | 0.85093 | 1.88813 | 1.38063 | 5.5742 | 36.5086 | 10.58552
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: contig | 22.37212 | 37.90383 | 30.94482 | 6.85868 | 10.6116 | 3.9993
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: CL | 5.41658 | 4.71098 | 4.66578 | 6.69875 | 14.7171 | 5.1167
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: contig | 10.69831 | 18.0468 | 13.71657 | 2.61192 | 4.96172 | 1.68635
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: CL | 2.52637 | 2.0096 | 2.0055 | 2.60314 | 7.2093 | 2.49843
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: contig | 0.47605 | 0.88398 | 0.65326 | 0.06525 | 0.115489 | 0.0674
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: CL3d | 0.10902 | 0.25293 | 0.157475 | 0.11386 | 0.53319 | 0.17836

Single core:

shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms
-- | -- | -- | -- | -- | -- | --
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: contig | 90.9809 | 163.473 | 126.1276 | 6.57721 | 41.40833 | 11.82505
size: (1, 56, 264, 264), kernel: 3,   stride: 1, mem_format: CL | 9.88405 | 38.39137 | 29.62069 | 7.10636 | 36.97535 | 11.0525
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: contig | 476.782 | 855.4769 | 648.2248 | 46.6488 | 219.2586 | 67.10599
size: (32, 16, 200, 200), kernel: 3,   stride: 1, mem_format: CL | 80.29271 | 91.33854 | 87.80345 | 48.81692 | 203.9974 | 63.39004
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: contig | 235.2113 | 419.0799 | 315.4284 | 20.6049 | 107.1524 | 32.39169
size: (32, 32, 100, 100), kernel: 3,   stride: 1, mem_format: CL | 29.47653 | 33.54905 | 32.82823 | 22.59674 | 98.5586 | 30.05763
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: contig | 7.90684 | 13.9208 | 10.03272 | 0.23725 | 1.35269 | 0.41728
size: (4, 19, 10, 16, 16), kernel: 3,   stride: 1, mem_format: CL3d | 2.33638 | 3.36894 | 2.64635 | 0.26535 | 1.244 | 0.38895

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98819
Approved by: https://github.com/mingfeima, https://github.com/mikaylagawarecki
diff --git a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
index d53678c..9380aa9 100644
--- a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
+++ b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
@@ -5,6 +5,7 @@
 #include <ATen/Dispatch.h>
 #include <ATen/Parallel.h>
 #include <ATen/cpu/vec/vec.h>
+#include <ATen/cpu/vec/functional.h>
 #include <ATen/native/Pool.h>
 #include <ATen/native/cpu/utils.h>
 #include <c10/util/irange.h>
@@ -60,13 +61,15 @@
   return ret;
 }
 
-template <typename scalar_t, typename accscalar_t>
-inline void compute_internal(
+template <typename scalar_t, typename opmath_t>
+inline
+typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
+compute_internal(
   scalar_t* input_data,
   scalar_t* out_data,
-  accscalar_t* max_ptr,
-  vec::int_same_size_t<accscalar_t>* index_ptr,
-  int64_t*  ind,
+  opmath_t* max_ptr,
+  vec::int_same_size_t<opmath_t>* index_ptr,
+  int64_t* ind,
   int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
   int64_t n,
   int64_t len,
@@ -78,7 +81,7 @@
   int64_t dilationH,
   int64_t dilationW) {
   using Vec = vec::Vectorized<scalar_t>;
-  using integer_t = vec::int_same_size_t<accscalar_t>;
+  using integer_t = vec::int_same_size_t<opmath_t>;
   using iVec = vec::Vectorized<integer_t>;
   // Pass I: init out lane
   iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
@@ -130,13 +133,16 @@
   }
 }
 
-template <>
-inline void compute_internal(
-  BFloat16* input_data,
-  BFloat16* out_data,
-  float* max_ptr,
-  int32_t* index_ptr,
-  int64_t*  ind,
+// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
+template <typename scalar_t, typename opmath_t>
+inline
+typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
+compute_internal(
+  scalar_t* input_data,
+  scalar_t* out_data,
+  opmath_t* max_ptr,
+  vec::int_same_size_t<opmath_t>* index_ptr,
+  int64_t* ind,
   int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
   int64_t n,
   int64_t len,
@@ -147,12 +153,12 @@
   int64_t dilationD,
   int64_t dilationH,
   int64_t dilationW) {
-  using bVec = vec::Vectorized<BFloat16>;
-  using fVec = vec::Vectorized<float>;
+  using Vec = vec::Vectorized<scalar_t>;
+  using fVec = vec::Vectorized<opmath_t>;
   using iVec = vec::Vectorized<int32_t>;
   // Pass I: init out lane
   iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
-  fVec out_vec = fVec(-std::numeric_limits<float>::infinity());
+  fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
   int64_t d1 = 0;
   for (; d1 < len; d1 += fVec::size()) {
     index0_vec.store(index_ptr + d1);
@@ -160,21 +166,21 @@
   }
   for (; d1 < size; d1++) {
     ind[d1] = ih0 * input_width + iw0;
-    max_ptr[d1] = -std::numeric_limits<float>::infinity();
+    max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
   }
   // Pass II: compute local max
   for (int64_t id = id0; id < id1; id += dilationD) {
     for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
       for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
-        BFloat16* in = input_data + (n * input_depth * input_height * input_width +
+        scalar_t* in = input_data + (n * input_depth * input_height * input_width +
             id * input_height * input_width + ih * input_width + iw) * channels;
 
         int64_t d2 = 0;
-        for (; d2 < len; d2 += bVec::size()) {
+        for (; d2 < len; d2 += Vec::size()) {
           iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
-          bVec val_bvec = bVec::loadu(in + d2);
+          Vec val_bvec = Vec::loadu(in + d2);
           fVec val_fvec0, val_fvec1;
-          std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
+          std::tie(val_fvec0, val_fvec1) = convert_to_float<scalar_t>(val_bvec);
 
           iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
           iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
@@ -200,9 +206,9 @@
         }
         for (; d2 < size; d2++) {
           int64_t index = id * input_height * input_width + ih * input_width + iw;
-          float val = float(in[d2]);
+          opmath_t val = opmath_t(in[d2]);
           int64_t maxindex = ind[d2];
-          float maxval = max_ptr[d2];
+          opmath_t maxval = max_ptr[d2];
 
           bool mask = (val > maxval) || std::isnan(val);
           max_ptr[d2] = mask ? val : maxval;
@@ -211,16 +217,16 @@
       }
     }
   }
-  // Convert max values from float to bfloat16
+  // Convert max values from float to bfloat16/half
   int64_t d3 = 0;
-  for (; d3 < len; d3 += bVec::size()) {
+  for (; d3 < len; d3 += Vec::size()) {
     fVec max_fvec0 = fVec::loadu(max_ptr + d3);
     fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
-    bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
+    Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
     max_bvec.store(out_data + d3);
   }
   for (; d3 < size; d3++) {
-    out_data[d3] = BFloat16(max_ptr[d3]);
+    out_data[d3] = scalar_t(max_ptr[d3]);
   }
 }
 
@@ -281,7 +287,7 @@
   int64_t output_height = output.size(-2);
   int64_t output_width = output.size(-1);
 
-  using accscalar_t = at::opmath_type<scalar_t>;
+  using opmath_t = at::opmath_type<scalar_t>;
   // parallel on dim N, C
   at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
     for (int64_t c = begin; c < end; c++) {
@@ -306,17 +312,18 @@
 
             // compute local max
             int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
-            accscalar_t maxval;
-            if (std::numeric_limits<accscalar_t>::has_infinity) {
-              maxval = -std::numeric_limits<accscalar_t>::infinity();
+            opmath_t maxval;
+            if (std::numeric_limits<opmath_t>::has_infinity) {
+              maxval = -std::numeric_limits<opmath_t>::infinity();
             } else {
-              maxval = std::numeric_limits<accscalar_t>::min();
+              maxval = std::numeric_limits<opmath_t>::min();
             }
+
             for (int64_t id = id0; id < id1; id += dilationD) {
               for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
                 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
                   int64_t index = id * input_height * input_width + ih * input_width + iw;
-                  accscalar_t val = input_ptr[index];
+                  opmath_t val = input_ptr[index];
                   if ((val > maxval) || is_nan(static_cast<double>(val))) {
                     maxval = val;
                     maxindex = index;
@@ -396,9 +403,9 @@
   int64_t output_height = output.size(-2);
   int64_t output_width = output.size(-1);
 
-  using accscalar_t = at::opmath_type<scalar_t>;
+  using opmath_t = at::opmath_type<scalar_t>;
   using Vec = vec::Vectorized<scalar_t>;
-  using integer_t = vec::int_same_size_t<accscalar_t>;
+  using integer_t = vec::int_same_size_t<opmath_t>;
   // for the convience of vectorization, use integer of the same size of scalar_t,
   //   e.g. int32_t for float, int64_t for double
   // need to make sure doesn't overflow
@@ -418,11 +425,11 @@
     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
     std::unique_ptr<integer_t []> index_buffer(new integer_t[len]);
     integer_t * index_ptr = index_buffer.get();
-    // temp buffer holding max value with accscalar_t
-    std::unique_ptr<accscalar_t []> max_arr;
-    accscalar_t* max_ptr = nullptr;
-    if (!std::is_same<scalar_t, accscalar_t>::value) {
-      max_arr = std::make_unique<accscalar_t[]>(size);
+    // temp buffer holding max value with opmath_t
+    std::unique_ptr<opmath_t []> max_arr;
+    opmath_t* max_ptr = nullptr;
+    if (!std::is_same<scalar_t, opmath_t>::value) {
+      max_arr = std::make_unique<opmath_t[]>(size);
       max_ptr = max_arr.get();
     }
 
@@ -598,13 +605,13 @@
     int dilationW, int dilationH) {
   switch (input.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
+      AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
         cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
+      AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
         cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
       });
       break;
@@ -637,7 +644,7 @@
       DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
       indices_sizes.insert(indices_sizes.begin(), 1);
       indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
-      AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
+      AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
         cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
           {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
       });
@@ -648,14 +655,14 @@
   }
   switch (input.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d", [&] {
+      AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
         cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
             {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast3d: {
-      AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
+      AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
         cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
           {kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
       });
@@ -672,13 +679,13 @@
     const Tensor& indices) {
   switch (grad_output.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
         cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
         cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
       });
       break;
@@ -705,7 +712,7 @@
       sizes.insert(sizes.begin(), 1);
       grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
       auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
-      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
         cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
       });
       grad_input.squeeze_(0);
@@ -714,13 +721,13 @@
   }
   switch (grad_output.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
         cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast3d: {
-      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
         cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
       });
       break;
diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp
index 1f30294..70443e6 100644
--- a/aten/src/ATen/native/cpu/MaxPooling.cpp
+++ b/aten/src/ATen/native/cpu/MaxPooling.cpp
@@ -1,7 +1,7 @@
 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
-#include <ATen/core/Tensor.h>
 #include <ATen/Dispatch.h>
 #include <ATen/Parallel.h>
+#include <ATen/core/Tensor.h>
 #include <ATen/cpu/vec/vec.h>
 #include <ATen/native/MaxPooling.h>
 #include <c10/util/irange.h>
@@ -31,25 +31,30 @@
     Tensor& output,
     const Tensor& input,
     const PoolingParams1D& p) {
-  AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
-    const Tensor in = input.contiguous();
-    scalar_t* const OP = output.data_ptr<scalar_t>();
-    const scalar_t* const IP = in.data_ptr<scalar_t>();
+  AT_DISPATCH_FLOATING_TYPES_AND2(
+      ScalarType::BFloat16,
+      ScalarType::Half,
+      input.scalar_type(),
+      "max_pool1d_impl",
+      [&] {
+        const Tensor in = input.contiguous();
+        scalar_t* const OP = output.data_ptr<scalar_t>();
+        const scalar_t* const IP = in.data_ptr<scalar_t>();
 
-    // Value used for padding
-    scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
-        ? -std::numeric_limits<scalar_t>::infinity()
-        : std::numeric_limits<scalar_t>::lowest();
+        // Value used for padding
+        scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
+            ? -std::numeric_limits<scalar_t>::infinity()
+            : std::numeric_limits<scalar_t>::lowest();
 
-    at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
-      for (const auto it : c10::irange(begin, end)) {
-        scalar_t* op = OP + it * p.OW;
-        const scalar_t* ip = IP + it * p.IW;
-        std::fill_n(op, p.OW, FILL);
-        max_pool1d_kernel(op, ip, p);
-      }
-    });
-  });
+        at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
+          for (const auto it : c10::irange(begin, end)) {
+            scalar_t* op = OP + it * p.OW;
+            const scalar_t* ip = IP + it * p.IW;
+            std::fill_n(op, p.OW, FILL);
+            max_pool1d_kernel(op, ip, p);
+          }
+        });
+      });
 }
 
 } // namespace
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index af61358..f92ff745 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -871,7 +871,7 @@
         helper(1, 100000, 1, 4, ks=(1, 4))  # test for max_pool1d
 
     @onlyNativeDeviceTypes
-    @dtypes(torch.bfloat16, torch.float, torch.double)
+    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
     @dtypesIfCUDA(torch.half, torch.float, torch.double)
     @gcIfJetson
     def test_max_pool2d_nhwc(self, device, dtype):
@@ -908,7 +908,7 @@
         helper(1, 129, 8, 8, 3, stride=2)
 
     @onlyNativeDeviceTypes
-    @dtypes(torch.bfloat16, torch.float, torch.double)
+    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
     @dtypesIfCUDA(torch.half, torch.float, torch.double)
     @gcIfJetson
     def test_max_pool3d_ndhwc(self, device, dtype):
@@ -974,9 +974,10 @@
         helper(0, 79, 4, 4, 4, 3, stride=2)
 
     @onlyCPU
-    def test_max_pool_bfloat16(self, device):
-        def helper(shape, kernel_size, stride, memory_format):
-            input = torch.randn(shape, dtype=torch.float32, device=device).bfloat16()
+    @dtypes(torch.half, torch.bfloat16)
+    def test_max_pool_bfloat16_half(self, device, dtype):
+        def helper(shape, kernel_size, stride, memory_format, dtype):
+            input = torch.randn(shape, dtype=dtype, device=device)
             input = input.to(memory_format=memory_format).requires_grad_()
             if len(shape) == 4:
                 pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)
@@ -991,22 +992,22 @@
             out2.sum().backward()
 
             self.assertTrue(out.is_contiguous(memory_format=memory_format))
-            self.assertEqual(out.dtype, torch.bfloat16)
-            self.assertEqual(input.grad.dtype, torch.bfloat16)
-            self.assertEqual(out, out2.bfloat16())
+            self.assertEqual(out.dtype, dtype)
+            self.assertEqual(input.grad.dtype, dtype)
+            self.assertEqual(out, out2.to(dtype=dtype))
             self.assertEqual(ind, ind2)
-            self.assertEqual(input.grad, input2.grad.bfloat16())
+            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
 
-        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format)
-        helper((4, 65, 8, 8), 7, 1, torch.channels_last)
-        helper((1, 19, 20, 10), 8, 2, torch.contiguous_format)
-        helper((1, 19, 20, 10), 8, 2, torch.channels_last)
-        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format)
-        helper((4, 65, 8, 8), 7, 1, torch.channels_last)
-        helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format)
-        helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d)
-        helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format)
-        helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d)
+        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
+        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
+        helper((1, 19, 20, 10), 8, 2, torch.contiguous_format, dtype)
+        helper((1, 19, 20, 10), 8, 2, torch.channels_last, dtype)
+        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
+        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
+        helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format, dtype)
+        helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d, dtype)
+        helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format, dtype)
+        helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d, dtype)
 
     @onlyCUDA
     @gcIfJetson
diff --git a/test/test_mps.py b/test/test_mps.py
index 7a06efc..1159e8d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10599,6 +10599,9 @@
         'nn.functional.triplet_margin_loss',
         'nn.functional.triplet_margin_with_distance_loss',
         'round', 'xlogy', 'addcmul',
+        'nn.functional.max_pool2d',
+        'nn.functional.gelu',
+        'nn.functional.glu',
 
         # for macOS 12
         'masked.normalize', 'masked.sum', 'masked.var',
@@ -10756,10 +10759,6 @@
             cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
             mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
 
-            if op.name in ["nn.functional.gelu", "nn.functional.glu"] and dtype == torch.float16:
-                atol = 1e-3
-                rtol = 1e-3
-
             self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
 
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7ce26ae..7bb3612 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -13114,7 +13114,7 @@
            check_batched_forward_grad=False,
            # TODO: add shape checks
            assert_jit_shape_analysis=False,
-           dtypes=floating_types_and(torch.bfloat16),
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            skips=(
                # Pre-existing condition; Needs to be fixed
@@ -13139,7 +13139,7 @@
            # got: Batching rule not implemented for aten::flatten.using_ints
            check_batched_forward_grad=False,
            assert_jit_shape_analysis=True,
-           dtypes=all_types_and(torch.bfloat16),
+           dtypes=all_types_and(torch.float16, torch.bfloat16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            error_inputs_func=error_inputs_max_pool2d,
            sample_inputs_func=sample_inputs_max_pool),
@@ -13156,8 +13156,7 @@
            supports_fwgrad_bwgrad=True,
            check_batched_forward_grad=False,
            assert_jit_shape_analysis=False,
-           dtypes=floating_types_and(torch.bfloat16),
-           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
            sample_inputs_func=sample_inputs_max_pool,
            skips=(
                # We've defined a custom op here, and we don't handle the case where we receive an out kwarg
@@ -13179,7 +13178,7 @@
            check_batched_forward_grad=False,
            # TODO: add shape checks
            assert_jit_shape_analysis=False,
-           dtypes=all_types_and(torch.bfloat16),
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            # TODO: investigate nondeterminism
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,