Complete decomposition for aten.round (#118635)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118635
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index 2ea3fbf..98d1d6c 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -419,6 +419,10 @@
aten::repeat.out
aten::roll
aten::roll.out
+aten::round
+aten::round.decimals
+aten::round.decimals_out
+aten::round.out
aten::round_
aten::round_.decimals
aten::rrelu_with_noise_backward
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 9f7e2ed..8cbae2f 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1115,10 +1115,6 @@
aten::resize_as_sparse
aten::resize_as_sparse.out
aten::resize_as_sparse_
-aten::round
-aten::round.decimals
-aten::round.decimals_out
-aten::round.out
aten::row_indices
aten::row_indices_copy
aten::row_indices_copy.out
diff --git a/test/test_ops.py b/test/test_ops.py
index e73043c..74a1f86 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1956,7 +1956,6 @@
'_refs.sum_to_size',
# ref implementation missing kwargs
'_refs.full_like', # missing "layout"
- '_refs.round', # missing "decimals"
'_refs.scalar_tensor', # missing "layout"
# other
'_refs.block_diag', # only refs._block_diag_iterable is in decomposition table
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 93073e3..5571a40 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1898,9 +1898,6 @@
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('linalg.norm', ''),
- xfail('round', 'decimals_0'), # Cannot call numel() on tensor with symbolic sizes/strides
- xfail('round', 'decimals_3'), # Cannot call numel() on tensor with symbolic sizes/strides
- xfail('round', 'decimals_neg_3'), # Cannot call numel() on tensor with symbolic sizes/strides
}
inplace_symbolic_tensor_failures = {
@@ -1952,10 +1949,6 @@
xfail('nn.functional.avg_pool2d', ''),
xfail('nn.functional.linear', ''),
xfail('qr', ''),
- xfail('round', ''),
- xfail('round', 'decimals_0'),
- xfail('round', 'decimals_3'),
- xfail('round', 'decimals_neg_3'),
xfail('scatter_add', ''),
xfail('scatter', ''),
xfail('sort', ''),
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 258f51d..7ca4fb3 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -752,12 +752,13 @@
return make_pointwise(fn)(x)
-@register_lowering(aten.round)
+@register_lowering(aten.round.default)
def round(x):
if is_integer_type(x):
return clone(x)
- fn = ops_wrapper("round")
- return make_pointwise(fn)(x)
+ else:
+ fn = ops_wrapper("round")
+ return make_pointwise(fn)(x)
@register_lowering(aten.trunc)
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index c3a5d83..3a6f9f1 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -870,13 +870,19 @@
return prims.reciprocal(a)
-# TODO: round takes additional kwargs
-@_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- aten_op=None, # TODO: this does need a decomp, but kwarg handling is needed
+@register_decomposition(aten.round)
+@out_wrapper()
+@elementwise_type_promotion_wrapper(
+ type_promoting_args=("a",),
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
-def round(a):
- return prims.round(a)
+def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
+ if decimals == 0:
+ return prims.round(a)
+ else:
+ ten_pow = 10**decimals
+ ten_neg_pow = 10 ** (-decimals)
+ return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)