[MPS] Fix memory leak in copy_from_mps_ (#114197)
By always calling `[destBuffer release]` before leaving the scope in which it was allocated.
Leak was introduced by https://github.com/pytorch/pytorch/pull/84928
Add regression test.
Before the change:
```
% python ../test/test_mps.py -v -k test_copy_cast_no_leak --repeat 10
test_copy_cast_no_leak (__main__.TestMemoryLeak) ... FAIL
======================================================================
FAIL: test_copy_cast_no_leak (__main__.TestMemoryLeak)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
method(*args, **kwargs)
File "/Users/nshulga/git/pytorch/pytorch/build/../test/test_mps.py", line 1064, in test_copy_cast_no_leak
self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
AssertionError: False is not true : Detected 65536 bytes leak of GPU memory
To execute this test, run the following from the base repo dir:
python test/test_mps.py -k test_copy_cast_no_leak
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 1.102s
FAILED (failures=1)
```
After:
```
% python ../test/test_mps.py -k test_copy_cast_no_leak --repeat 10
.
----------------------------------------------------------------------
Ran 1 test in 0.819s
OK
.
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
.
----------------------------------------------------------------------
Ran 1 test in 0.002s
OK
...
```
Fixes https://github.com/pytorch/pytorch/issues/114096
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114197
Approved by: https://github.com/kit1980
diff --git a/test/test_mps.py b/test/test_mps.py
index 600c7bf..240977e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1053,6 +1053,16 @@
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()
+ def test_copy_cast_no_leak(self):
+ a = torch.randn(128, 128, device='mps', dtype=torch.float16)
+ torch.mps.empty_cache()
+ driver_before = torch.mps.driver_allocated_memory()
+ a = a.to(device='cpu', dtype=torch.float32)
+ a = a.to(device='mps', dtype=torch.float16)
+ torch.mps.empty_cache()
+ driver_after = torch.mps.driver_allocated_memory()
+ self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
+
class TestPixelShuffle(TestCaseMPS):
def test_pixel_shuffle_unshuffle(self):