Ported matmul compositeimplicitautograd impl into core (#85239)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e238bdb..024e729 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -22,6 +22,7 @@
 
 import types
 import functools
+import itertools
 
 aten = torch.ops.aten
 
@@ -1010,7 +1011,6 @@
     xfail('linalg.eigvals'),
     skip('_masked.logsumexp', ''),  # Tensors of type TensorImpl do not have numel
     xfail('__getitem__', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('__rmatmul__', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.amin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.argmax', ''),  # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1022,15 +1022,12 @@
     xfail('_masked.mean', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
     xfail('_masked.median', ''),  # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
     xfail('_masked.norm', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
-    xfail('_masked.normalize', ''),  # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.prod', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmax', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.softmin', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.std', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
     xfail('_masked.sum', ''),  # aten._to_copy.default - couldn't find symbolic meta function/decomposition
     xfail('_masked.var', ''),  # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
-    xfail('addmm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
-    xfail('addmm', 'decomposed'),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('addmv', ''),  # aten.addmv.default - couldn't find symbolic meta function/decomposition
     xfail('addr', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('aminmax', ''),  # aten.aminmax.default - couldn't find symbolic meta function/decomposition
@@ -1042,13 +1039,11 @@
     xfail('as_strided_scatter', ''),  # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('baddbmm', ''),  # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
     xfail('bernoulli', ''),  # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
-    xfail('bmm', ''),  # aten.bmm.default - couldn't find symbolic meta function/decomposition
     xfail('bucketize', ''),  # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
     xfail('cartesian_prod', ''),  # Tensors of type TensorImpl do not have numel
     xfail('cdist', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('cholesky_solve', ''),  # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
     xfail('chunk', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('clone', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
     xfail('column_stack', ''),  # Tensors of type TensorImpl do not have numel
     xfail('constant_pad_nd', ''),  # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
     xfail('count_nonzero', ''),  # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
@@ -1150,7 +1145,6 @@
     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.vector_norm', ''),  # TensorImpl do not have numel
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp', ''),  # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
     xfail('logcumsumexp', ''),  # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
@@ -1161,14 +1155,12 @@
     xfail('masked_fill', ''),  # expected predicate to be bool, got torch.float32
     xfail('masked_scatter', ''),  # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
     xfail('masked_select', ''),  # aten.masked_select.default - couldn't find symbolic meta function/decomposition
-    xfail('matmul', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
     xfail('matrix_exp', ''),  # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
     xfail('max', 'reduction_with_dim'),  # aten.max.dim - couldn't find symbolic meta function/decomposition
     xfail('median', ''),  # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
     xfail('meshgrid', 'list_of_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('meshgrid', 'variadic_tensors'),  # Tensors of type TensorImpl do not have numel
     xfail('min', 'reduction_with_dim'),  # aten.min.dim - couldn't find symbolic meta function/decomposition
-    xfail('mm', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition
     xfail('mode', ''),  # aten.mode.default - couldn't find symbolic meta function/decomposition
     xfail('msort', ''),  # aten.sort.default - couldn't find symbolic meta function/decomposition
     xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
@@ -1213,7 +1205,6 @@
     xfail('nn.functional.local_response_norm', ''),  # Tensors of type TensorImpl do not have numel
     xfail('nn.functional.margin_ranking_loss', ''),  # The underlying op of 'aten.stride' has no overload name '_schema'
     xfail('nn.functional.max_pool1d', ''),  # Trying to call aten.size on a tensor with symbolic shapes.
-    xfail('nn.functional.max_pool2d', ''),  # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_pool3d', ''),  # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
     xfail('nn.functional.max_unpool1d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
     xfail('nn.functional.max_unpool2d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
@@ -1234,7 +1225,6 @@
     xfail('nn.functional.unfold', ''),  # aten.im2col.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.upsample_bilinear', ''),  # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
     xfail('nn.functional.upsample_nearest', ''),  # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
-    xfail('norm', ''),  # TensorImpl does not have numel
     xfail('norm', 'nuc'),  # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
     xfail('normal', ''),  # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
     xfail('normal', 'number_mean'),  # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1337,7 +1327,9 @@
         return op.op(*args, **kwargs)
     sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
     new_f = None
-    for sample_input in sample_inputs_itr:
+
+    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
+    for sample_input in itertools.islice(sample_inputs_itr, 100):
         args = [sample_input.input] + list(sample_input.args)
         kwargs = sample_input.kwargs
 
@@ -1345,7 +1337,6 @@
             new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
         except DynamicOutputShapeException as e:
             self.skipTest("Dynamic output shape operation in trace")
-
         for arg in args:
             if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
                 arg.uniform_(0, 1)