[MPS] Fix SDP training (#134719)

Check whether the input tensors require grad. If required, then we don't get into the fast path and fall back to composite implicit.

Fixes #134678
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134719
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index cd28de9..281c68b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9360,18 +9360,37 @@
         err = ((y - ref).abs() / denom).mean().item()
         self.assertLess(err, 0.01)
 
-    def _test_sdpa_no_mask(self, is_causal: bool, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
+    def _test_sdpa_no_mask(
+        self,
+        is_causal: bool,
+        dtype: torch.dtype,
+        L: int = 1,
+        S: int = 72,
+        NH: int = 32,
+        HS: int = 128,
+        requires_grad: bool = False
+    ):
+
         torch.manual_seed(1729)
         with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
-            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
+            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
             k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
             v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
+            q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
+            k_cpu = k.cpu()
+            v_cpu = v.cpu()
 
             y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
-            y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal)
+            y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
 
             self._compare_tensors(y.cpu(), y_ref)
 
+            if requires_grad and torch.is_grad_enabled():
+                y.sum().backward()
+                y_ref.sum().backward()
+
+                self._compare_tensors(q.grad.cpu(), q_cpu.grad)
+
     def test_sdpa_no_mask_no_causal_fp32(self):
         self._test_sdpa_no_mask(False, torch.float32)
 
@@ -9393,6 +9412,12 @@
     def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
         self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
 
+    def test_sdpa_no_mask_no_causal_fp32_grad(self):
+        self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
+
+        with torch.no_grad():
+            self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
+
     def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
         torch.manual_seed(1729)
         causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
@@ -9421,7 +9446,7 @@
         self._test_sdpa_mask(torch.float16, 6)
 
     def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
-        self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
+        self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
 
 
 class TestGatherScatter(TestCaseMPS):