Print output during MPS test import tests (#79163)
Simplify `test_no_warnings_on_input` to simply capture any output.
Copy its implementation to `test_testing.py` as this is not specific to MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79163
Approved by: https://github.com/janeyx99, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index a7ae3ba..ebf56a9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4358,24 +4358,15 @@
self.assertEqual(cpu_weight_grad, mps_weight_grad)
class TestFallbackWarning(TestCase):
+ # TODO: Remove once test_testing.py is running on MPS devices
def test_no_warning_on_import(self):
- script = """
-import warnings
-
-with warnings.catch_warnings(record=True) as w:
- import torch
-
-exit(len(w))
-"""
- try:
- subprocess.check_output(
- [sys.executable, '-W', 'all', '-c', script],
- stderr=subprocess.STDOUT,
- # On Windows, opening the subprocess with the default CWD makes `import torch`
- # fail, so just set CWD to this script's directory
- cwd=os.path.dirname(os.path.realpath(__file__)),)
- except subprocess.CalledProcessError as e:
- self.assertTrue(False, "There was a warning when importing torch.")
+ out = subprocess.check_output(
+ [sys.executable, "-W", "all", "-c", "import torch"],
+ stderr=subprocess.STDOUT,
+ # On Windows, opening the subprocess with the default CWD makes `import torch`
+ # fail, so just set CWD to this script's directory
+ cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
+ self.assertEquals(out, "")
def _get_not_implemented_op(self):
# This can be changed once we actually implement `torch.bincount`
@@ -4402,6 +4393,7 @@
import torch
if len(w) > 0:
+ print(w)
exit(1)
# This should run just fine and raise warning about perf
@@ -4409,6 +4401,7 @@
{op}
if len(w) != 1:
+ print(w)
exit(2)
"""
@@ -4421,12 +4414,14 @@
cwd=os.path.dirname(os.path.realpath(__file__)),)
except subprocess.CalledProcessError as e:
if e.returncode == 1:
- self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set.")
+ self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
+ e.output.decode("utf-8"))
elif e.returncode == 2:
self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
- "PYTORCH_ENABLE_MPS_FALLBACK set.")
+ f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
else:
- self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set.")
+ self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
+ e.output.decode("utf-8"))
class TestNoRegression(TestCase):
def test_assert_close(self):