[reland][opinfo] empty_strided (#101782)
Follows #100223
Previous PR: #100890
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101782
Approved by: https://github.com/ezyang
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 60f6879..f7c36ca 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -255,6 +255,8 @@
"fft.rfft": {f16, f32, f64, b8, i32, i64},
"fft.rfft2": {b8, f16, f32, f64, i32, i64},
"fft.rfftn": {b8, f16, f32, f64, i32, i64},
+ # AssertionError: Scalars are not close!
+ "empty_strided": {b8, i32, i64, f16, f32, f64},
# These return complex tensors
"cdouble": {b8, i32, i64, f16, f32, f64},
"cfloat": {b8, i32, i64, f16, f32, f64},
diff --git a/test/test_decomp.py b/test/test_decomp.py
index a97d73d..81dde16 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -336,6 +336,8 @@
(None, None, "native_batch_norm"),
(None, None, "_upsample_bilinear2d_aa"),
+
+ (None, None, "empty_strided"), # aten.empty_strided was not decomposed
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
diff --git a/test/test_mps.py b/test/test_mps.py
index b269c72..d6e0564 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -710,6 +710,7 @@
'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
torch.int32, torch.int64, torch.uint8, torch.int8],
+ 'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# CPU: empty is returning all 0's and there is a mismatch with MPS
# allocation (MacOS 13). According to
# https://pytorch.org/docs/2.0/generated/torch.empty.html
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 7fac910..961a14c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1409,6 +1409,9 @@
skip('to_sparse'),
# segfaults
skip('block_diag'),
+
+ # AssertionError: Tensor-likes are not close!
+ skip('empty_strided', '', device_type='cpu'),
}
fake_tensor_failures = {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8df9b41..1cf3e6a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1611,6 +1611,18 @@
else:
yield SampleInput(t, output_shape, **kwargs)
+def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs):
+
+ inputs = [
+ ((), (), {}),
+ ((S,), (4,), {'dtype': dtype, 'device': device}),
+ ((S, S), (2, 1), {'dtype': dtype, 'device': device}),
+ ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}),
+ ]
+
+ for shape, strides, kwargs in inputs:
+ yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs)
+
def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
# shape
cases = (
@@ -15786,6 +15798,29 @@
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
)),
+ OpInfo('empty_strided',
+ op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs),
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half),
+ supports_out=False,
+ supports_autograd=False,
+ sample_inputs_func=sample_inputs_empty_strided,
+ skips=(
+ # FX failed to normalize op - add the op to the op_skip list.
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # AssertionError: JIT Test does not execute any logic
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+ # Empty tensor data is garbage so it's hard to make comparisons with it.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+ DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
+ DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'),
+ # Lazy tensor failures
+ DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'),
+ )),
OpInfo('empty',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
sample_inputs_func=sample_inputs_empty,
@@ -20200,6 +20235,7 @@
PythonRefInfo(
"_refs.new_empty_strided",
torch_opinfo_name="new_empty_strided",
+ supports_nvfuser=False,
skips=(
DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
'TestCommon',
@@ -20224,6 +20260,32 @@
),
),
PythonRefInfo(
+ "_refs.empty_strided",
+ torch_opinfo_name="empty_strided",
+ supports_nvfuser=False,
+ skips=(
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestCommon',
+ 'test_python_ref'),
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestCommon',
+ 'test_python_ref_torch_fallback'),
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestMathBits',
+ 'test_conj_view'),
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestMathBits',
+ 'test_neg_conj_view'),
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestMathBits',
+ 'test_neg_view'),
+ DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+ 'TestCommon',
+ 'test_python_ref_executor'),
+
+ ),
+ ),
+ PythonRefInfo(
"_refs.new_full",
torch_opinfo_name="new_full",
supports_nvfuser=False,