| # Owner(s): ["module: cpp"] |
| |
| |
| import os |
| |
| from cpp_api_parity import ( |
| functional_impl_check, |
| module_impl_check, |
| sample_functional, |
| sample_module, |
| ) |
| from cpp_api_parity.parity_table_parser import parse_parity_tracker_table |
| from cpp_api_parity.utils import is_torch_nn_functional_test |
| |
| import torch |
| import torch.testing._internal.common_nn as common_nn |
| import torch.testing._internal.common_utils as common |
| |
| |
| # NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose) |
| PRINT_CPP_SOURCE = False |
| |
| devices = ["cpu", "cuda"] |
| |
| PARITY_TABLE_PATH = os.path.join( |
| os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md" |
| ) |
| |
| parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH) |
| |
| |
| @torch.testing._internal.common_utils.markDynamoStrictTest |
| class TestCppApiParity(common.TestCase): |
| module_test_params_map = {} |
| functional_test_params_map = {} |
| |
| |
| expected_test_params_dicts = [] |
| |
| if not common.IS_ARM64: |
| for test_params_dicts, test_instance_class in [ |
| (sample_module.module_tests, common_nn.NewModuleTest), |
| (sample_functional.functional_tests, common_nn.NewModuleTest), |
| (common_nn.module_tests, common_nn.NewModuleTest), |
| (common_nn.new_module_tests, common_nn.NewModuleTest), |
| (common_nn.criterion_tests, common_nn.CriterionTest), |
| ]: |
| for test_params_dict in test_params_dicts: |
| if test_params_dict.get("test_cpp_api_parity", True): |
| if is_torch_nn_functional_test(test_params_dict): |
| functional_impl_check.write_test_to_test_class( |
| TestCppApiParity, |
| test_params_dict, |
| test_instance_class, |
| parity_table, |
| devices, |
| ) |
| else: |
| module_impl_check.write_test_to_test_class( |
| TestCppApiParity, |
| test_params_dict, |
| test_instance_class, |
| parity_table, |
| devices, |
| ) |
| expected_test_params_dicts.append(test_params_dict) |
| |
| # Assert that all NN module/functional test dicts appear in the parity test |
| assert len( |
| [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] |
| ) == len(expected_test_params_dicts) * len(devices) |
| |
| # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. |
| # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) |
| assert ( |
| len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4 |
| ) |
| # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) |
| assert ( |
| len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) |
| == 4 |
| ) |
| |
| module_impl_check.build_cpp_tests( |
| TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE |
| ) |
| functional_impl_check.build_cpp_tests( |
| TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE |
| ) |
| |
| if __name__ == "__main__": |
| common.TestCase._default_dtype_check_enabled = True |
| common.run_tests() |