blob: 4af895b2f894e03e398425baf5a98988d7c1aef4 [file] [log] [blame]
# Owner(s): ["module: nvfuser"]
import torch
from torch.testing._internal.common_utils import set_default_dtype
try:
from _nvfuser.test_torchscript import * # noqa: F403,F401
except ImportError:
def run_tests():
return
pass
if __name__ == '__main__':
# TODO: Update nvfuser to work with float default dtype
with set_default_dtype(torch.double):
run_tests()