| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| |
| from itertools import product |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.testing import FileCheck |
| import unittest |
| |
| try: |
| import torchvision |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| activations = [ |
| F.celu, |
| F.elu, |
| F.hardsigmoid, |
| F.hardswish, |
| F.hardtanh, |
| F.leaky_relu, |
| F.relu, |
| F.relu6, |
| F.rrelu, |
| F.selu, |
| F.silu, |
| ] |
| |
| class TestFunctionalToInplaceActivation(JitTestCase): |
| def test_check_no_type_promotion(self): |
| dtypes = [ |
| torch.bool, |
| torch.int8, |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| torch.float32, |
| torch.float64, |
| ] |
| # restore_mutation.h contains a mapping from activation operators |
| # to whether they allow type conversion. Use this checking to |
| # guard the mapping, and if any later change breaks the assumption |
| # we need to update the mapping correspondingly. |
| for activation, dtype in product(activations, dtypes): |
| inp = torch.normal(0, 5, size=(4, 4)).to(dtype) |
| try: |
| out = activation(inp) |
| self.assertEqual(dtype, out.dtype) |
| except RuntimeError: |
| # Skip the not implemented error |
| pass |
| |
| def test_functional_to_inplace_activation(self): |
| for activation in activations: |
| def test_basic(x): |
| y = x + 1 |
| z = activation(y) |
| return z |
| |
| fn = torch.jit.script(test_basic) |
| self.run_pass("inline", fn.graph) |
| self.run_pass("constant_propagation", fn.graph) |
| FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) |
| self.run_pass('functional_to_inplace_activation', fn.graph) |
| FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph) |
| FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) |
| inp = torch.rand([2, 2]) |
| self.assertEqual(fn(inp), test_basic(inp)) |
| |
| def test_no_functional_to_inplace(self): |
| # inplace conversion should not happen because sigmoid may |
| # perform type conversion |
| def test1(): |
| y = torch.ones([2, 2]) |
| z = torch.sigmoid(y) |
| return z |
| |
| fn = torch.jit.script(test1) |
| self.run_pass('functional_to_inplace_activation', fn.graph) |
| FileCheck().check_not("aten::sigmoid_").run(fn.graph) |
| |
| # inplace conversion should not happen because y is alias |
| # the input x |
| def test2(x): |
| y = x[0] |
| z = torch.relu(y) |
| return z |
| |
| fn = torch.jit.script(test2) |
| self.run_pass('functional_to_inplace_activation', fn.graph) |
| FileCheck().check_not("aten::relu_").run(fn.graph) |
| |
| # inplace conversion should not happen because self.x is |
| # at the global scope |
| class Test3(nn.Module): |
| def __init__(self, x): |
| super(Test3, self).__init__() |
| self.x = x |
| |
| def forward(self): |
| y = torch.relu(self.x) |
| return y |
| |
| fn = torch.jit.script(Test3(torch.rand([2, 2])).eval()) |
| self.run_pass('functional_to_inplace_activation', fn.graph) |
| FileCheck().check_not("aten::relu_").run(fn.graph) |
| |
| @skipIfNoTorchVision |
| def test_resnet18_correctness(self): |
| model = torchvision.models.resnet18() |
| frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) |
| N, C, H, W, = 10, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| self.run_pass('functional_to_inplace_activation', frozen_model.graph) |
| self.assertEqual(model(inp), frozen_model(inp)) |
| |
| |
| class TestInplaceToFunctionalActivation(JitTestCase): |
| def test_inplace_to_functional_activation(self): |
| for activation in activations: |
| def test_basic(x): |
| y = x + 1 |
| activation(y, inplace=True) |
| return y |
| |
| fn = torch.jit.script(test_basic) |
| self.run_pass("inline", fn.graph) |
| self.run_pass("constant_propagation", fn.graph) |
| FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) |
| self.run_pass('inplace_to_functional_activation', fn.graph) |
| FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph) |
| FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) |
| |
| for activation in [ |
| torch.relu_, |
| torch.sigmoid_, |
| torch.tanh_, |
| ]: |
| def test_basic(x): |
| y = x + 1 |
| activation(y) |
| return y |
| |
| fn = torch.jit.script(test_basic) |
| self.run_pass("inline", fn.graph) |
| self.run_pass("constant_propagation", fn.graph) |
| FileCheck().check(f"aten::{activation.__name__}").run(fn.graph) |
| self.run_pass('inplace_to_functional_activation', fn.graph) |
| FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph) |
| FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph) |
| |
| inp = torch.rand([2, 2]) |
| self.assertEqual(fn(inp), test_basic(inp)) |
| |
| @skipIfNoTorchVision |
| def test_resnet18_correctness(self): |
| model = torchvision.models.resnet18() |
| frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) |
| N, C, H, W, = 10, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| self.run_pass('inplace_to_functional_activation', frozen_model.graph) |
| self.assertEqual(model(inp), frozen_model(inp)) |