| # Owner(s): ["module: mtia"] |
| |
| import os |
| import shutil |
| import sys |
| import tempfile |
| import unittest |
| |
| import torch |
| import torch.testing._internal.common_utils as common |
| import torch.utils.cpp_extension |
| from torch.testing._internal.common_utils import ( |
| IS_ARM64, |
| IS_LINUX, |
| skipIfTorchDynamo, |
| TEST_CUDA, |
| TEST_PRIVATEUSE1, |
| TEST_XPU, |
| ) |
| from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME |
| |
| |
| # define TEST_ROCM before changing TEST_CUDA |
| TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None |
| TEST_CUDA = TEST_CUDA and CUDA_HOME is not None |
| |
| |
| def remove_build_path(): |
| if sys.platform == "win32": |
| # Not wiping extensions build folder because Windows |
| return |
| default_build_root = torch.utils.cpp_extension.get_default_build_root() |
| if os.path.exists(default_build_root): |
| shutil.rmtree(default_build_root, ignore_errors=True) |
| |
| |
| @unittest.skipIf( |
| IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU, |
| "Only on linux platform and mutual exclusive to other backends", |
| ) |
| @torch.testing._internal.common_utils.markDynamoStrictTest |
| class TestCppExtensionMTIABackend(common.TestCase): |
| """Tests MTIA backend with C++ extensions.""" |
| |
| module = None |
| |
| def setUp(self): |
| super().setUp() |
| # cpp extensions use relative paths. Those paths are relative to |
| # this file, so we'll change the working directory temporarily |
| self.old_working_dir = os.getcwd() |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) |
| |
| def tearDown(self): |
| super().tearDown() |
| # return the working directory (see setUp) |
| os.chdir(self.old_working_dir) |
| |
| @classmethod |
| def tearDownClass(cls): |
| remove_build_path() |
| |
| @classmethod |
| def setUpClass(cls): |
| remove_build_path() |
| build_dir = tempfile.mkdtemp() |
| # Load the fake device guard impl. |
| cls.module = torch.utils.cpp_extension.load( |
| name="mtia_extension", |
| sources=["cpp_extensions/mtia_extension.cpp"], |
| build_directory=build_dir, |
| extra_include_paths=[ |
| "cpp_extensions", |
| "path / with spaces in it", |
| "path with quote'", |
| ], |
| is_python_module=False, |
| verbose=True, |
| ) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| def test_get_device_module(self): |
| device = torch.device("mtia:0") |
| default_stream = torch.get_device_module(device).current_stream() |
| self.assertEqual( |
| default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA) |
| ) |
| print(torch._C.Stream.__mro__) |
| print(torch.cuda.Stream.__mro__) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| def test_stream_basic(self): |
| default_stream = torch.mtia.current_stream() |
| user_stream = torch.mtia.Stream() |
| self.assertEqual(torch.mtia.current_stream(), default_stream) |
| self.assertNotEqual(default_stream, user_stream) |
| # Check mtia_extension.cpp, default stream id starts from 0. |
| self.assertEqual(default_stream.stream_id, 0) |
| self.assertNotEqual(user_stream.stream_id, 0) |
| with torch.mtia.stream(user_stream): |
| self.assertEqual(torch.mtia.current_stream(), user_stream) |
| self.assertTrue(user_stream.query()) |
| default_stream.synchronize() |
| self.assertTrue(default_stream.query()) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| def test_stream_context(self): |
| mtia_stream_0 = torch.mtia.Stream(device="mtia:0") |
| mtia_stream_1 = torch.mtia.Stream(device="mtia:0") |
| print(mtia_stream_0) |
| print(mtia_stream_1) |
| with torch.mtia.stream(mtia_stream_0): |
| current_stream = torch.mtia.current_stream() |
| msg = f"current_stream {current_stream} should be {mtia_stream_0}" |
| self.assertTrue(current_stream == mtia_stream_0, msg=msg) |
| |
| with torch.mtia.stream(mtia_stream_1): |
| current_stream = torch.mtia.current_stream() |
| msg = f"current_stream {current_stream} should be {mtia_stream_1}" |
| self.assertTrue(current_stream == mtia_stream_1, msg=msg) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| def test_stream_context_different_device(self): |
| device_0 = torch.device("mtia:0") |
| device_1 = torch.device("mtia:1") |
| mtia_stream_0 = torch.mtia.Stream(device=device_0) |
| mtia_stream_1 = torch.mtia.Stream(device=device_1) |
| print(mtia_stream_0) |
| print(mtia_stream_1) |
| orig_current_device = torch.mtia.current_device() |
| with torch.mtia.stream(mtia_stream_0): |
| current_stream = torch.mtia.current_stream() |
| self.assertTrue(torch.mtia.current_device() == device_0.index) |
| msg = f"current_stream {current_stream} should be {mtia_stream_0}" |
| self.assertTrue(current_stream == mtia_stream_0, msg=msg) |
| self.assertTrue(torch.mtia.current_device() == orig_current_device) |
| with torch.mtia.stream(mtia_stream_1): |
| current_stream = torch.mtia.current_stream() |
| self.assertTrue(torch.mtia.current_device() == device_1.index) |
| msg = f"current_stream {current_stream} should be {mtia_stream_1}" |
| self.assertTrue(current_stream == mtia_stream_1, msg=msg) |
| self.assertTrue(torch.mtia.current_device() == orig_current_device) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| def test_device_context(self): |
| device_0 = torch.device("mtia:0") |
| device_1 = torch.device("mtia:1") |
| with torch.mtia.device(device_0): |
| self.assertTrue(torch.mtia.current_device() == device_0.index) |
| |
| with torch.mtia.device(device_1): |
| self.assertTrue(torch.mtia.current_device() == device_1.index) |
| |
| |
| if __name__ == "__main__": |
| common.run_tests() |