blob: 0c06fb69d349ed83498656f3ed57fb0a36f27c0d [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
import unittest
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
activations = [
F.celu,
F.elu,
F.hardsigmoid,
F.hardswish,
F.hardtanh,
F.leaky_relu,
F.relu,
F.relu6,
F.rrelu,
F.selu,
F.silu,
]
class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self):
dtypes = [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float32,
torch.float64,
]
# restore_mutation.h contains a mapping from activation operators
# to whether they allow type conversion. Use this checking to
# guard the mapping, and if any later change breaks the assumption
# we need to update the mapping correspondingly.
for activation, dtype in product(activations, dtypes):
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
try:
out = activation(inp)
self.assertEqual(dtype, out.dtype)
except RuntimeError:
# Skip the not implemented error
pass
def test_functional_to_inplace_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
z = activation(y)
return z
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
def test_no_functional_to_inplace(self):
# inplace conversion should not happen because sigmoid may
# perform type conversion
def test1():
y = torch.ones([2, 2])
z = torch.sigmoid(y)
return z
fn = torch.jit.script(test1)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias
# the input x
def test2(x):
y = x[0]
z = torch.relu(y)
return z
fn = torch.jit.script(test2)
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is
# at the global scope
class Test3(nn.Module):
def __init__(self, x):
super(Test3, self).__init__()
self.x = x
def forward(self):
y = torch.relu(self.x)
return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass('functional_to_inplace_activation', fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
activation(y, inplace=True)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
for activation in [
torch.relu_,
torch.sigmoid_,
torch.tanh_,
]:
def test_basic(x):
y = x + 1
activation(y)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))