[MPS] Add glu (#79866)
Adds mps op for `aten::glu.out`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79866
Approved by: https://github.com/kulinseth, https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index e41138f..200fd1f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3542,8 +3542,30 @@
for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
helper(shape, alpha)
- # Test softplus
+ # Test glu
+ def test_glu(self):
+ def helper(shape, dim=0):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+ for activation_func in [torch.nn.GLU(dim=dim)]:
+ glu_result = activation_func(x)
+ glu_result_cpu = activation_func(cpu_x)
+
+ cpu_grad = torch.randn(glu_result_cpu.shape)
+ grad = cpu_grad.to('mps')
+
+ glu_result.backward(gradient=grad)
+ glu_result_cpu.backward(gradient=cpu_grad)
+
+ self.assertEqual(glu_result, glu_result_cpu)
+ self.assertEqual(x.grad, cpu_x.grad)
+
+ for shape in [[4], (2, 4), (2, 8, 4, 6)]:
+ for dim in range(len(shape)):
+ helper(shape, dim)
+
+ # Test softplus
def test_softplus(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)