[opinfo] item (#100313)
Follows #100223
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100313
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index fc5690b..7eef9e5 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3594,6 +3594,9 @@
xfail('addcmul'),
xfail('clamp'),
+ # TypeError: expected Tensor as element 0 in argument 0, but got float
+ xfail('item'),
+
# UBSAN: runtime error: shift exponent -1 is negative
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
@@ -3645,6 +3648,8 @@
xfail('take'),
xfail('tensor_split'),
xfail('to_sparse'),
+ # TypeError: expected Tensor as element 0 in argument 0, but got float
+ xfail('item'),
xfail('tril'), # Exception not raised on error input
xfail('triu'), # Exception not raised on error input
xfail('__getitem__', ''),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index aadb484..8446d81 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -169,6 +169,8 @@
"index_add": {f16},
"index_reduce": {f16, f32, f64},
"istft": {f32, f64},
+ # Unsupported: data dependent operator: aten._local_scalar_dense.default
+ "item": {b8, f16, f32, f64, i32, i64},
"linalg.eig": {f32, f64},
"linalg.eigh": {f32, f64},
"linalg.eigvals": {f32, f64},
@@ -275,6 +277,8 @@
"equal": {b8, f16, f32, f64, i32, i64},
"index_reduce": {f16, f32, f64},
"istft": {f32, f64},
+ # Unsupported: data dependent operator: aten._local_scalar_dense.default
+ "item": {b8, f16, f32, f64, i32, i64},
"linalg.eig": {f32, f64},
"linalg.eigh": {f32, f64},
"linalg.eigvals": {f32, f64},
diff --git a/test/test_decomp.py b/test/test_decomp.py
index ffbf11f..a9f3143 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -305,6 +305,9 @@
(None, None, "empty_like"),
(None, None, "empty"),
+ # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
+ (None, None, "item"),
+
# It's the only in-place op without an out-of-place equivalent in the Python API
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
(None, None, "zero_"),
diff --git a/test/test_meta.py b/test/test_meta.py
index 693636b..6a4c9da 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -603,7 +603,7 @@
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.ormqr : {f64, c64, c128, f32},
- torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
+ torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
diff --git a/test/test_mps.py b/test/test_mps.py
index 28eba43..69dceac 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -95,6 +95,8 @@
# 'bool' object is not iterable
'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32],
+ # 'float' object is not iterable
+ 'item': [torch.float16, torch.float32],
# "mse_backward_cpu_out" not implemented for 'Half'
'nn.functional.mse_loss': [torch.float16],
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
diff --git a/test/test_ops.py b/test/test_ops.py
index ecf638c..3ea07fa 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1746,7 +1746,6 @@
'_refs.equal',
'_refs.full',
'_refs.full_like',
- '_refs.item',
'_refs.to',
'_refs.ones',
'_refs.ones_like',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 67e45bf..24ad617 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1374,6 +1374,7 @@
skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow
+ skip('item'),
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index ba0decf..c412476 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -186,7 +186,7 @@
#
"clone",
"copy_to", # TODO: add OpInfo (or implement .to)
- "item", # TODO: add OpInfo
+ "item",
"to",
#
# Reduction ops
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b787a40..f47674c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -370,6 +370,36 @@
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
+
+def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+
+ cases = (
+ (),
+ (()),
+ (1),
+ ((1,)),
+ )
+
+ for shape in cases:
+ yield SampleInput(make_arg(shape))
+
+def error_inputs_item(op, device, **kwargs):
+ make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
+
+ cases = (
+ (M),
+ ((S,)),
+ (S, S),
+ (S, M, L),
+ )
+
+ for shape in cases:
+ yield ErrorInput(
+ SampleInput(make_arg(shape)), error_type=RuntimeError,
+ error_regex="elements cannot be converted to Scalar")
+
+
def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@@ -9021,6 +9051,28 @@
'test_reference_numerics_extremal_values',
dtypes=(torch.complex64, torch.complex128)),
)),
+ OpInfo('item',
+ op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs),
+ ref=np.ndarray.item,
+ method_variant=None,
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool),
+ supports_out=False,
+ supports_autograd=False,
+ error_inputs_func=error_inputs_item,
+ sample_inputs_func=sample_inputs_item,
+ skips=(
+ # Error testing item function variant
+ DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+ dtypes=(torch.float32, torch.complex64)),
+ # FX failed to normalize op - add the op to the op_skip list.
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # RuntimeError: Composite compliance check failed with the above error.
+ DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+ # Booleans mismatch: AssertionError: False is not true
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+ # Booleans mismatch: AssertionError: False is not true
+ DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+ )),
OpInfo('arange',
dtypes=all_types_and(torch.bfloat16, torch.float16),
supports_out=True,
@@ -18441,6 +18493,16 @@
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
+ PythonRefInfo(
+ "_refs.item",
+ torch_opinfo_name="item",
+ supports_nvfuser=False,
+ skips=(
+ # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
+ # ValueError: Can't convert a tensor with 10 elements to a number!
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),),
+ ),
ElementwiseUnaryPythonRefInfo(
"_refs.conj_physical",
torch_opinfo_name="conj_physical",