Turn on meta converter for complex (#98869)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98869
Approved by: https://github.com/ngimel
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 08601a2..42abf2f 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2439,6 +2439,34 @@
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!
+ # many complex operators incorrect striding, metadata
+ xfail("pca_lowrank"),
+ xfail("norm", "nuc"),
+ xfail('fft.fft', ''),
+ xfail('fft.hfft2', ''),
+ xfail('fft.hfft', ''),
+ xfail('fft.hfftn', ''),
+ xfail('fft.ifft', ''),
+ xfail('fft.ihfft2', ''),
+ xfail('fft.ihfft', ''),
+ xfail('fft.ihfftn', ''),
+ xfail('fft.irfft2', ''),
+ xfail('fft.irfft', ''),
+ xfail('fft.irfftn', ''),
+ xfail('fft.rfft2', ''),
+ xfail('fft.rfft', ''),
+ xfail('fft.rfftn', ''),
+
+ xfail('linalg.svdvals', ''),
+ xfail('linalg.cond', ''),
+ xfail('linalg.matrix_norm', ''),
+ xfail('linalg.norm', ''),
+ xfail('linalg.norm', 'subgradients_at_zero'),
+ xfail('linalg.svd', ''),
+ xfail('svd', ''),
+ xfail('svd_lowrank', ''),
+ xfail('stft', ''),
+
# Misc
xfail('to_sparse'),
xfail('corrcoef'),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 173a58b..bac4c5b 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -125,6 +125,7 @@
inductor_skips = defaultdict(dict)
+
inductor_skips["cpu"] = {
"linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault
"linalg.ldl_factor": {f32, f64}, # flaky
@@ -208,6 +209,14 @@
"sparse.sampled_addmm": {f32, f64},
("sparse.mm", "reduce"): {bf16, f32, f64},
"stft": {f32, f64},
+ "svd": {f32, f64},
+ "svd_lowrank": {f32, f64},
+ "linalg.cond": {f32, f64},
+ "linalg.svd": {f32, f64},
+ "linalg.svdvals": {f32, f64},
+ "linalg.matrix_rank": {f32, f64},
+ "pca_lowrank": {f32, f64},
+ ("norm", "nuc"): {f32, f64},
"tensor_split": {b8, f16, f32, f64, i32, i64},
"to_sparse": {f32, f64},
# AssertionError: Tensor-likes are not close!
@@ -238,8 +247,8 @@
"fft.irfft2": {b8, f16, f32, f64, i32, i64},
"fft.irfftn": {b8, f16, f32, f64, i32, i64},
"fft.rfft": {f16, f32, f64, b8, i32, i64},
- "fft.rfft2": {f16, f32, f64},
- "fft.rfftn": {f16, f32, f64},
+ "fft.rfft2": {b8, f16, f32, f64, i32, i64},
+ "fft.rfftn": {b8, f16, f32, f64, i32, i64},
# These return complex tensors
"cdouble": {b8, i32, i64, f16, f32, f64},
"cfloat": {b8, i32, i64, f16, f32, f64},
@@ -271,6 +280,7 @@
"linalg.eigh": {f32, f64},
"linalg.eigvals": {f32, f64},
"linalg.eigvalsh": {f32, f64},
+ "linalg.householder_product": {f32, f64},
"linalg.lstsq": {f32, f64},
("linalg.lstsq", "grad_oriented"): {f32, f64},
"masked_scatter": {f16, f32, f64},
@@ -320,6 +330,11 @@
# (including _linalg_svd), possibly we should have something similar here
"linalg.cond": {f32, f64},
"linalg.svdvals": {f32, f64},
+ "linalg.matrix_rank": {f32, f64},
+ "linalg.svd": {f32, f64},
+ "pca_lowrank": {f32, f64},
+ "svd_lowrank": {f32, f64},
+ "svd": {f32, f64},
("norm", "nuc"): {f32, f64},
# AssertionError: Scalars are not close!
"nn.functional.soft_margin_loss": {f16},
@@ -339,8 +354,8 @@
"fft.irfft2": {b8, f16, f32, f64, i32, i64},
"fft.irfftn": {b8, f16, f32, f64, i32, i64},
"fft.rfft": {f16, f32, f64, b8, i32, i64},
- "fft.rfft2": {f16, f32, f64},
- "fft.rfftn": {f16, f32, f64},
+ "fft.rfft2": {b8, f16, f32, f64, i32, i64},
+ "fft.rfftn": {b8, f16, f32, f64, i32, i64},
# These return complex tensors
"cdouble": {b8, i32, i64, f16, f32, f64},
"cfloat": {b8, i32, i64, f16, f32, f64},
diff --git a/test/test_ops.py b/test/test_ops.py
index b662455..5547a5e 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -315,6 +315,8 @@
continue
except torch._subclasses.fake_tensor.DataDependentOutputException:
continue
+ except torch._subclasses.fake_tensor.UnsupportedOperatorException:
+ continue
if isinstance(result, torch.Tensor):
self.assertTrue(isinstance(meta_result, FakeTensor))
@@ -1976,7 +1978,7 @@
"linalg.svd",
}
-fake_backward_xfails = fake_tensor_stride_failing_ops | {
+fake_backward_skips = {
"linalg.cond",
"linalg.matrix_norm",
"linalg.norm",
@@ -1989,12 +1991,10 @@
"cholesky",
}
-fake_backward_xfails = {xfail(stride_skip) for stride_skip in fake_backward_xfails} | {
+fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
- xfail("norm", "nuc"),
- xfail("linalg.norm", "subgradients_at_zero"), # can accept vector inputs
skip('nn.functional.ctc_loss'),
-}
+} | {skip(stride_skip) for stride_skip in fake_tensor_stride_failing_ops}
fake_autocast_backward_xfails = {
skip("nn.functional.binary_cross_entropy"),
@@ -2065,6 +2065,8 @@
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
+ except torch._subclasses.fake_tensor.UnsupportedOperatorException:
+ pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(name in dynamic_output_op_tests or name in sometimes_dynamic_output_op_test)
except torch._subclasses.fake_tensor.DataDependentOutputException:
@@ -2150,12 +2152,15 @@
)
# TODO: enable check_aliasing, batch norm fails
- with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True):
- with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(False):
- composite_compliance.compute_expected_grads(
- op.get_op(), args, kwargs,
- sample.output_process_fn_grad,
- op.gradcheck_wrapper)
+ try:
+ with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True):
+ with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(False):
+ composite_compliance.compute_expected_grads(
+ op.get_op(), args, kwargs,
+ sample.output_process_fn_grad,
+ op.gradcheck_wrapper)
+ except torch._subclasses.fake_tensor.UnsupportedOperatorException:
+ pass
@onlyCUDA
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 49d0af7..928e84c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1330,6 +1330,22 @@
xfail('nanquantile'),
xfail('narrow'),
+ # many complex operators incorrect striding, metadata
+ skip('fft.fft', ''),
+ skip('fft.hfft2', ''),
+ skip('fft.hfft', ''),
+ skip('fft.hfftn', ''),
+ skip('fft.ifft', ''),
+ skip('fft.ihfft2', ''),
+ skip('fft.ihfft', ''),
+ skip('fft.ihfftn', ''),
+ skip('fft.irfft2', ''),
+ skip('fft.irfft', ''),
+ skip('fft.irfftn', ''),
+ skip('fft.rfft2', ''),
+ skip('fft.rfft', ''),
+ skip('fft.rfftn', ''),
+
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
@@ -1350,6 +1366,23 @@
xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
+
+ xfail('linalg.cond', ''),
+ xfail("linalg.matrix_norm"),
+ xfail("linalg.norm"),
+ xfail("linalg.matrix_norm"),
+ xfail("linalg.matrix_rank"),
+ xfail("linalg.norm"),
+ xfail("linalg.norm", "subgradients_at_zero"),
+ xfail("linalg.svd"),
+ xfail("linalg.svdvals"),
+
+ xfail("norm", "nuc"),
+ xfail("pca_lowrank"),
+ xfail("stft"),
+ xfail("svd"),
+ xfail("svd_lowrank"),
+ xfail("linalg.matrix_norm"),
}
symbolic_tensor_failures = {