[MPS] Fix SDP training (#134719)
Check whether the input tensors require grad. If required, then we don't get into the fast path and fall back to composite implicit.
Fixes #134678
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134719
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index cd28de9..281c68b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9360,18 +9360,37 @@
err = ((y - ref).abs() / denom).mean().item()
self.assertLess(err, 0.01)
- def _test_sdpa_no_mask(self, is_causal: bool, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
+ def _test_sdpa_no_mask(
+ self,
+ is_causal: bool,
+ dtype: torch.dtype,
+ L: int = 1,
+ S: int = 72,
+ NH: int = 32,
+ HS: int = 128,
+ requires_grad: bool = False
+ ):
+
torch.manual_seed(1729)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
- q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
+ q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
+ q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
+ k_cpu = k.cpu()
+ v_cpu = v.cpu()
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
- y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal)
+ y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
self._compare_tensors(y.cpu(), y_ref)
+ if requires_grad and torch.is_grad_enabled():
+ y.sum().backward()
+ y_ref.sum().backward()
+
+ self._compare_tensors(q.grad.cpu(), q_cpu.grad)
+
def test_sdpa_no_mask_no_causal_fp32(self):
self._test_sdpa_no_mask(False, torch.float32)
@@ -9393,6 +9412,12 @@
def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
+ def test_sdpa_no_mask_no_causal_fp32_grad(self):
+ self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
+
+ with torch.no_grad():
+ self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
+
def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
torch.manual_seed(1729)
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
@@ -9421,7 +9446,7 @@
self._test_sdpa_mask(torch.float16, 6)
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
- self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
+ self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
class TestGatherScatter(TestCaseMPS):