Ported matmul compositeimplicitautograd impl into core (#85239)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e238bdb..024e729 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -22,6 +22,7 @@
import types
import functools
+import itertools
aten = torch.ops.aten
@@ -1010,7 +1011,6 @@
xfail('linalg.eigvals'),
skip('_masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1022,15 +1022,12 @@
xfail('_masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
xfail('_masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
xfail('_masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
- xfail('_masked.normalize', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('_masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('_masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
- xfail('addmm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
- xfail('addmm', 'decomposed'), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
@@ -1042,13 +1039,11 @@
xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
- xfail('bmm', ''), # aten.bmm.default - couldn't find symbolic meta function/decomposition
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('clone', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
@@ -1150,7 +1145,6 @@
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.vector_norm', ''), # TensorImpl do not have numel
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
@@ -1161,14 +1155,12 @@
xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32
xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
- xfail('matmul', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition
xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel
xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel
xfail('min', 'reduction_with_dim'), # aten.min.dim - couldn't find symbolic meta function/decomposition
- xfail('mm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
@@ -1213,7 +1205,6 @@
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
- xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
@@ -1234,7 +1225,6 @@
xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
- xfail('norm', ''), # TensorImpl does not have numel
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1337,7 +1327,9 @@
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
- for sample_input in sample_inputs_itr:
+
+ # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
+ for sample_input in itertools.islice(sample_inputs_itr, 100):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
@@ -1345,7 +1337,6 @@
new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
except DynamicOutputShapeException as e:
self.skipTest("Dynamic output shape operation in trace")
-
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)