Complete decomposition for aten.round (#118635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118635
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index 2ea3fbf..98d1d6c 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -419,6 +419,10 @@
 aten::repeat.out
 aten::roll
 aten::roll.out
+aten::round
+aten::round.decimals
+aten::round.decimals_out
+aten::round.out
 aten::round_
 aten::round_.decimals
 aten::rrelu_with_noise_backward
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 9f7e2ed..8cbae2f 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -1115,10 +1115,6 @@
 aten::resize_as_sparse
 aten::resize_as_sparse.out
 aten::resize_as_sparse_
-aten::round
-aten::round.decimals
-aten::round.decimals_out
-aten::round.out
 aten::row_indices
 aten::row_indices_copy
 aten::row_indices_copy.out
diff --git a/test/test_ops.py b/test/test_ops.py
index e73043c..74a1f86 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1956,7 +1956,6 @@
         '_refs.sum_to_size',
         # ref implementation missing kwargs
         '_refs.full_like',  # missing "layout"
-        '_refs.round',  # missing "decimals"
         '_refs.scalar_tensor',  # missing "layout"
         # other
         '_refs.block_diag',  # only refs._block_diag_iterable is in decomposition table
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 93073e3..5571a40 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1898,9 +1898,6 @@
     xfail('i0', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition
 
     xfail('linalg.norm', ''),
-    xfail('round', 'decimals_0'),  # Cannot call numel() on tensor with symbolic sizes/strides
-    xfail('round', 'decimals_3'),  # Cannot call numel() on tensor with symbolic sizes/strides
-    xfail('round', 'decimals_neg_3'),  # Cannot call numel() on tensor with symbolic sizes/strides
 }
 
 inplace_symbolic_tensor_failures = {
@@ -1952,10 +1949,6 @@
     xfail('nn.functional.avg_pool2d', ''),
     xfail('nn.functional.linear', ''),
     xfail('qr', ''),
-    xfail('round', ''),
-    xfail('round', 'decimals_0'),
-    xfail('round', 'decimals_3'),
-    xfail('round', 'decimals_neg_3'),
     xfail('scatter_add', ''),
     xfail('scatter', ''),
     xfail('sort', ''),
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 258f51d..7ca4fb3 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -752,12 +752,13 @@
     return make_pointwise(fn)(x)
 
 
-@register_lowering(aten.round)
+@register_lowering(aten.round.default)
 def round(x):
     if is_integer_type(x):
         return clone(x)
-    fn = ops_wrapper("round")
-    return make_pointwise(fn)(x)
+    else:
+        fn = ops_wrapper("round")
+        return make_pointwise(fn)(x)
 
 
 @register_lowering(aten.trunc)
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index c3a5d83..3a6f9f1 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -870,13 +870,19 @@
     return prims.reciprocal(a)
 
 
-# TODO: round takes additional kwargs
-@_make_elementwise_unary_reference(
-    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
-    aten_op=None,  # TODO: this does need a decomp, but kwarg handling is needed
+@register_decomposition(aten.round)
+@out_wrapper()
+@elementwise_type_promotion_wrapper(
+    type_promoting_args=("a",),
+    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 )
-def round(a):
-    return prims.round(a)
+def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
+    if decimals == 0:
+        return prims.round(a)
+    else:
+        ten_pow = 10**decimals
+        ten_neg_pow = 10 ** (-decimals)
+        return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
 
 
 @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)