blob: f522c37e178948069c59f5505d2213cf484e1a1c [file] [log] [blame] [edit]
import os
import sys
from tempfile import NamedTemporaryFile
import torch.package.package_exporter
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
class PackageTestCase(TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._temporary_files = []
def temp(self):
t = NamedTemporaryFile()
name = t.name
if IS_WINDOWS:
t.close() # can't read an open file in windows
else:
self._temporary_files.append(t)
return name
def setUp(self):
"""Add test/package/ to module search path. This ensures that
importing our fake packages via, e.g. `import package_a` will always
work regardless of how we invoke the test.
"""
super().setUp()
self.package_test_dir = os.path.dirname(os.path.realpath(__file__))
self.orig_sys_path = sys.path.copy()
sys.path.append(self.package_test_dir)
torch.package.package_exporter._gate_torchscript_serialization = False
def tearDown(self):
super().tearDown()
sys.path = self.orig_sys_path
# remove any temporary files
for t in self._temporary_files:
t.close()
self._temporary_files = []