blob: 3ce42e171b659a12d92bfc56a59a29bfaa66bab5 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
from itertools import product
import unittest
import torch
from torch.testing._internal.common_utils import TEST_CUDA
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit._passes._property_propagation import apply_input_props_using_example
try:
from torchvision import models
except ImportError:
models = None
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDeviceAnalysis(JitTestCase):
@classmethod
def setUpClass(cls):
cls.cpu = torch.device("cpu")
cls.cuda = torch.device("cuda")
cls.vulkan = torch.device("vulkan")
cls.mkldnn = torch.device(
"mkldnn"
) # MKLDNN can't mix with other device types at all
cls.device_types = [cls.cpu, cls.cuda, cls.vulkan]
@staticmethod
def node_output_device(graph):
graph_out = list(graph.outputs())
assert len(graph_out) == 1
return graph_out[0].type().device()
def prop_device_on_graph(self, graph, example_devices, in_shapes=None):
graph_inputs = list(graph.inputs())
torch._C._jit_pass_erase_shape_information(graph)
self.assertEqual(len(graph_inputs), len(example_devices))
for graph_i, device_i in zip(graph_inputs, example_devices):
if device_i is not None:
graph_i.setType(graph_i.type().with_device(device_i))
if in_shapes:
for graph_i, shapes_i in zip(graph_inputs, in_shapes):
if shapes_i is not None:
graph_i.setType(graph_i.type().with_sizes(shapes_i))
torch._C._jit_pass_propagate_shapes_on_graph(graph)
torch._C._jit_pass_propagate_device(graph)
def assert_device_equal(
self, fn, in_devices, expected_device, in_shapes=None, subtest_str=""
):
with self.subTest(
f"In device: {in_devices}, expected: {expected_device}, \n {subtest_str}"
):
graph = torch.jit.script(fn).graph
self.prop_device_on_graph(graph, in_devices, in_shapes)
actual_device = self.node_output_device(graph)
if expected_device is None or actual_device is None:
self.assertEqual(actual_device, expected_device)
else:
self.assertEqual(
actual_device.type, expected_device.type, "Failed Verification"
)
def test_device_apply(self):
# Test if the device is properly applied to the input
def add_self(x):
return x + x
graph = torch.jit.script(add_self).graph
graph_input = next(graph.inputs())
graph_input.setType(graph_input.type().with_device(self.cpu))
# self.prop_device_on_graph(graph, [self.cpu])
self.assertEqual(graph_input.type().device(), self.cpu)
@unittest.skipIf(models is None, "Requires torchvision")
def test_mobilenet(self):
in_cpu = torch.randn(1, 3, 224, 224, device=self.cpu)
in_example = in_cpu
expected_device = self.cpu
m = torch.jit.script(models.mobilenet_v3_small())
m.eval()
graph = torch.jit.freeze(m).graph
# torch._C._jit_pass_erase_shape_information(graph)
apply_input_props_using_example(graph, in_example)
torch._C._jit_pass_propagate_shapes_on_graph(graph)
torch._C._jit_pass_propagate_device(graph)
actual_device = self.node_output_device(graph)
if expected_device is None or actual_device is None:
self.assertEqual(actual_device, expected_device)
else:
self.assertEqual(
actual_device.type, expected_device.type, "Failed Verification"
)
def test_simple(self):
def add_self(x):
return x + x
def relu_(x):
return torch.nn.functional.relu_(x)
functions = [add_self, relu_]
for in_device, fn in product(self.device_types, functions):
self.assert_device_equal(fn, [in_device], in_device)
def test_set_dtype(self):
def set_device(x):
return x.to("cpu")
for in_device in self.device_types:
self.assert_device_equal(set_device, [in_device], self.cpu)
def test_device_arg(self):
# Test that no device gets propagated when arg is passed in
def set_device(x, device_name: torch.device):
return x.to(device=device_name)
for in_device in self.device_types:
self.assert_device_equal(set_device, [in_device, None], None)
def test_tensor_as_fns(self):
def view_as_fn(x, y):
return x.view_as(y)
def expand_as_fn(x, y):
return x.expand_as(y)
def reshape_as_fn(x, y):
return x.reshape_as(y)
for test_fn in [view_as_fn, expand_as_fn, reshape_as_fn]:
self.assert_device_equal(test_fn, [self.cpu, self.cpu], self.cpu)
self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
self.assert_device_equal(test_fn, [None, self.mkldnn], None)
def type_as_fn(x, y):
return x.type_as(y)
self.assert_device_equal(type_as_fn, [self.cpu, self.cpu], self.cpu)
self.assert_device_equal(type_as_fn, [self.cuda, None], None)
self.assert_device_equal(type_as_fn, [None, self.mkldnn], self.mkldnn)
def zerodim_test_core(self, device_pairs):
# Test the support of zerodim tensors with non-zerodim tensors
def mul(x, y):
return x * y
def add(x, y):
return x + y
fns = [mul, add]
input_shapes = [
((1, 2, 2), (2, 2)), # Different dim, non-zerodim
((1, 2, 2), ()), # one zerodim
((), ()), # both zerodim
]
for fn, shapes, devices in product(fns, input_shapes, device_pairs):
subtest_str = f"{fn.__name__} \n shapes: {shapes}, \n devices: {devices}"
in0 = torch.rand(shapes[0], device=devices[0])
in1 = torch.rand(shapes[1], device=devices[1])
try:
out = fn(in0, in1)
except Exception as e:
# Don't expect eager failures for CPU zerodim tensors
for i in range(len(devices)):
if shapes[i] == () and devices[i] == self.cpu:
raise e
# only expect eager failures on different devices
if devices[0] == devices[1]:
raise e
# Expect result device to be None for the failure cases.
self.assert_device_equal(fn, devices, None, shapes, subtest_str)
continue
self.assert_device_equal(fn, devices, out.device, shapes, subtest_str)
# Test that without shapes, we either get the same device or None for the device
# Aka that the code is convservative for tensor shapes.
graph = torch.jit.script(fn).graph
self.prop_device_on_graph(graph, devices)
actual_device = self.node_output_device(graph)
self.assertTrue(
(actual_device is None) or (actual_device.type == out.device.type)
)
def test_zerodim_cpu(self):
# Allow for minimal testing locally
self.zerodim_test_core([(self.cpu, self.cpu)])
def test_zerodim_no_device(self):
# If device is missing, you should never be able to infer device type.
def mul(x, y):
return x * y
def add(x, y):
return x + y
fns = [mul, add]
device_pairs = [
(self.cpu, None),
(None, self.cpu),
(None, None),
]
input_shapes = [
((1, 2, 2), (2, 2)), # Different dim, non-zerodim
((1, 2, 2), ()), # one zerodim
((), ()), # both zerodim
]
for fn, shapes, devices in product(fns, input_shapes, device_pairs):
self.assert_device_equal(fn, devices, None, shapes)
@unittest.skipIf(not TEST_CUDA, "No CUDA")
def test_zerodim_gpu(self):
device_pairs = [
(self.cpu, self.cuda),
(self.cuda, self.cpu),
(self.cuda, self.cuda),
]
self.zerodim_test_core(device_pairs)
def test_custom_device_op(self):
# Test both of the custom functions and check that the devicetype is
# correctly applied
def set_cuda(x):
return x.cuda()
def set_cpu(x):
return x.cpu()
def set_mkldnn(x):
return x.to_mkldnn()
device_pairs = (
(set_cuda, self.cuda),
(set_cpu, self.cpu),
(set_mkldnn, self.mkldnn),
)
for fn, out_device in device_pairs:
for in_device in self.device_types:
self.assert_device_equal(fn, [in_device], out_device)
def test_device_if_propagation(self):
def test_fn(x, y, z: bool):
if z:
return x + 3
else:
return y * 2
self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu)
self.assert_device_equal(test_fn, [self.mkldnn, self.mkldnn, None], self.mkldnn)
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None)
def test_loop_simple(self):
def test_fn(x, y, z: int):
for _ in range(z):
y = x
return y
self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu)
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None)
self.assert_device_equal(test_fn, [self.cpu, None, None], None)
def test_loop_device_change(self):
def test_fn(x, z: int):
for _ in range(z):
x = x.cuda()
return x
self.assert_device_equal(test_fn, [self.cpu, None], None)
self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
self.assert_device_equal(test_fn, [None, None], None)
def test_while_change(self):
def test_fn(x, z: int):
while z > 0:
x = x.cuda()
z = 0
return x
self.assert_device_equal(test_fn, [self.cpu, None], None)
self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
self.assert_device_equal(test_fn, [None, None], None)
def test_nested_loops(self):
def test_fn(x, z: int):
for i in range(z):
x = x.cpu()
for _ in range(i):
x = x + 1
return x
self.assert_device_equal(test_fn, [self.cpu, None], self.cpu)
self.assert_device_equal(test_fn, [self.cuda, None], None)
self.assert_device_equal(test_fn, [None, None], None)
def test_if_loop_mix(self):
def test_fn(x, y, z: bool, a: bool):
c = x
while a:
if z:
c = x + 3
else:
c = y * 2
a = False
return c
self.assert_device_equal(test_fn, [self.cpu, self.cpu, None, None], self.cpu)
self.assert_device_equal(
test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn
)
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)