[MPS] Fix std/var cache issue (#80502)
Use `getTensorsStringKey` which has tensor shape info added as part of the key to prevent cache lookup issue when the shape of input tensor is changed.
Fixes #80499
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80502
Approved by: https://github.com/malfet, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 200fd1f..a61a76d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2796,6 +2796,8 @@
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
helper((4, 5, 6, 7))
+ # verify if a change in shape of input would cause problems with graph caching
+ helper((9, 5, 6, 7))
# Test var
def test_var(self):
@@ -2894,6 +2896,8 @@
self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
helper((4, 5, 6, 7))
+ # verify if a change in shape of input would cause problems with graph caching
+ helper((9, 5, 6, 7))
# Test forward amax
def test_amax(self):