Meta register all foreach ops (#112281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112281
Approved by: https://github.com/lezcano
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 73879f3..dc8948b 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1,5 +1,6 @@
 import math
 from enum import Enum
+from functools import partial
 from typing import List, Optional, Sequence, Tuple, Union
 
 import torch
@@ -2959,37 +2960,166 @@
     return self.new_empty(self.size())
 
 
-@register_meta(
-    [
-        aten._foreach_abs_.default,
-        aten._foreach_neg_.default,
-        aten._foreach_reciprocal_.default,
-        aten._foreach_sqrt_.default,
-        aten._foreach_sign_.default,
-    ]
-)
-def meta__foreach_unaop_(self):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"Expect List[Tensor] but got {type(self)}",
-    )
+def register_meta_foreach(ops):
+    def wrapper(fn):
+        def register(op):
+            op_name = str(op).split(".")[1]
+            scalar_op = getattr(aten, op_name.replace("_foreach_", ""))
+
+            _add_op_to_registry(
+                meta_table,
+                op,
+                partial(
+                    fn,
+                    _scalar_op=scalar_op,
+                ),
+            )
+
+        pytree.tree_map_(register, ops)
+        return fn
+
+    return wrapper
 
 
-@register_meta(
+@register_meta_foreach(
     [
-        aten._foreach_abs.default,
-        aten._foreach_neg.default,
-        aten._foreach_reciprocal.default,
-        aten._foreach_sqrt.default,
-        aten._foreach_sign.default,
+        aten._foreach_abs,
+        aten._foreach_acos,
+        aten._foreach_asin,
+        aten._foreach_atan,
+        aten._foreach_ceil,
+        aten._foreach_cos,
+        aten._foreach_cosh,
+        aten._foreach_erf,
+        aten._foreach_erfc,
+        aten._foreach_exp,
+        aten._foreach_expm1,
+        aten._foreach_frac,
+        aten._foreach_floor,
+        aten._foreach_lgamma,
+        aten._foreach_log,
+        aten._foreach_log10,
+        aten._foreach_log1p,
+        aten._foreach_log2,
+        aten._foreach_neg,
+        aten._foreach_reciprocal,
+        aten._foreach_round,
+        aten._foreach_sigmoid,
+        aten._foreach_sign,
+        aten._foreach_sin,
+        aten._foreach_sinh,
+        aten._foreach_sqrt,
+        aten._foreach_tan,
+        aten._foreach_tanh,
+        aten._foreach_trunc,
+        aten._foreach_zero,
+        aten._foreach_add,
+        aten._foreach_sub,
+        aten._foreach_mul,
+        aten._foreach_div,
+        aten._foreach_clamp_min,
+        aten._foreach_clamp_max,
+        aten._foreach_lerp,
+    ],
+)
+def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs):
+    torch._check(
+        isinstance(args[0], list),
+        lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."),
+    )
+
+    nelem = len(args[0])
+    torch._check(
+        nelem > 0,
+        lambda: ("Tensor list must have at least one tensor."),
+    )
+
+    nlists = 1
+    for iarg, arg in enumerate(args[1:]):
+        if isinstance(arg, list):
+            nlists += 1
+            torch._check(
+                len(arg) == nelem,
+                lambda: (
+                    f"self and argument-{iarg+2} must match in length, "
+                    f"but got {nelem} and {len(arg)}."
+                ),
+            )
+        elif isinstance(arg, Tensor):
+            torch._check(
+                arg.dim() == 0 and arg.numel() == 1,
+                lambda: (
+                    "scalar tensor expected to be 0 dim but it has "
+                    f"{arg.dim()} dimensions and {arg.numel()} elements."
+                ),
+            )
+        else:
+            break
+
+    result = []
+    for elem in range(nelem):
+        each_args = [args[i][elem] for i in range(nlists)]
+        result.append(_scalar_op(*each_args, *args[nlists:], **kwargs))
+
+    return result
+
+
+@register_meta_foreach(
+    [
+        aten._foreach_abs_,
+        aten._foreach_acos_,
+        aten._foreach_asin_,
+        aten._foreach_atan_,
+        aten._foreach_ceil_,
+        aten._foreach_cos_,
+        aten._foreach_cosh_,
+        aten._foreach_erf_,
+        aten._foreach_erfc_,
+        aten._foreach_exp_,
+        aten._foreach_expm1_,
+        aten._foreach_frac_,
+        aten._foreach_floor_,
+        aten._foreach_lgamma_,
+        aten._foreach_log_,
+        aten._foreach_log10_,
+        aten._foreach_log1p_,
+        aten._foreach_log2_,
+        aten._foreach_neg_,
+        aten._foreach_reciprocal_,
+        aten._foreach_round_,
+        aten._foreach_sigmoid_,
+        aten._foreach_sign_,
+        aten._foreach_sin_,
+        aten._foreach_sinh_,
+        aten._foreach_sqrt_,
+        aten._foreach_tan_,
+        aten._foreach_tanh_,
+        aten._foreach_trunc_,
+        aten._foreach_zero_,
+        aten._foreach_add_,
+        aten._foreach_sub_,
+        aten._foreach_mul_,
+        aten._foreach_div_,
+        aten._foreach_clamp_min_,
+        aten._foreach_clamp_max_,
+        aten._foreach_lerp_,
+        aten._foreach_copy_,
     ]
 )
-def meta__foreach_unaop(self):
+def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
+    _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs)
+    return
+
+
+@register_meta([aten._foreach_pow.ScalarAndTensor])
+def meta__foreach_pow_scalar_and_tensor(self, exponent):
+    # Only foreach_pow has a ScalarAndTensor method and needs special
+    # handling because it does not work with _meta_foreach_out_of_place.
     torch._check(
-        isinstance(self, List),
-        lambda: f"Expect List[Tensor] but got {type(self)}",
+        isinstance(exponent, List),
+        lambda: f"exponent must be a tensor list but got {type(exponent)}",
     )
-    return [torch.empty_like(s) for s in self]
+    return [torch.empty_like(e) for e in exponent]
 
 
 def _check_foreach_binop_tensor_lists(self, other):
@@ -3011,100 +3141,67 @@
 
 @register_meta(
     [
-        aten._foreach_add.List,
-        aten._foreach_sub.List,
-        aten._foreach_mul.List,
-        aten._foreach_div.List,
-        aten._foreach_maximum.List,
-        aten._foreach_minimum.List,
-        aten._foreach_clamp_min.List,
-        aten._foreach_clamp_max.List,
+        aten._foreach_maximum,
+        aten._foreach_minimum,
     ]
 )
-def meta__foreach_binop_list(self, other, alpha=1):
-    _check_foreach_binop_tensor_lists(self, other)
+def meta__foreach_binop_scalar(*args):
+    # aten.maximum(Tensor, Scalar) does not exist.
+    return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min)
+
+
+@register_meta(
+    [
+        aten._foreach_maximum_,
+        aten._foreach_minimum_,
+    ]
+)
+def meta__foreach_binop__scalar(*args):
+    # aten.maximum(Tensor, Scalar) does not exist
+    _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_)
+    return
+
+
+@register_meta(
+    [
+        aten._foreach_addcdiv.Scalar,
+        aten._foreach_addcmul.Scalar,
+    ]
+)
+def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
+    # forach_addcdiv and addcdiv have different signatures and
+    # cannot use _meta_foreach_out_of_place.
+    torch._check(
+        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
+        lambda: (
+            "All arguments must be List[Tensor], "
+            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
+        ),
+    )
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+    torch._check(
+        len(self) == len(tensor1) and len(self) == len(tensor2),
+        lambda: "All input tensor lists must have the same length",
+    )
+
     return [torch.empty_like(s) for s in self]
 
 
-@register_meta(
-    [
-        aten._foreach_add_.List,
-        aten._foreach_sub_.List,
-        aten._foreach_mul_.List,
-        aten._foreach_div_.List,
-        aten._foreach_maximum_.List,
-        aten._foreach_minimum_.List,
-        aten._foreach_clamp_min_.List,
-        aten._foreach_clamp_max_.List,
-    ]
-)
-def meta__foreach_binop__list(self, other, alpha=1):
-    _check_foreach_binop_tensor_lists(self, other)
-
-
-@register_meta(
-    [
-        aten._foreach_add.Tensor,
-    ]
-)
-def meta__foreach_binop_tensor(self, other, alpha=1):
+@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
+def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
     torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
+        all(isinstance(l, List) for l in [self, tensor1, tensor2])
+        and isinstance(scalars, torch.Tensor),
+        lambda: (
+            "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, "
+            f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
+        ),
     )
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
     torch._check(
-        isinstance(other, torch.Tensor),
-        lambda: f"The second argument must be Tensor, but got {type(other)}.",
+        len(self) == len(tensor1) and len(self) == len(tensor2),
+        lambda: "All input tensor lists must have the same length",
     )
-    return [torch.empty_like(s) for s in self]
-
-
-@register_meta(
-    [
-        aten._foreach_add_.Tensor,
-    ]
-)
-def meta__foreach_binop__tensor(self, other, alpha=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument must be List[Tensor], but got {type(self)}.",
-    )
-    torch._check(
-        isinstance(other, torch.Tensor),
-        lambda: f"The second argument must be Tensor, but got {type(other)}.",
-    )
-
-
-@register_meta(
-    [
-        aten._foreach_add_.Scalar,
-        aten._foreach_mul_.Scalar,
-        aten._foreach_sub_.Scalar,
-        aten._foreach_div_.Scalar,
-        aten._foreach_maximum_.Scalar,
-    ]
-)
-def meta__foreach_binop__scalar(self, scalar=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
-    )
-
-
-@register_meta(
-    [
-        aten._foreach_add.Scalar,
-        aten._foreach_div.Scalar,
-        aten._foreach_mul.Scalar,
-        aten._foreach_sub.Scalar,
-    ]
-)
-def meta__foreach_binop_scalar(self, scalar=1):
-    torch._check(
-        isinstance(self, List),
-        lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
-    )
-    return [torch.empty_like(s) for s in self]
 
 
 @register_meta(
@@ -3128,69 +3225,6 @@
     )
 
 
-@register_meta(
-    [
-        aten._foreach_lerp_.Scalar,
-    ]
-)
-def meta__foreach_lerp__scalar(self, other, scalar=1):
-    _check_foreach_binop_tensor_lists(self, other)
-
-
-@register_meta(
-    [
-        aten._foreach_addcdiv.Scalar,
-        aten._foreach_addcmul.Scalar,
-    ]
-)
-def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
-    torch._check(
-        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
-        lambda: (
-            "All arguments must be List[Tensor], "
-            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
-        ),
-    )
-    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    torch._check(
-        len(self) == len(tensor1) and len(self) == len(tensor2),
-        lambda: "All input tensor lists must have the same length",
-    )
-
-    return [torch.empty_like(s) for s in self]
-
-
-@register_meta([aten._foreach_pow.ScalarAndTensor])
-def meta__foreach_pow_scalar_and_tensor(self, exponent):
-    torch._check(
-        isinstance(exponent, List),
-        lambda: f"exponent must be a tensor list but got {type(exponent)}",
-    )
-    return [torch.empty_like(e) for e in exponent]
-
-
-@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
-def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
-    torch._check(
-        all(isinstance(l, List) for l in [self, tensor1, tensor2])
-        and isinstance(scalars, torch.Tensor),
-        lambda: (
-            "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, "
-            f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
-        ),
-    )
-    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    torch._check(
-        len(self) == len(tensor1) and len(self) == len(tensor2),
-        lambda: "All input tensor lists must have the same length",
-    )
-
-
-@register_meta([aten._foreach_copy_])
-def meta__foreach_copy_inplace(self, src, non_blocking=False):
-    _check_foreach_binop_tensor_lists(self, src)
-
-
 @register_meta([aten._fused_adam_.default])
 def meta__fused_adam_(
     self,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index fde2a7b..2629c48 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8974,128 +8974,38 @@
         'exp',
         foreach_inputs_sample_func(1, False, False),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'acos',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'asin',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'atan',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'cos',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'cosh',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log10',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log2',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'tan',
@@ -9114,16 +9024,6 @@
                 device_type='cuda'
             ),
         ),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'tanh',
@@ -9139,44 +9039,14 @@
                 device_type='cuda'
             ),
         ),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'sin',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'sinh',
         foreach_inputs_sample_func(1, False, False),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'neg',
@@ -9196,48 +9066,18 @@
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'erf',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'erfc',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'expm1',
@@ -9245,80 +9085,30 @@
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'floor',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'log1p',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_and_complex_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'round',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'frac',
         foreach_inputs_sample_func(1, False, False),
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'reciprocal',
@@ -9333,32 +9123,12 @@
         dtypes=floating_types_and(torch.bfloat16),
         dtypesIfCUDA=floating_types_and(torch.half),
         backward_requires_result=True,
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'trunc',
         foreach_inputs_sample_func(1, False, False),
         dtypes=all_types_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
     ForeachFuncInfo(
         'abs',
@@ -9368,21 +9138,11 @@
         supports_fwgrad_bwgrad=True,
         skips=(
             DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
+                         "test_dispatch_symbolic_meta_inplace", dtypes=complex_types()),
+            DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
                          "test_dispatch_meta_inplace", dtypes=complex_types()),
             DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestMeta",
                          "test_meta_inplace", dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
-                         dtypes=complex_types()),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
-                         dtypes=complex_types()),
         ),
     ),
     ForeachFuncInfo(
@@ -9391,13 +9151,9 @@
         dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
         has_no_out_of_place=True,
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
         ),
     ),
@@ -9413,14 +9169,12 @@
         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
         dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
         ),
     ),
 ]
@@ -9433,14 +9187,16 @@
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         supports_alpha_param=True,
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
+            # These tests fail with aten._local_scalar_dense not being implemented.
             DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
         ),
     ),
     ForeachFuncInfo(
@@ -9466,14 +9222,15 @@
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
         ),
     ),
     ForeachFuncInfo(
@@ -9482,14 +9239,24 @@
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
         sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
         skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)),
+            # fails with div_cpu is not implemented with ComplexHalf
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=(torch.float16,), device_type='cpu'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=(torch.float16,), device_type='cpu'),
         ),
     ),
     ForeachFuncInfo(
@@ -9660,16 +9427,6 @@
         dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
         dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        skips=(
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
-            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
-        ),
     ),
 ]