| #include <gtest/gtest.h> |
| |
| #include <torch/cuda.h> |
| |
| #include <iostream> |
| #include <string> |
| |
| std::string add_negative_flag(const std::string& flag) { |
| std::string filter = ::testing::GTEST_FLAG(filter); |
| if (filter.find('-') == std::string::npos) { |
| filter.push_back('-'); |
| } else { |
| filter.push_back(':'); |
| } |
| filter += flag; |
| return filter; |
| } |
| |
| int main(int argc, char* argv[]) { |
| ::testing::InitGoogleTest(&argc, argv); |
| if (!torch::cuda::is_available()) { |
| std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests" |
| << std::endl; |
| ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); |
| } else if (torch::cuda::device_count() < 2) { |
| std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests" |
| << std::endl; |
| ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA"); |
| } |
| |
| return RUN_ALL_TESTS(); |
| } |