Ported matmul compositeimplicitautograd impl into core (#85239)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index 26436f1..0cf820c 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -276,6 +276,9 @@
 
 def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
     fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
+    if config.debug_graphs:
+        print("====== Forward (only) graph ======")
+        fw_module.print_readable()
     with track_graph_compiling("inference"):
         compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
 
diff --git a/test/test_decomp.py b/test/test_decomp.py
index ffd44cd..719240d 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -24,6 +24,7 @@
     instantiate_device_type_tests,
 )
 from torch.testing._internal.common_methods_invocations import op_db
+from torch._dispatch.python import enable_python_dispatcher
 
 import itertools
 import functools
@@ -486,7 +487,7 @@
                 # explicit clearing is necessary as I will create a fresh mode
                 # for each region
                 decomposed.clear()
-                with enable_torch_dispatch_mode(DecompCrossRefMode):
+                with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                     decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
                 if aten_name in decomposition_names:
                     check_decomposed(aten_name)
@@ -495,7 +496,7 @@
                     cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
 
                     decomposed.clear()
-                    with enable_torch_dispatch_mode(DecompCrossRefMode):
+                    with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                         decomp_vjp_fn(cotangents)
                     if not run_all:
                         check_decomposed(op.aten_backward_name)
@@ -504,7 +505,7 @@
                 args = [sample_input.input] + list(sample_input.args)
                 kwargs = sample_input.kwargs
                 decomposed.clear()
-                with enable_torch_dispatch_mode(DecompCrossRefMode):
+                with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                     func(*args, **kwargs)
                 if not run_all:
                     check_decomposed(aten_name)
diff --git a/test/test_ops.py b/test/test_ops.py
index 4cf025a..e96b59a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1671,7 +1671,6 @@
         '_refs.expand_as',
         '_refs.as_strided',  # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
         '_refs.copy_to',  # torch._C._jit_get_operation: No such operator aten::copy_to
-        '_refs.clone',  # test_meta.py: view size is not compatible with input tensor's size and stride
         '_refs.equal',  # 'bool' object has no attribute 'dtype'
         '_refs.conj',  # Calls _prims.conj
         '_refs.real',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e238bdb..024e729 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -22,6 +22,7 @@
 
 import types
 import functools
+import itertools
 
 aten = torch.ops.aten
 
@@ -1010,7 +1011,6 @@
     xfail('linalg.eigvals'),
     skip('_masked.logsumexp', ''),  # Tensors of type TensorImpl do not have numel
     xfail('__getitem__', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('__rmatmul__', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.argmax', ''),  # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1022,15 +1022,12 @@
     xfail('_masked.mean', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
     xfail('_masked.median', ''),  # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
     xfail('_masked.norm', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
-    xfail('_masked.normalize', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.prod', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.std', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
     xfail('_masked.sum', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.var', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
-    xfail('addmm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
-    xfail('addmm', 'decomposed'),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('addmv', ''),  # aten.addmv.default - couldn't find symbolic meta function/decomposition
     xfail('addr', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('aminmax', ''),  # aten.aminmax.default - couldn't find symbolic meta function/decomposition
@@ -1042,13 +1039,11 @@
     xfail('as_strided_scatter', ''),  # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('baddbmm', ''),  # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
     xfail('bernoulli', ''),  # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
-    xfail('bmm', ''),  # aten.bmm.default - couldn't find symbolic meta function/decomposition
     xfail('bucketize', ''),  # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
     xfail('cartesian_prod', ''),  # Tensors of type TensorImpl do not have numel
     xfail('cdist', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('cholesky_solve', ''),  # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
     xfail('chunk', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('clone', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
     xfail('column_stack', ''),  # Tensors of type TensorImpl do not have numel
     xfail('constant_pad_nd', ''),  # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
     xfail('count_nonzero', ''),  # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
@@ -1150,7 +1145,6 @@
     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.vector_norm', ''),  # TensorImpl do not have numel
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp', ''),  # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
     xfail('logcumsumexp', ''),  # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
@@ -1161,14 +1155,12 @@
     xfail('masked_fill', ''),  # expected predicate to be bool, got torch.float32
     xfail('masked_scatter', ''),  # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('masked_select', ''),  # aten.masked_select.default - couldn't find symbolic meta function/decomposition
-    xfail('matmul', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('matrix_exp', ''),  # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
     xfail('max', 'reduction_with_dim'),  # aten.max.dim - couldn't find symbolic meta function/decomposition
     xfail('median', ''),  # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
     xfail('meshgrid', 'list_of_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('meshgrid', 'variadic_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('min', 'reduction_with_dim'),  # aten.min.dim - couldn't find symbolic meta function/decomposition
-    xfail('mm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('mode', ''),  # aten.mode.default - couldn't find symbolic meta function/decomposition
     xfail('msort', ''),  # aten.sort.default - couldn't find symbolic meta function/decomposition
     xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
@@ -1213,7 +1205,6 @@
     xfail('nn.functional.local_response_norm', ''),  # Tensors of type TensorImpl do not have numel
     xfail('nn.functional.margin_ranking_loss', ''),  # The underlying op of 'aten.stride' has no overload name '_schema'
     xfail('nn.functional.max_pool1d', ''),  # Trying to call aten.size on a tensor with symbolic shapes.
-    xfail('nn.functional.max_pool2d', ''),  # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_pool3d', ''),  # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_unpool1d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
     xfail('nn.functional.max_unpool2d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
@@ -1234,7 +1225,6 @@
     xfail('nn.functional.unfold', ''),  # aten.im2col.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.upsample_bilinear', ''),  # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
     xfail('nn.functional.upsample_nearest', ''),  # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
-    xfail('norm', ''),  # TensorImpl does not have numel
     xfail('norm', 'nuc'),  # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
     xfail('normal', ''),  # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
     xfail('normal', 'number_mean'),  # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1337,7 +1327,9 @@
         return op.op(*args, **kwargs)
     sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
     new_f = None
-    for sample_input in sample_inputs_itr:
+
+    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
+    for sample_input in itertools.islice(sample_inputs_itr, 100):
         args = [sample_input.input] + list(sample_input.args)
         kwargs = sample_input.kwargs
 
@@ -1345,7 +1337,6 @@
             new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
         except DynamicOutputShapeException as e:
             self.skipTest("Dynamic output shape operation in trace")
-
         for arg in args:
             if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
                 arg.uniform_(0, 1)
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 1d09c0b..05dc65d 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1,4 +1,5 @@
 import functools
+import operator
 from enum import Enum
 from itertools import product
 from typing import Callable, Iterable, List, Optional, Tuple
@@ -12,6 +13,8 @@
 from torch._prims_common.wrappers import out_wrapper
 from torch.utils._pytree import tree_flatten, tree_map
 
+DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
+
 # None of these functions are publicly accessible; get at them
 # from torch._decomps
 __all__: List[str] = []
@@ -2030,3 +2033,131 @@
         loss = loss * weight
 
     return apply_loss_reduction(loss, reduction)
+
+
+def should_fold(tensor1: torch.Tensor, dim_tensor2: int) -> bool:
+    dim_tensor1 = tensor1.ndim
+    if dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2):
+        t1_sizes_ptr = tensor1.shape
+        t1_strides = tensor1.stride()
+        if (
+            dim_tensor1 == 3
+            and dim_tensor2 == 2
+            and t1_strides[-1] != 1
+            and t1_strides[0] == t1_sizes_ptr[1] * t1_sizes_ptr[2]
+        ):
+            # First dim is slowest moving, and then the following two dims are
+            # transposed. This can happen for example by permute(0, 2, 1).
+            # First 2 dims could be folded to use mm but would require permutation
+            # with actual data movement, which can be instead handled by BMM with each
+            # GEMM transposed.
+            # This can be generalized to a tensor with dim X + Y + Z where X, Y, and Z
+            # dims are contiguous, Y dims and Z dims are transposed, and X, Y, Z > 0.
+            # For example, this can happen by permute(0, 1, 5, 2, 3, 4), where X = 2,
+            # Y = 3, and Z = 1.
+            return False
+        else:
+            return True
+    else:
+        return False
+
+
[email protected]_impl(DispatchKey.CompositeImplicitAutograd)
+def matmul(tensor1, tensor2):
+    dim_tensor1 = tensor1.dim()
+    dim_tensor2 = tensor2.dim()
+    assert dim_tensor1 != 0 and dim_tensor2 != 0
+    if dim_tensor1 == 1 and dim_tensor2 == 1:
+        return torch.dot(tensor1, tensor2)
+    elif dim_tensor1 == 2 and dim_tensor2 == 1:
+        return torch.mv(tensor1, tensor2)
+    elif dim_tensor1 == 1 and dim_tensor2 == 2:
+        return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
+    elif dim_tensor1 == 2 and dim_tensor2 == 2:
+        # if tensor1.shape[1] != tensor2.shape[0]:
+        #     breakpoint()
+        return torch.mm(tensor1, tensor2)
+    elif should_fold(tensor1, dim_tensor2) or should_fold(tensor2, dim_tensor1):
+        # NB: Much of this was written with Copilot! (although still had to fix a bunch of issues)
+
+        # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
+        # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
+        # and some condition on the strides is fulfilled
+
+        # optimization: use mm instead of bmm by folding the batch of the larger tensor
+        # into its leading matrix dimension
+        transpose = dim_tensor2 > dim_tensor1
+        t1 = tensor2.mT if transpose else tensor1
+        t2 = (
+            tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
+        )
+        # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
+        #            and t1 and t2 are matmul-compatible
+
+        # Why not t1.view(-1, sizes_1[-1])?
+        # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
+        # This can happen in e.g. [3, 5, 0] @ [0, 0].
+        sizes_1 = t1.shape
+        output_shape = list(sizes_1[:-1])
+        folded_dim1 = functools.reduce(operator.mul, output_shape)
+
+        # Readjust output_shape if we are multiplying by a matrix
+        t2_is_matrix = t2.dim() == 2
+        if t2_is_matrix:
+            output_shape.append(t2.shape[1])
+        # HACK: We need reshape with symint support
+        t1 = t1.contiguous()
+        t1_folded = t1.view(folded_dim1, sizes_1[-1])
+        if t2_is_matrix:
+            # FIXME This path always does an unnecessary copy when transpose == True as the returned
+            # result from BLAS is already C-transposed
+            output = t1_folded.mm(t2).view(output_shape)
+            return output.mT.contiguous() if transpose else output
+        else:
+            return t1_folded.mv(t2).view(output_shape)
+
+    elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
+        # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
+        # we track m1 vs m2 separately even though they must match for nicer error messages
+        n = tensor1.size(-2) if dim_tensor1 > 1 else 1
+        m1 = tensor1.size(-1)
+        batch_tensor1 = tensor1.shape[:-2]
+        m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
+        p = tensor2.size(-1) if dim_tensor2 > 1 else 1
+        batch_tensor2: List[int] = []
+        # TODO: handling of slice
+        for i in range(dim_tensor2 - 2):
+            batch_tensor2.append(tensor2.size(i))
+
+        # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
+        expand_batch_portion = list(
+            torch.broadcast_shapes(batch_tensor1, batch_tensor2)
+        )
+
+        tensor1_expand_size = expand_batch_portion + [n, m1]
+        tensor2_expand_size = expand_batch_portion + [m2, p]
+
+        expand_batch_product = prod(expand_batch_portion)
+
+        # HACK: We need reshape with symint support
+        tensor1_expanded = (
+            tensor1.expand(tensor1_expand_size)
+            .contiguous()
+            .view(expand_batch_product, n, m1)
+        )
+        tensor2_expanded = (
+            tensor2.expand(tensor2_expand_size)
+            .contiguous()
+            .view(expand_batch_product, m2, p)
+        )
+
+        output_shape = expand_batch_portion
+        if dim_tensor1 > 1:
+            output_shape.append(n)
+
+        if dim_tensor2 > 1:
+            output_shape.append(p)
+
+        return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
+    else:
+        utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e63e277..4684a72 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -193,6 +193,16 @@
     return self.new_empty(())
 
 
+@register_meta([aten.mm.default], register_dispatcher=False)
+def meta_mm(a, b):
+    check(a.dim() == 2, lambda: "a must be 2D")
+    check(b.dim() == 2, lambda: "b must be 2D")
+    N, M1 = a.shape
+    M2, P = b.shape
+    check(M1 == M2, lambda: "a and b must have same reduction dim")
+    return a.new_empty(N, P)
+
+
 def _compute_reduction_shape(self, dims, keepdim):
     if keepdim:
         return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
@@ -762,6 +772,234 @@
     return self.view(self.shape)
 
 
+def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
+    check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+    check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+
+    batch1_sizes = batch1.size()
+    batch2_sizes = batch2.size()
+
+    bs = batch1_sizes[0]
+    contraction_size = batch1_sizes[2]
+    res_rows = batch1_sizes[1]
+    res_cols = batch2_sizes[2]
+    output_size = (bs, res_rows, res_cols)
+
+    check(
+        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
+        lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
+        f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
+    )
+
+    # TODO: handle out
+
+    output = batch2.new_empty(output_size)
+
+    if not is_bmm and self_baddbmm is not None:
+        check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
+        check(
+            self_baddbmm.size() == output_size,
+            lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
+        )
+
+    return output
+
+
+@register_meta(aten.bmm.default, register_dispatcher=False)
+def meta_bmm(self, mat2):
+    return common_meta_baddbmm_bmm(self, mat2, True)
+
+
+def div_rtn(x, y):
+    q = x // y
+    r = x % y
+    # WARNING: explicit bool conversion here is necessary;
+    # would be fixed by SymBool
+    if r != 0 and (bool(r < 0) != bool(y < 0)):
+        q -= 1
+    return q
+
+
+def pooling_output_shape_pad_lr(
+    inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
+):
+    outputSize = (
+        div_rtn(
+            inputSize
+            + pad_l
+            + pad_r
+            - dilation * (kernelSize - 1)
+            - 1
+            + (stride - 1 if ceil_mode else 0),
+            stride,
+        )
+        + 1
+    )
+    if ceil_mode:
+        if (outputSize - 1) * stride >= inputSize + pad_l:
+            outputSize -= 1
+    return outputSize
+
+
+def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
+    check(stride != 0, lambda: "stride should not be zero")
+    check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
+    check(
+        pad <= kernelSize // 2,
+        lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
+    )
+    return pooling_output_shape_pad_lr(
+        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
+    )
+
+
+def pool2d_shape_check(
+    input,
+    kH,
+    kW,
+    dH,
+    dW,
+    padH,
+    padW,
+    dilationH,
+    dilationW,
+    nInputPlane,
+    inputHeight,
+    inputWidth,
+    outputHeight,
+    outputWidth,
+    memory_format,
+):
+    ndim = input.dim()
+    nOutputPlane = nInputPlane
+
+    check(
+        kW > 0 and kH > 0,
+        lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
+    )
+    check(
+        dW > 0 and dH > 0,
+        lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
+    )
+    check(
+        dilationH > 0 and dilationW > 0,
+        lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
+    )
+
+    valid_dims = input.size(1) != 0 and input.size(2) != 0
+
+    if memory_format == torch.channels_last:
+        check(
+            ndim == 4 and valid_dims and input.size(3) != 0,
+            lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
+            " with optional 0 dim batch size for input, but got: {input.size()}",
+        )
+    else:
+        check(
+            (ndim == 3 and input.size(0) != 0 and valid_dims)
+            or (ndim == 4 and valid_dims and input.size(3) != 0),
+            lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
+        )
+
+    check(
+        kW // 2 >= padW and kH // 2 >= padH,
+        lambda: "pad should be smaller than or equal to half of kernel size, but got "
+        f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
+    )
+
+    check(
+        outputWidth >= 1 and outputHeight >= 1,
+        lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
+        f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
+        "Output size is too small",
+    )
+
+
+@register_meta(aten.max_pool2d_with_indices.default, register_dispatcher=False)
+def meta_max_pool2d_with_indices(
+    input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
+):
+    # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
+    def unpack(name, val):
+        check(
+            len(val) in [1, 2],
+            lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
+        )
+        H = val[0]
+        W = H if len(val) == 1 else val[1]
+        return H, W
+
+    kH, kW = unpack("kernel_size", kernel_size)
+
+    check(
+        len(stride) in [0, 1, 2],
+        lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
+    )
+    if len(stride) == 0:
+        dH, dW = kH, kW
+    else:
+        dH, dW = unpack("stride", stride)
+
+    padH, padW = unpack("padding", padding)
+    dilationH, dilationW = unpack("dilation", dilation)
+
+    memory_format = utils.suggest_memory_format(input)
+    if memory_format == torch.channels_last:
+        check(
+            input.dim() == 4,
+            lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
+        )
+    elif memory_format == torch.contiguous_format:
+        check(
+            input.dim() in [3, 4],
+            lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
+        )
+    else:
+        check(
+            False,
+            lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
+        )
+
+    nbatch = input.size(-4) if input.dim() == 4 else 1
+    nInputPlane = input.size(-3)
+    inputHeight = input.size(-2)
+    inputWidth = input.size(-1)
+
+    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
+    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
+
+    pool2d_shape_check(
+        input,
+        kH,
+        kW,
+        dH,
+        dW,
+        padH,
+        padW,
+        dilationH,
+        dilationW,
+        nInputPlane,
+        inputHeight,
+        inputWidth,
+        outputHeight,
+        outputWidth,
+        memory_format,
+    )
+
+    if input.dim() == 3:
+        size = [nInputPlane, outputHeight, outputWidth]
+    else:
+        size = [nbatch, nInputPlane, outputHeight, outputWidth]
+    return (
+        torch.empty(
+            size, dtype=input.dtype, device=input.device, memory_format=memory_format
+        ),
+        torch.empty(
+            size, dtype=torch.int64, device=input.device, memory_format=memory_format
+        ),
+    )
+
+
 # We must also trigger meta registrations from PrimTorch ref
 # decompositions
 import torch._refs
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 7d334c9..6f59ad1 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -338,10 +338,7 @@
                 return x
 
             if not utils.same_shape(x.shape, common_shape):
-                common_rank = len(common_shape) + 1
-                start = common_rank - (len(x.shape) + 1)
-                dims = tuple(range(start, len(x.shape) + start))
-                return prims.broadcast_in_dim(x, common_shape, dims)
+                return x.expand(common_shape)
 
             return x
         else:
@@ -1658,6 +1655,7 @@
 #
 # Data Movement References
 #
+@register_decomposition(torch.ops.aten.clone.default)
 def clone(
     a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
 ) -> TensorLikeType:
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 38b446e..ee0f8ab 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -727,7 +727,6 @@
             aten.empty_strided.default,
             aten.as_strided.default,
             aten.zeros.default,
-            aten.clone.default,
             aten.detach.default,
         ]
         # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 34b4012..202990d 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -442,7 +442,8 @@
         if func in [prim.device.default]:
             return func(*args, **kwargs)
 
-        return proxy_call(self, func, args, kwargs)
+        out = proxy_call(self, func, args, kwargs)
+        return out
 
 
 SymInt = torch.SymIntNode
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index e938945..54fe764 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1,6 +1,6 @@
 import torch
 import torch.utils._pytree as pytree
-from typing import Dict, List, Type, Optional, cast
+from typing import Set, Dict, List, Type, Optional, cast
 import operator
 import functools
 from functools import lru_cache
@@ -17,7 +17,7 @@
 
 __all__ = [
     "has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
-    "SymDispatchMode", "PySymFloat", "sym_float"
+    "SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
 ]
 
 SYM_FUNCTION_MODE = None
@@ -148,13 +148,31 @@
     def __str__(self):
         return f"{self.expr}"
 
+if HAS_SYMPY:
+    class FloorDiv(sympy.Function):
+        """
+        We maintain this so that:
+        1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
+        2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
+        """
+        nargs = (2,)
+
+        @classmethod
+        def eval(cls, base, divisor):
+            if base == 0:
+                return sympy.Integer(0)
+            if divisor == 1:
+                return base
+            if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
+                return base // divisor
+
 # Methods that have a `__foo__` as well as `__rfoo__`
 reflectable_magic_methods = {
     'add': lambda a, b: a + b,
     'sub': lambda a, b: a - b,
     'mul': lambda a, b: a * b,
     'mod': lambda a, b: a % b,
-    'floordiv': lambda a, b: (a - (a % b)) / b
+    'floordiv': lambda a, b: FloorDiv(a, b)
 }
 
 magic_methods = {
@@ -225,8 +243,8 @@
         # Maps from sympy ints to expressions representing them
         # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
         self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {}  #
-        # Keys are Mod(x, y), values are 0 (for ease of substitution)
-        self.divisible: Dict["sympy.Expr", "sympy.Integer"] = {}
+        # Set holds a % b expressions that evaluate to 0.
+        self.divisible: Set["sympy.Expr"] = set()
 
     def _get_key(self):
         """
@@ -267,9 +285,7 @@
         return all(guard.xreplace(new_env.var_to_val) == value for guard, value, _ in self.guards)
 
     def get_nontrivial_guards(self):
-        guards = [(self.simplify(guard), val) for guard, val, _ in self.guards]
-        guards = [guard for guard in guards if len(guard[0].free_symbols) > 0]
-        return guards
+        return [(self.simplify(guard), val) for guard, val, _ in self.guards if self._maybe_evaluate_static(guard) is None]
 
     def get_shape_groups(self):
         shape_groups = collections.defaultdict(list)
@@ -282,6 +298,7 @@
         """
         Tries to evaluate expr without introducing guards
         """
+        expr = self.simplify(expr)
         # Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
         symbols = list(expr.free_symbols)
         new_shape_env = {
@@ -289,7 +306,10 @@
             for idx, k in enumerate(symbols)
         }
         new_expr = expr.xreplace(new_shape_env)
-        new_expr = sympy.expand(new_expr)
+        floor_div_replace = {}
+        for atom in new_expr.atoms(FloorDiv):
+            floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
+        new_expr = sympy.expand(new_expr.xreplace(floor_div_replace))
         if len(list(new_expr.free_symbols)) == 0:
             return new_expr
         return None
@@ -301,20 +321,25 @@
 
     @_lru_cache
     def _update_divisible(self):
-        new_divisible = {}
+        new_divisible = set()
         for k in self.divisible:
             res = self.replace(k)
             if len(res.free_symbols) > 0:
-                new_divisible[k] = sympy.Integer(0)
+                new_divisible.add(k)
 
         self.divisible = new_divisible
 
     @_lru_cache
     def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
         expr = self.replace(expr)
-        if expr.has(sympy.Mod):
+        if expr.has(FloorDiv):
             self._update_divisible()
-            expr = expr.xreplace(self.divisible)
+            div_replacements = {}
+            for atom in expr.atoms(FloorDiv):
+                base, divisor = atom.args
+                if self.replace(base % divisor) in self.divisible:
+                    div_replacements[atom] = base / divisor
+            expr = expr.xreplace(div_replacements)
             expr = sympy.expand(expr)
         return expr
 
@@ -363,24 +388,23 @@
         lhs = expr.lhs
         rhs = expr.rhs
         try:
-            solutions = sympy.solveset(lhs - rhs, free[0], domain=sympy.S.Integers)
-            if not solutions.is_finite_set:
-                if expr.has(sympy.Mod):
-                    mod_expr = tuple(expr.atoms(sympy.Mod))[0]
-                    solutions = sympy.solveset(lhs - rhs, mod_expr, domain=sympy.S.Integers)
-                    if solutions.is_finite_set and len(solutions) == 1 and tuple(solutions)[0] == 0:
-                        self.divisible[mod_expr] = sympy.Integer(0)
+            solutions = sympy.solve(lhs - rhs, free[0], dict=True)
+            if len(solutions) != 1:
                 return
-
-            if not isinstance(solutions, sympy.FiniteSet):
-                return
-
-            solutions = tuple(solutions)
-            if len(solutions) == 1 and all(t.is_integer for t in sympy.preorder_traversal(solutions[0])):
-                new_var = self._find(solutions[0])
+            solution = solutions[0][free[0]]
+            if all(t.is_integer for t in sympy.preorder_traversal(solution)):
+                new_var = self._find(solution)
                 self.replacements[cast(sympy.Symbol, free[0])] = new_var
-        except ZeroDivisionError:
-            pass
+        except NotImplementedError:
+            if expr.has(sympy.Mod):
+                mod_expr = tuple(expr.atoms(sympy.Mod))[0]
+                try:
+                    solutions = sympy.solve(lhs - rhs, mod_expr, dict=True)
+                    if len(solutions) == 1 and solutions[0][mod_expr] == 0:
+                        self.divisible.add(mod_expr)
+                except NotImplementedError:
+                    pass
+            return
 
     @lru_cache(256)
     def evaluate_expr(self, expr: "sympy.Expr"):
@@ -390,7 +414,6 @@
         if len(expr.free_symbols) == 0:
             return expr
         expr = self.simplify(expr)
-
         static_expr = self._maybe_evaluate_static(expr)
         if static_expr is not None:
             return static_expr
diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py
index c7eb3c9..5d65f03 100644
--- a/torch/fx/proxy.py
+++ b/torch/fx/proxy.py
@@ -162,7 +162,6 @@
             return a.node
         elif isinstance(a, base_types) or a is None or a is ...:
             return a
-
         raise NotImplementedError(f"argument of type: {type(a)}")
 
     @compatibility(is_backward_compatible=True)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 75ece1f..7f8d9b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -5558,7 +5558,7 @@
 
     return samples
 
-def sample_inputs_matmul(op_info, device, dtype, requires_grad, **kwargs):
+def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
     test_cases = (((L,), (L,)),
                   ((S, M), (M,)),
                   ((M,), (M, S)),
@@ -5577,12 +5577,10 @@
     for lhs_shape, rhs_shape in test_cases:
         lhs = make_tensor(lhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
         rhs = make_tensor(rhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
-        if op_info.name == 'matmul':
+        if not is_rmatmul:
             sample_inputs.append(SampleInput(lhs, args=(rhs,)))
-        elif op_info.name == '__rmatmul__':
-            sample_inputs.append(SampleInput(rhs, args=(lhs,)))
         else:
-            raise RuntimeError("`op_info.name` must be 'matmul' or '__rmatmul__'")
+            sample_inputs.append(SampleInput(rhs, args=(lhs,)))
     return tuple(sample_inputs)
 
 
@@ -9972,7 +9970,7 @@
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
            check_batched_forward_grad=False,
-           sample_inputs_func=sample_inputs_matmul,
+           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False),
            decorators=[
                # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
@@ -12679,7 +12677,7 @@
                                                        *[torch.bfloat16]
                                                        if (SM53OrLater and CUDA11OrLater) or TEST_WITH_ROCM else []),
            assert_autodiffed=True,
-           sample_inputs_func=sample_inputs_matmul,
+           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True),
            # Runs very slowly on slow gradcheck - alternatively reduce input sizes
            gradcheck_fast_mode=True,
            supports_out=False,