[MPS] Add argmin (#80828)
This PR
1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.
Co-authored by Kulin Seth <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80828
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index f8614ea..dc857a3 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2187,9 +2187,14 @@
helper((2, 3, 4, 5))
- # Test forward argmax
- def test_argmax(self):
- def helper(n, c, h, w, dtype=torch.float32):
+ # Test forward argmin argmax
+ def test_argmin_argmax(self):
+ def helper(n, c, h, w, reduction_type, dtype=torch.float32):
+ if reduction_type == "max":
+ arg_reduction_fn = torch.argmax
+ else:
+ arg_reduction_fn = torch.argmin
+
cpu_x = None
x = None
if(dtype not in [torch.float32, torch.bool]):
@@ -2202,46 +2207,50 @@
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
- y = torch.argmax(x)
- ref_y = torch.argmax(cpu_x)
+ y = arg_reduction_fn(x)
+ ref_y = arg_reduction_fn(cpu_x)
self.assertEqual(y, ref_y)
- y_0 = torch.argmax(x, dim=0)
- refy_0 = torch.argmax(cpu_x, dim=0)
+ y_0 = arg_reduction_fn(x, dim=0)
+ refy_0 = arg_reduction_fn(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
- y_0dim = torch.argmax(x, dim=0, keepdim=True)
- refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True)
+ y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
+ refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
- y_1 = torch.argmax(x, dim=1)
- refy_1 = torch.argmax(cpu_x, dim=1)
+ y_1 = arg_reduction_fn(x, dim=1)
+ refy_1 = arg_reduction_fn(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
- y_1dim = torch.argmax(x, dim=1, keepdim=True)
- refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True)
+ y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
+ refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
- y_2 = torch.argmax(x, dim=2)
- refy_2 = torch.argmax(cpu_x, dim=2)
+ y_2 = arg_reduction_fn(x, dim=2)
+ refy_2 = arg_reduction_fn(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
- y_2dim = torch.argmax(x, dim=2, keepdim=True)
- refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True)
+ y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
+ refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
- y_3 = torch.argmax(x, dim=3)
- refy_3 = torch.argmax(cpu_x, dim=3)
+ y_3 = arg_reduction_fn(x, dim=3)
+ refy_3 = arg_reduction_fn(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
- y_3dim = torch.argmax(x, dim=3, keepdim=True)
- refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True)
+ y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
+ refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
- helper(2, 8, 4, 4, torch.float32)
- helper(2, 8, 4, 4, torch.int32)
- helper(2, 8, 4, 4, torch.float16)
- helper(2, 8, 4, 4, torch.int64)
+ helper(2, 8, 4, 4, "max", torch.float32)
+ helper(2, 8, 4, 4, "max", torch.int32)
+ helper(2, 8, 4, 4, "max", torch.float16)
+ helper(2, 8, 4, 4, "max", torch.int64)
+ helper(2, 8, 4, 4, "min", torch.float32)
+ helper(2, 8, 4, 4, "min", torch.int32)
+ helper(2, 8, 4, 4, "min", torch.float16)
+ helper(2, 8, 4, 4, "min", torch.int64)
# Test forward max
# Note - don't test grad now