Added complex support for `torch.logsumexp` (#133187)
Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`.
Fixes #133047
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187
Approved by: https://github.com/amjames, https://github.com/lezcano
diff --git a/test/test_mps.py b/test/test_mps.py
index f7f36e5..f3dd38d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -438,6 +438,7 @@
'logical_not',
'logical_or',
'logical_xor',
+ 'logsumexp',
'long',
'masked_fill',
'masked.mean',
@@ -445,6 +446,7 @@
'masked.std',
'masked.sum',
'masked.var',
+ 'masked.logsumexp',
'matmul',
'mean',
'mm',
@@ -6540,6 +6542,18 @@
helper((2, 8, 4, 5))
+ def test_logsumexp(self):
+ def helper(shape):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
+ x = cpu_x.detach().clone().to('mps')
+
+ log_result = torch.logsumexp(x, -1)
+ log_result_cpu = torch.logsumexp(cpu_x, -1)
+
+ self.assertEqual(log_result, log_result_cpu)
+
+ helper((2, 8, 4, 5))
+
# Test concat forward
def test_cat2(self):