Ported matmul compositeimplicitautograd impl into core (#85239)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
diff --git a/test/test_decomp.py b/test/test_decomp.py
index ffd44cd..719240d 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -24,6 +24,7 @@
     instantiate_device_type_tests,
 )
 from torch.testing._internal.common_methods_invocations import op_db
+from torch._dispatch.python import enable_python_dispatcher
 
 import itertools
 import functools
@@ -486,7 +487,7 @@
                 # explicit clearing is necessary as I will create a fresh mode
                 # for each region
                 decomposed.clear()
-                with enable_torch_dispatch_mode(DecompCrossRefMode):
+                with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                     decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
                 if aten_name in decomposition_names:
                     check_decomposed(aten_name)
@@ -495,7 +496,7 @@
                     cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
 
                     decomposed.clear()
-                    with enable_torch_dispatch_mode(DecompCrossRefMode):
+                    with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                         decomp_vjp_fn(cotangents)
                     if not run_all:
                         check_decomposed(op.aten_backward_name)
@@ -504,7 +505,7 @@
                 args = [sample_input.input] + list(sample_input.args)
                 kwargs = sample_input.kwargs
                 decomposed.clear()
-                with enable_torch_dispatch_mode(DecompCrossRefMode):
+                with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
                     func(*args, **kwargs)
                 if not run_all:
                     check_decomposed(aten_name)
diff --git a/test/test_ops.py b/test/test_ops.py
index 4cf025a..e96b59a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1671,7 +1671,6 @@
         '_refs.expand_as',
         '_refs.as_strided',  # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
         '_refs.copy_to',  # torch._C._jit_get_operation: No such operator aten::copy_to
-        '_refs.clone',  # test_meta.py: view size is not compatible with input tensor's size and stride
         '_refs.equal',  # 'bool' object has no attribute 'dtype'
         '_refs.conj',  # Calls _prims.conj
         '_refs.real',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e238bdb..024e729 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -22,6 +22,7 @@
 
 import types
 import functools
+import itertools
 
 aten = torch.ops.aten
 
@@ -1010,7 +1011,6 @@
     xfail('linalg.eigvals'),
     skip('_masked.logsumexp', ''),  # Tensors of type TensorImpl do not have numel
     xfail('__getitem__', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('__rmatmul__', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.argmax', ''),  # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1022,15 +1022,12 @@
     xfail('_masked.mean', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
     xfail('_masked.median', ''),  # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
     xfail('_masked.norm', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
-    xfail('_masked.normalize', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.prod', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.std', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
     xfail('_masked.sum', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.var', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
-    xfail('addmm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
-    xfail('addmm', 'decomposed'),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('addmv', ''),  # aten.addmv.default - couldn't find symbolic meta function/decomposition
     xfail('addr', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('aminmax', ''),  # aten.aminmax.default - couldn't find symbolic meta function/decomposition
@@ -1042,13 +1039,11 @@
     xfail('as_strided_scatter', ''),  # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('baddbmm', ''),  # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
     xfail('bernoulli', ''),  # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
-    xfail('bmm', ''),  # aten.bmm.default - couldn't find symbolic meta function/decomposition
     xfail('bucketize', ''),  # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
     xfail('cartesian_prod', ''),  # Tensors of type TensorImpl do not have numel
     xfail('cdist', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('cholesky_solve', ''),  # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
     xfail('chunk', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('clone', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
     xfail('column_stack', ''),  # Tensors of type TensorImpl do not have numel
     xfail('constant_pad_nd', ''),  # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
     xfail('count_nonzero', ''),  # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
@@ -1150,7 +1145,6 @@
     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.vector_norm', ''),  # TensorImpl do not have numel
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp', ''),  # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
     xfail('logcumsumexp', ''),  # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
@@ -1161,14 +1155,12 @@
     xfail('masked_fill', ''),  # expected predicate to be bool, got torch.float32
     xfail('masked_scatter', ''),  # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('masked_select', ''),  # aten.masked_select.default - couldn't find symbolic meta function/decomposition
-    xfail('matmul', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('matrix_exp', ''),  # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
     xfail('max', 'reduction_with_dim'),  # aten.max.dim - couldn't find symbolic meta function/decomposition
     xfail('median', ''),  # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
     xfail('meshgrid', 'list_of_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('meshgrid', 'variadic_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('min', 'reduction_with_dim'),  # aten.min.dim - couldn't find symbolic meta function/decomposition
-    xfail('mm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('mode', ''),  # aten.mode.default - couldn't find symbolic meta function/decomposition
     xfail('msort', ''),  # aten.sort.default - couldn't find symbolic meta function/decomposition
     xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
@@ -1213,7 +1205,6 @@
     xfail('nn.functional.local_response_norm', ''),  # Tensors of type TensorImpl do not have numel
     xfail('nn.functional.margin_ranking_loss', ''),  # The underlying op of 'aten.stride' has no overload name '_schema'
     xfail('nn.functional.max_pool1d', ''),  # Trying to call aten.size on a tensor with symbolic shapes.
-    xfail('nn.functional.max_pool2d', ''),  # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_pool3d', ''),  # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_unpool1d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
     xfail('nn.functional.max_unpool2d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
@@ -1234,7 +1225,6 @@
     xfail('nn.functional.unfold', ''),  # aten.im2col.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.upsample_bilinear', ''),  # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
     xfail('nn.functional.upsample_nearest', ''),  # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
-    xfail('norm', ''),  # TensorImpl does not have numel
     xfail('norm', 'nuc'),  # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
     xfail('normal', ''),  # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
     xfail('normal', 'number_mean'),  # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1337,7 +1327,9 @@
         return op.op(*args, **kwargs)
     sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
     new_f = None
-    for sample_input in sample_inputs_itr:
+
+    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
+    for sample_input in itertools.islice(sample_inputs_itr, 100):
         args = [sample_input.input] + list(sample_input.args)
         kwargs = sample_input.kwargs
 
@@ -1345,7 +1337,6 @@
             new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
         except DynamicOutputShapeException as e:
             self.skipTest("Dynamic output shape operation in trace")
-
         for arg in args:
             if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
                 arg.uniform_(0, 1)