| # Owner(s): ["module: unknown"] |
| |
| import glob |
| import io |
| import os |
| import unittest |
| |
| import torch |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| |
| |
| try: |
| from third_party.build_bundled import create_bundled |
| except ImportError: |
| create_bundled = None |
| |
| license_file = 'third_party/LICENSES_BUNDLED.txt' |
| starting_txt = 'The Pytorch repository and source distributions bundle' |
| site_packages = os.path.dirname(os.path.dirname(torch.__file__)) |
| distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info')) |
| |
| class TestLicense(TestCase): |
| |
| @unittest.skipIf(not create_bundled, "can only be run in a source tree") |
| def test_license_for_wheel(self): |
| current = io.StringIO() |
| create_bundled('third_party', current) |
| with open(license_file) as fid: |
| src_tree = fid.read() |
| if not src_tree == current.getvalue(): |
| raise AssertionError( |
| f'the contents of "{license_file}" do not ' |
| 'match the current state of the third_party files. Use ' |
| '"python third_party/build_bundled.py" to regenerate it') |
| |
| @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test") |
| def test_distinfo_license(self): |
| """If run when pytorch is installed via a wheel, the license will be in |
| site-package/torch-*dist-info/LICENSE. Make sure it contains the third |
| party bundle of licenses""" |
| |
| if len(distinfo) > 1: |
| raise AssertionError('Found too many "torch-*dist-info" directories ' |
| f'in "{site_packages}, expected only one') |
| with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid: |
| txt = fid.read() |
| self.assertTrue(starting_txt in txt) |
| |
| if __name__ == '__main__': |
| run_tests() |