| # Owner(s): ["module: unknown"] |
| |
| import glob |
| import io |
| import os |
| import unittest |
| |
| import torch |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| try: |
| from third_party.build_bundled import create_bundled |
| except ImportError: |
| create_bundled = None |
| |
| license_file = "third_party/LICENSES_BUNDLED.txt" |
| starting_txt = "The PyTorch repository and source distributions bundle" |
| site_packages = os.path.dirname(os.path.dirname(torch.__file__)) |
| distinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info")) |
| |
| |
| class TestLicense(TestCase): |
| @unittest.skipIf(not create_bundled, "can only be run in a source tree") |
| def test_license_for_wheel(self): |
| current = io.StringIO() |
| create_bundled("third_party", current) |
| with open(license_file) as fid: |
| src_tree = fid.read() |
| if not src_tree == current.getvalue(): |
| raise AssertionError( |
| f'the contents of "{license_file}" do not ' |
| "match the current state of the third_party files. Use " |
| '"python third_party/build_bundled.py" to regenerate it' |
| ) |
| |
| @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test") |
| def test_distinfo_license(self): |
| """If run when pytorch is installed via a wheel, the license will be in |
| site-package/torch-*dist-info/LICENSE. Make sure it contains the third |
| party bundle of licenses""" |
| |
| if len(distinfo) > 1: |
| raise AssertionError( |
| 'Found too many "torch-*dist-info" directories ' |
| f'in "{site_packages}, expected only one' |
| ) |
| with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid: |
| txt = fid.read() |
| self.assertTrue(starting_txt in txt) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |