blob: 43dd5151c9abc06b4dd7611327bd9ef63ed090ad [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import contextlib
import unittest
import os
import random
import enum
import copy
from functools import reduce
import operator
import warnings
import torch
from torch.nn import functional
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, slowTest, \
is_iterable_of_tensors, freeze_rng_state
from torch.testing._internal.jit_utils import clone_inputs, get_traced_sample_variant_pairs, JitTestCase, RUN_CUDA
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing import FileCheck
from jit.test_fuser_common import TestFuserCommon # noqa: F401
import itertools
import numpy as np
import math
from torch.autograd.gradcheck import gradcheck
from typing import List
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
CUDA_MAJOR, CUDA_MINOR = 0, 0
if RUN_NVFUSER and torch.version.cuda is not None:
CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')[:2])
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition'
os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng'
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
# TODO: enable complex when we fixes the extremal cases in OpInfo
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# os.environ['PYTORCH_NVFUSER_ENABLE'] = 'complex'
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
FUSION_GROUP = 'prim::CudaFusionGroup'
FUSION_GUARD = 'prim::CudaFusionGuard'
# TODO: revert disabled alias ops
ALIAS_TEST_DISABLED = True
@contextlib.contextmanager
def nvfuser_singleton_fusion(flag):
old_value = torch._C._jit_set_nvfuser_single_node_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_single_node_mode(old_value)
@contextlib.contextmanager
def nvfuser_horizontal_fusion(flag):
old_value = torch._C._jit_set_nvfuser_horizontal_mode(flag)
try:
yield
finally:
torch._C._jit_set_nvfuser_horizontal_mode(old_value)
def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()
TEST_LARGE_TENSOR = RUN_NVFUSER
if RUN_NVFUSER:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
class CudaFuserTestOptions():
def __init__(self):
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
self.old_value = torch._C._jit_set_autocast_mode(True)
if(RUN_CUDA):
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
def restore(self):
if(RUN_CUDA):
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
torch._C._debug_set_autodiff_subgraph_inlining(True)
torch._C._jit_set_autocast_mode(self.old_value)
class TestCudaFuser(JitTestCase):
def assertEqual(self, *args, **kwargs):
kwargs["exact_layout"] = True
super(JitTestCase, self).assertEqual(*args, **kwargs)
def _getSubgraphInFusion(self, graph):
num_node = 0
subgraph = None
def count(block, ret):
for n in block.nodes():
if n.kind() == FUSION_GROUP:
ret[0] = ret[0] + 1
self.assertTrue(n.hasAttribute('Subgraph'))
ret[1] = n.g('Subgraph')
for block in n.blocks():
count(block, ret)
ret = [num_node, subgraph]
count(graph, ret)
self.assertEqual(ret[0], 1)
return ret[1]
def setUp(self):
super(TestCudaFuser, self).setUp()
self.skip_node_list = []
disabled_ops = ("aten::batch_norm",
"aten::_batch_norm_impl_index",
"aten::_batch_norm_impl_index_backward",
"aten::native_batch_norm_backward")
for op in disabled_ops:
disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False)
if disabled_flag:
torch._C._jit_set_nvfuser_skip_node_kind(op, True)
self.skip_node_list.append(op)
# cpu backup to avoid errors in case this is run on a CPU-only machine
dev = 'cuda' if RUN_NVFUSER else 'cpu'
self.special_values = torch.tensor(
[float("-inf"), -10, -math.pi,
-1, -0.5, 0, 1, 0.5,
math.pi, 10, float("inf"),
float("nan")], dtype=torch.float, device=dev)
self.int_types = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64
]
self.support_tensor_dtypes = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
torch.complex64,
torch.complex128,
]
if TEST_BF16:
self.support_tensor_dtypes.append(torch.bfloat16)
if(RUN_NVFUSER):
self.cuda_fuser_options = CudaFuserTestOptions()
def tearDown(self):
# restoring skip node to the configuration before tests
for op in self.skip_node_list:
disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False)
if not disabled_flag:
torch._C._jit_set_nvfuser_skip_node_kind(op, True)
if(RUN_NVFUSER):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
seed = 123
torch.cuda.manual_seed_all(seed)
jit_o = jit_op(*args)
for i in range(check_runs):
torch.cuda.manual_seed_all(seed + i)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(seed + i)
o = op(*args)
if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)
def _run_training_helper(self, jit_op, op, grads, *args):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
jit_g = jit_o.backward(grads)
torch.cuda.manual_seed_all(123)
o = op(*args)
g = o.backward(grads)
self.assertEqual(o, jit_o)
self.assertEqual(g, jit_g)
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(jit_op.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_half(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
return (o_16, o_32_a, o_32_b)
t_jit = torch.jit.script(t)
alpha = 0.5
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
y = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
z = torch.randint(0, 256, (4, 8)).to(dtype=torch.bfloat16, device="cuda")
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
o = t(x, y, z, alpha)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_const(self):
def t(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_chunk(self):
def t(x, y, z, q):
o = x + q
x0, x1 = torch.chunk(o, 2)
o = x0 + x1
o = o + y
o = o * z
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(2, 8, dtype=torch.float, device="cuda")
z = torch.randn(2, 8, dtype=torch.float, device="cuda")
q = torch.randn(4, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z, q)
jit_o = t_jit(x, y, z, q)
o = t(x, y, z, q)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_dtypes_axis(self):
for op in [torch.sum, torch.mean, torch.amax, torch.var, torch.std]:
for dtype in [torch.float16, torch.float32, torch.double]:
for axis in [-1, 2, 0]:
def make_func(op):
def func(x: torch.Tensor):
o = torch.mul(x, 2.0)
o = op(o, dim=[axis])
return o
return func
x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
t = make_func(op)
t_jit = torch.jit.trace(t, x)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_variance(self):
for op in [torch.var, torch.std]:
for dtype in [torch.float16, torch.float32, torch.double]:
for axis in [-2, -1, 2, 1]:
for unbiased in [False, True]:
def make_func(op):
def func(x: torch.Tensor):
o = torch.mul(x, 2.0)
o = op(o, dim=[axis])
return o
return func
x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
t = make_func(op)
t_jit = torch.jit.trace(t, x)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_input(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_2(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_3(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# test_broadcasting_partition_logic_X
# Testing partition logic that is capable to avoid creating unsupported
# broadcasting semantics in CudaFusionGroup
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_0(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_partition_logic_1(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
x = x + 12.0
o1 = x + y
o2 = x + z
o = o1 + o2
return o
t_jit = torch.jit.script(t)
x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda")
y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda")
z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
@unittest.skipIf(True, "Broadcast with different output not supported yet")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output_shape(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(True, "broadcast on branches can't be resolved yet")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_broadcasting_multiple_output(self):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
oo = o1.sum() + o2.sum()
return oo
t_jit = torch.jit.script(t)
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
def _unary_test_helper(self, operation, dtype, random_data):
gradient_check = (dtype == torch.float64) and random_data
shape = self.special_values.shape
torch.cuda.manual_seed_all(211)
# need additional def of t for boolean ops
def t(x: torch.Tensor, y: torch.Tensor):
o = x * y
o = o + 5e-3
o = operation(o)
return o
y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
y = y.to(dtype=dtype)
if random_data:
x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
if dtype in self.int_types:
# prefer a larger variance for integer types
x = x * 5
x = x.to(dtype=dtype)
else:
x = self.special_values.to(dtype=dtype)
try:
ref = t(x, y)
except Exception:
# same way as TE checker, if eager mode throws, ignore this test
return
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
if jit_o.dtype != torch.bool:
# bool dtype has no `-`
gradcheck(t_jit, [x, y], nondet_tol=1e-5)
elif dtype in self.support_tensor_dtypes:
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
if dtype == torch.bfloat16:
# compare with the actual ground truth for
# bfloat16 kernels instead of eager mode
# implementation, since mismatch in cast
# adds excessive noise.
o = t(x.to(torch.float64), y.to(torch.float64))
if o.dtype.is_floating_point:
o = o.to(torch.bfloat16)
else:
o = t(x, y)
self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2))
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_ops(self):
data_types = [
*self.int_types,
torch.float16,
torch.float32,
torch.float64,
# TODO: revert this
# see issue https://github.com/csarofeen/pytorch/issues/1730"
# torch.cfloat,
# torch.cdouble,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
operations = [torch.neg,
torch.abs,
torch.log,
torch.log10,
torch.log1p,
torch.log2,
torch.lgamma,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.acos,
torch.cosh,
torch.sin,
torch.asin,
torch.sinh,
torch.tan,
torch.atan,
torch.sqrt,
torch.rsqrt,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.frac,
torch.reciprocal,
torch.isfinite,
torch.isinf,
torch.isnan,
torch.isneginf,
torch.isposinf,
torch.isreal,
torch.nn.functional.softplus,
torch.nn.functional.gelu,
torch.relu,
torch.sigmoid,
torch.bitwise_not,
torch.tan,
torch.tanh,
torch.nn.functional.silu]
skip_complex = {torch.rsqrt, torch.reciprocal}
for op, dtype in itertools.product(operations, data_types):
if dtype.is_complex and op in skip_complex:
continue
self._unary_test_helper(op, dtype, False) # test special numbers
self._unary_test_helper(op, dtype, True) # test random data
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_category_rule(self):
def run_tensor(x, z):
def t(x: torch.Tensor, z: torch.Tensor):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
def run_scalar(x, z):
def t(x: torch.Tensor, z: float):
o = x + z
o = torch.abs(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, z)
jit_o = t_jit(x, z)
o = t(x, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
# n-dim with 0-dim (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with 0-dim (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(2.0, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with n-dim (type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
run_tensor(x, z)
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
if TEST_BF16:
# n-dim with scalar (no type-promote)
x = torch.randn(4, 8, 32, 32, dtype=torch.bfloat16, device="cuda")
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
# n-dim with scalar (type-promote)
x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
z = torch.tensor(3., dtype=torch.double)
run_scalar(x, z)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unary_bitwise(self):
def bit_not(x: torch.Tensor):
return ~(x + 1)
jitted = torch.jit.script(bit_not)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
jit_o = jitted(x)
jit_o = jitted(x)
o = bit_not(x)
self.assertEqual(o, jit_o)
jitted.graph_for(x) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD)
def bool_not(x: torch.Tensor, y: torch.Tensor):
return ~(x & y)
jitted = torch.jit.script(bool_not)
x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = bool_not(x, y)
self.assertEqual(o, jit_o)
jitted.graph_for(x, y) # Shows up in second instance, not first
self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD)
def _get_scalar_binary_test_fn(self, category_and_type1, category_and_type2, operation):
category1, dtype_arg1 = category_and_type1
category2, dtype_arg2 = category_and_type2
def t_intx_tensory(x: int, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
def t_doublex_tensory(x: float, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
def t_cdoublex_tensory(x: complex, y: torch.Tensor):
o = operation(x, y)
o = 2 + o
return o
# Omit both scalar cases and swap cases
assert category1 == "scalar" and category2 != "scalar"
if dtype_arg1.is_floating_point:
return t_doublex_tensory
if dtype_arg1 == torch.int64 or dtype_arg1 == torch.int32:
return t_intx_tensory
if dtype_arg1.is_complex or dtype_arg1 == torch.int32:
return t_cdoublex_tensory
raise NotImplementedError
def _binary_test_helper(self, operation, dtypes, random_data, categories="ndim"):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtypes
if isinstance(categories, tuple) and random_data:
category1, category2 = categories
elif not random_data:
category1 = category2 = "ndim"
else:
category1 = category2 = categories
def is_cpu_category(x):
return x == "0dimcpu" or x == "scalar"
# skip unsupported cases
if is_cpu_category(category1) and is_cpu_category(category2):
return
# only test cases with first operand as scalar
if category2 == "scalar":
return
# skip ops that doesn't support scalar inputs in eager
if operation in [
torch.atan2,
torch.max,
torch.min,
torch.remainder, # unsupported in nvfuser
]:
if category1 == "scalar" or category2 == "scalar":
return
if operation in [
torch.fmod,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.le,
torch.lt
]:
if category1 == "scalar":
return
# operators that does not support bfloat16
if operation in [torch.fmod]:
if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16:
return
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = operation(x, y)
o = o + z
return o
shape = (4, 32, 32)
shapex = shape if category1 == "ndim" else ()
shapey = shape if category2 == "ndim" else ()
if random_data:
x = (torch.randn(shapex, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shapey, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
r"""
Category conversion
"""
has_scalar = False
if category1 == "scalar":
has_scalar = True
x = x.item()
if category1 == "0dimcpu":
x = x.to(device="cpu")
if category2 == "scalar":
has_scalar = True
y = y.item()
if category2 == "0dimcpu":
y = y.to(device="cpu")
z = torch.tensor([2], device="cuda").to(dtype_arg1)
is_dtype_arg1_int = dtype_arg1 == torch.int32 or dtype_arg1 == torch.int64
is_dtype_arg2_int = dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64
if operation in [torch.pow]:
if is_dtype_arg1_int and is_dtype_arg2_int:
if category2 == "scalar":
# RuntimeError: Integers to negative integer powers are not allowed
y = abs(y)
if category2 == "0dimcpu" and y == -1:
# https://github.com/pytorch/pytorch/issues/73196
y = y - 1
if category2 == "0dimcpu" and y == -2:
# avoid pow(0, -2), which gives inconsistent results on integer tensor
y = y - 1
# Avoid division by zero for integer tensors
div_like = [torch.div, torch.fmod, torch.remainder]
if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64):
y[y == 0] = 1
test_value = True
if dtype_arg1 == torch.half or dtype_arg2 == torch.half:
test_value = False
if dtype_arg1 == torch.bfloat16 or dtype_arg2 == torch.bfloat16:
test_value = False
try:
if not has_scalar:
o = t(x, y, z)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
if test_value:
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
elif category2 != "scalar": # only test the case where first is scalar
test_fn = self._get_scalar_binary_test_fn((category1, dtype_arg1), (category2, dtype_arg2), operation)
o = test_fn(x, y)
t_jit = torch.jit.script(test_fn)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
if test_value:
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
except Exception as e:
print("failing test for op: ", operation.__name__)
print("with input\n\tx: ", x)
print("\ty: ", y)
print("\tz: ", z)
raise e
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
data_types = [
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
]
if TEST_BF16:
data_types.append(torch.bfloat16)
operations = [torch.mul,
torch.div,
torch.atan2,
torch.max,
torch.min,
torch.pow,
torch.remainder,
torch.fmod,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.le,
torch.lt]
category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]
binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))
for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers
# TODO: revert this
@unittest.skipIf(True, "see issue https://github.com/csarofeen/pytorch/issues/1730")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_complex(self):
data_types = [torch.cfloat, torch.cdouble]
operations = [torch.mul, torch.div, torch.pow, torch.eq, torch.ne]
category_types = [
"scalar",
"0dim",
"0dimcpu",
"ndim"
]
binary_dtype_combinations = list(itertools.combinations(data_types, 2))
category_combinations = list(itertools.combinations(category_types, 2))
for op, dtypes, categories in itertools.product(operations, binary_dtype_combinations, category_combinations):
self._binary_test_helper(op, dtypes, True, categories) # random data
for op, dtypes in itertools.product(operations, binary_dtype_combinations):
self._binary_test_helper(op, dtypes, False) # special numbers
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_bitwise(self):
dtypes = [torch.bool, torch.int32, torch.int64]
for dtype1, dtype2, dtype3 in itertools.product(dtypes, repeat=3):
def jit_and(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_and(x, y) & z
def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_or(x, y) | z
def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_xor(x, y) ^ z
def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_left_shift(x, y) << z
def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.bitwise_right_shift(x, y) >> z
for jit_func in [jit_and, jit_or, jit_xor, jit_lshift, jit_rshift]:
if torch.bool in {dtype1, dtype2, dtype3} and jit_func in {jit_lshift, jit_rshift}:
continue
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype1)
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(dtype2)
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(dtype3)
jitted = torch.jit.script(jit_func)
jit_o = jitted(x, y, z)
jit_o = jitted(x, y, z)
o = jit_func(x, y, z)
self.assertEqual(o, jit_o)
self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_as_op(self):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = torch.lt(x, z)
o = o.type_as(y)
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 0.5)
jit_o = t_jit(x, y, 0.5)
o = t(x, y, 0.5)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD)
def _ternary_integer_test_helper(self, dtype_arg1):
shape = (4, 8, 32, 32)
magnitude = 100
if (dtype_arg1 in self.int_types):
x = torch.randint(-magnitude, magnitude, shape, dtype=dtype_arg1, device="cuda")
else:
x = torch.randn(shape, dtype=dtype_arg1, device="cuda") * magnitude
arg2 = int(0)
arg3 = int(magnitude * 0.1)
def clamp0(x: torch.Tensor, f: int):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, arg2)
def clamp1(x: torch.Tensor, f: int, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, arg2, arg3)
def clamp2(x: torch.Tensor, f: float, ff: int):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp2_jit = torch.jit.script(clamp2)
self._run_helper(clamp2_jit, clamp2, x, float(arg2), arg3)
def clamp3(x: torch.Tensor, f: int, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp3_jit = torch.jit.script(clamp3)
self._run_helper(clamp3_jit, clamp3, x, arg2, float(arg3))
def threshold(x: torch.Tensor, th: int, val: int):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, arg2, arg3)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_integer_compatibility(self):
data_types = [
torch.float16,
torch.float32,
torch.float64
]
for dtype in data_types:
self._ternary_integer_test_helper(dtype)
def _ternary_test_helper(self, operation, dtypes, random_data):
if isinstance(dtypes, tuple):
dtype_arg1, dtype_arg2, dtype_arg3 = dtypes
else:
dtype_arg1 = dtype_arg2 = dtype_arg3 = dtypes
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: torch.Tensor):
o = operation(x, y, z)
o = o + alpha
return o
shape = (4, 32, 32)
if operation is torch.where:
dtype_arg1 = torch.bool
if random_data:
x = torch.randint(0, 2, shape).to(dtype=torch.bool, device="cuda")
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = torch.randint(0, 2, self.special_values.size()).to(dtype=torch.bool, device="cuda")
y = self.special_values.to(dtype=dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
elif random_data:
x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
z = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg3)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)
z = (torch.rand_like(self.special_values) * 5).to(dtype_arg3)
alpha = torch.tensor([2], device="cuda").to(dtype_arg1)
o = t(x, y, z, alpha)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, alpha)
jit_o = t_jit(x, y, z, alpha)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_ternary_ops_type_promotion(self):
# TODO: update accuracy tolerance for bf16 / fp16 data types
data_types = [
# torch.float16,
torch.float32,
torch.float64
]
'''
if TEST_BF16:
data_types.append(torch.bfloat16)
'''
# TODO: Add Tensor support for clamp
operations = [torch.clamp]
ternary_dtype_combinations = itertools.combinations(data_types, 3)
for op, dtypes in itertools.product(operations, ternary_dtype_combinations):
self._ternary_test_helper(op, dtypes, True) # random data
self._ternary_test_helper(op, dtypes, False) # special numbers
# We can't test the scalar version of rsub from python
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_rsub(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def rsub(x: torch.Tensor, y: torch.Tensor):
o = torch.rsub(x, y)
o = o * 2.
return o
rsub_jit = torch.jit.script(rsub)
self._run_helper(rsub_jit, rsub, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
# legacy fuser does not work for rand_like, see issue #34361
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
def test_ternary_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")
def add(x: torch.Tensor, other: torch.Tensor, alpha: float):
o = torch.relu(x)
o = torch.add(o, other=other, alpha=alpha)
return o
add_jit = torch.jit.script(add)
self._run_helper(add_jit, add, x, y, 2.0)
def clamp0(x: torch.Tensor, f: float):
o = 2. * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, x, 0.5)
def clamp1(x: torch.Tensor, f: float, ff: float):
o = 2. * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7)
def threshold(x: torch.Tensor, th: float, val: float):
o = 2. * torch.threshold(x, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, x, 0.2, 0.9)
def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor):
o = 2. * torch.where(cond, x, y)
return o
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, x, y, cond)
def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = 2. * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, x, y, z)
def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
o = 2. * torch.lerp(x, y, z)
return o
lerp_scale_jit = torch.jit.script(lerp_scale)
self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
def test_addcmul_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dynamic_size(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0))
self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
# this test is not ideal, as we rely on the bailout to test it and we
# don't know a way to verify the bailout graph to validate the proper
# fusion.
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda")
y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_random_topo(self):
os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))
def _compare(self, desc, inp1, inp2, error):
a = inp1.clone()
b = inp2.clone()
close = torch.allclose(a, b, rtol=error, atol=error, equal_nan=True)
if not close:
print(desc, close)
z = a - b
index = (torch.abs(z) >= error + error * torch.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
print("maximum difference", z[index].max())
return close
# Permutation helper that applies binary operation between two tensors:
# 1. applies separate permutation `perm0` & `perm1` to two inputs
# 2. reduce dimension `broadcast_axis` of operand two to size 1
# The purpose of this test is to ensure permutation works well in
# complicated cases with arbitrary stride order and broadcasting dimensions
def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
if broadcast_axis >= 0:
sizes[broadcast_axis] = 1
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(o.stride(), jit_o.stride())
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
# end-2-end test of permutation & contiguity handling in integration.
# we are testing inputs with all combination of permutation order, just to
# ensure that integration would be able to generate functionally correct
# kernels
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_permutation(self):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
x = [7, 8, 12]
b_axes = range(-1, len(x))
for b_axis in b_axes:
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
x = [7, 8, 12]
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops_channels_last_with_bcast(self):
device = "cuda"
x = torch.randn([4, 3, 2, 5], device=device).to(memory_format=torch.channels_last)
w = torch.randn([2, 5], device=device)
def t(x: torch.Tensor, b: torch.Tensor):
o = x + b
return torch.relu(o)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
jit_o = t_jit(x, w)
o = t(x, w)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, w), FUSION_GUARD)
def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
class MyReduction(torch.nn.Module):
__constants__ = ['reduction_axis', 'keepdim']
def __init__(self):
super(MyReduction, self).__init__()
self.reduction_axis = reduction_axis
self.keepdim = keepdim
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
return o
t = MyReduction()
x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
[perm0.index(i) for i in range(len(sizes))])
y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
[perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction(self):
for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for keepdim in (True, False):
perm0 = range(len(x))
perm1 = range(len(x))
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)
def _layer_norm_autodiff_helper(self, model, grad, shapes, args):
jit_model = torch.jit.script(model)
eps = np.random.random() * 1e-4
use_cudnn = bool(np.random.randint(0, 2))
# profile/optimization runs
for i in range(3):
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
ref_args = [t.detach().clone().requires_grad_() for t in args]
[t.grad.zero_() for t in args]
jit_o = jit_model(shapes, *args, eps, use_cudnn)
jit_o.backward(grad)
o = model(shapes, *ref_args, eps, use_cudnn)
o.backward(grad)
self.assertEqual(jit_o, o)
for arg, ref_arg in zip(args, ref_args):
self.assertEqual(arg.grad, ref_arg.grad)
# check fusion in fw & bw
g = jit_model.graph_for(shapes, *args, eps, use_cudnn)
for node in g.nodes():
n = node
dbg_state = jit_model.get_debug_state()
for val in dbg_state.execution_plans.values():
v = val
state2 = v.code.grad_executor_states()
for val in state2[0].execution_plans.values():
v2 = val
FileCheck().check(FUSION_GUARD).run(g)
FileCheck().check(FUSION_GUARD).run(v2.graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_autodiff(self):
def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
o = torch.relu(o)
return o
def t_w(shapes: List[int], x, w, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, None, eps, cudnn)
o = torch.relu(o)
return o
def t_b(shapes: List[int], x, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, b, eps, cudnn)
o = torch.relu(o)
return o
def t(shapes: List[int], x, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, None, None, eps, cudnn)
o = torch.relu(o)
return o
model = {3: t_wb, 2: t_w, 1: t_b, 0: t}
for w, b in itertools.product([True, False], repeat=2):
batch = [2]
# note: awkward shape here to avoid vectorized fast kernel, which is
# buggy in aten
shapes = [2, 7, 3]
m = model[w * 2 + b]
grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
if w:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
if b:
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
self._layer_norm_autodiff_helper(m, grad, shapes, args)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_parser(self):
dtype = torch.float32
device = "cuda"
x = torch.randn([4, 4, 2], dtype=dtype, device=device)
w = torch.randn([4, 2], dtype=dtype, device=device)
b = torch.randn([4, 2], dtype=dtype, device=device)
def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
o = torch.relu(x)
o = torch.layer_norm(o, [4, 2], w, b, 1e-5)
return o
o = t(x, w, b)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, w, b)
jit_o = t_jit(x, w, b)
o = t(x, w, b)
self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD)
def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True):
class MyLayerNorm(torch.nn.Module):
__constants__ = ['norm_shape']
def __init__(self, elementwise_affine=True):
super(MyLayerNorm, self).__init__()
self.norm_shape = norm_shape
if elementwise_affine:
self.weight = torch.randn(norm_shape, dtype=dtype, device=device)
self.bias = torch.randn(norm_shape, dtype=dtype, device=device)
with torch.no_grad():
self.weight.fill_(1)
self.bias.fill_(0)
else:
self.weight = None
self.bias = None
def forward(self, x: torch.Tensor):
o = torch.relu(x)
o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5)
return o
t = MyLayerNorm(affine)
x = torch.randn(shape, dtype=dtype, device=device)
t_jit = torch.jit.script(t)
jit_o, jit_mean, jit_rstd = t_jit(x)
jit_o, jit_mean, jit_rstd = t_jit(x)
o, mean, rstd = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error))
self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error))
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
for affine in (True, False):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_native_layer_norm_half(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_native_layer_norm_bfloat(self):
dims = 4
rnds = 3
for idx in range(rnds):
for offset in range(1, dims):
input_shape = [random.randint(10, 30) for idx in range(dims)]
norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
self._native_layer_norm_helper(input_shape, norm_shape, torch.bfloat16, "cuda", 1e-1)
def _norm_helper(self,
shape,
dtype,
device,
error,
is_batch_norm_else_instance_norm,
memory_format=torch.contiguous_format,
*,
layer_dtype=torch.float32):
class MyBatchNorm(torch.nn.Module):
def __init__(self):
super(MyBatchNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True)
o = torch.relu(o)
return o
class MyInstanceNorm(torch.nn.Module):
def __init__(self):
super(MyInstanceNorm, self).__init__()
def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True)
o = torch.relu(o)
return o
t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm()
x = torch.randn(shape, dtype=dtype, device=device).to(memory_format=memory_format)
running_mean = torch.zeros(shape[1], dtype=layer_dtype, device=device)
running_var = torch.ones(shape[1], dtype=layer_dtype, device=device)
t_jit = torch.jit.script(t)
eager_running_mean = running_mean.clone()
eager_running_var = running_var.clone()
jit_running_mean = running_mean.clone()
jit_running_var = running_var.clone()
jit_o = t_jit(x, running_mean.clone(), running_var.clone())
self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error))
jit_o = t_jit(x, jit_running_mean, jit_running_var)
o = t(x, eager_running_mean, eager_running_var)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.stride(), jit_o.stride())
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error))
self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error))
self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_layer_norm_trivial_reduce_dim(self):
def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
o = torch.relu(o)
return o
batch = [1]
shapes = [2, 7, 3]
grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
self._layer_norm_autodiff_helper(t_wb, grad, shapes, args)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_half_layer(self):
size = [2, 4, 2, 2]
for is_batch_norm_else_instance_norm in [False, True]:
for mf in [torch.channels_last, torch.contiguous_format]:
self._norm_helper(size, torch.float16, "cuda", 1e-3, is_batch_norm_else_instance_norm,
memory_format=mf, layer_dtype=torch.float16)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_channels_last(self):
size = [3, 4, 5, 6]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for mf in [torch.channels_last, torch.contiguous_format]:
self._norm_helper(size, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm, memory_format=mf)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_large(self):
output_elements = 262144
channel_sizes = 67, 457, 1024
for is_batch_norm_else_instance_norm in [True, False]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_norm_half(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_norm_bfloat(self):
output_elements = 10000
channel_sizes = [67, 457, 1024, 4096]
with torch.backends.cudnn.flags(enabled=False):
for is_batch_norm_else_instance_norm in [False, True]:
for dims in range(3, 6):
output_size = int(pow(output_elements, 1. / (dims - 1)))
for C in channel_sizes:
x = [output_size for idx in range(dims)]
x[1] = C
self._norm_helper(x, torch.bfloat16, "cuda", 1e-1, is_batch_norm_else_instance_norm)
def _softmax_helper(self, shape, reduction_axis, is_log_softmax, dtype, device, error):
class MySoftmax(torch.nn.Module):
__constants__ = ['reduction_axis']
def __init__(self):
super(MySoftmax, self).__init__()
self.reduction_axis = reduction_axis
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=self.reduction_axis)
return o
class MyLogSoftmax(torch.nn.Module):
__constants__ = ['reduction_axis']
def __init__(self):
super(MyLogSoftmax, self).__init__()
self.reduction_axis = reduction_axis
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.log_softmax(o, dim=self.reduction_axis)
return o
gradient_check = (dtype == torch.float64)
t = MyLogSoftmax() if is_log_softmax else MySoftmax()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check)
y = torch.randn(shape, dtype=dtype, device=device, requires_grad=gradient_check)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
gradcheck(t_jit.forward, [x, y], nondet_tol=1e-5)
else:
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
# numerical issues here due to our scheduling.
# can't use `self.assertEqual(o, jit_o)`
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_dtype(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch.nn.functional.softmax(o, dim=0, dtype=torch.float32)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda").requires_grad_()
y = torch.randn_like(x).requires_grad_()
grad = torch.randn_like(x).float()
ref_x = x.detach().requires_grad_()
ref_y = y.detach().requires_grad_()
o = t(ref_x, ref_y)
o.backward(grad)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad.zero_()
y.grad.zero_()
jit_o = t_jit(x, y)
jit_o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(ref_x.grad, x.grad)
self.assertEqual(ref_y.grad, y.grad)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=False)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test__softmax_function_half_to_float(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = torch._softmax(o, dim=-1, half_to_float=True)
return o
x = torch.randn([4, 4], dtype=torch.float16, device="cuda")
y = torch.randn_like(x)
o = t(x, y)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-3))
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
# gradient check
for reduction_dim in range(dims):
for is_log_softmax in [False, True]:
shape = [output_size for idx in range(dims)]
self._softmax_helper(shape, reduction_dim, is_log_softmax, torch.float64, "cuda", 1e-4)
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float32, "cuda", 1e-4)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softmax_half(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.float16, "cuda", 5e-3)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_softmax_bfloat(self):
output_size = 10000
dims = 4
output_size = int(pow(output_size, 1. / dims))
reduction_sizes = [67, 256, 1024, 4096]
for reduction_dim in range(dims):
for reduction_size in reduction_sizes:
x = [output_size for idx in range(dims)]
x[reduction_dim] = reduction_size
for is_log_softmax in [False, True]:
self._softmax_helper(x, reduction_dim, is_log_softmax, torch.bfloat16, "cuda", 1e-1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_permutation(self):
x = [7, 8, 12]
# note that num_dim is exclusive from len(x), so we are not reducing
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
for perm0 in itertools.permutations(range(len(x))):
for perm1 in itertools.permutations(range(len(x))):
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_multiple_output(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_bailout_depth(20)
def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor):
o = torch.mul(x, y)
o = torch.mul(o, scale)
out1 = torch.mul(o, z)
out2 = torch.sum(out1, dim=[2])
return out1, out2
t_jit = torch.jit.script(t)
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
scale = 0.5
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
x = x.to(memory_format=torch.channels_last)
y = y.to(memory_format=torch.channels_last)
z = z.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y, scale, z)
jit_o = t_jit(x, y, scale, z)
o = t(x, y, scale, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_channels_last_with_broadcast(self):
# setting this true forces a new graph to be generated with a new
# input a different broadcast shape
torch._C._jit_set_nvfuser_guard_mode(True)
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.mul(x, y)
o = o + 2.0
return o
t_jit = torch.jit.script(t)
# Single Channel broadcasts
# Test 1
x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
x = x.to(memory_format=torch.channels_last)
y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 3
y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
Currently, the JIT doesn't have tensor merge logic to handle adding
a broadcast tensor with more than one broadcast into a non-broadcast
tensor. Therefore, either of these tests can fail depending on the
sort implementation. The second test is known to fail.
# Two Channel broadcasts
# Test 1
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
# Test 2
y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
y = y.to(memory_format=torch.channels_last).transpose(2,3)
x = x.transpose(2,3)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
jit_o.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(o, jit_o)
'''
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_pw_single_reduction_partition(self):
sizes = [2, 2, 2]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=[0])
o = torch.add(o, z)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
with nvfuser_singleton_fusion(True):
def t(x: torch.Tensor):
return torch.relu(x)
t_jit = torch.jit.script(t)
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
self._run_helper(t_jit, t, x, check_stride=True)
def t(x: torch.Tensor, y: torch.Tensor):
return torch.add(x, y)
t_jit = torch.jit.script(t)
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
y = torch.randn(sizes[1:], dtype=dtype, device=device)
self._run_helper(t_jit, t, x, y, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_0(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
# mismatch rank with *note* different permutation recognized by PE
bias = torch.randn(3, dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1)
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
with nvfuser_singleton_fusion(True):
self._run_helper(t_jit, t, x, bias, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_1_broken(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
# in-compatible permutation, this will cause format propagation to break
bias = torch.randn(4, 5, dtype=dtype, device=device)
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
with nvfuser_singleton_fusion(True):
for _ in range(5):
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
try:
# nvfuser does not support in-compatible permutation, this will throw
self.assertEqual(o.stride(), jit_o.stride())
except Exception as e:
warnings.warn(
"permutation propagation is broken, proper support should come after nvfuser permutation scheduler update")
self.assertGraphContains(t_jit.graph_for(x, bias), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_permutation_preservation_edge_case_2(self):
sizes = [2, 3, 4, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
y = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
z = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
def t(x, y, w):
tmp = torch.lerp(x, y, w)
tmp = torch.clamp(tmp, -1.0, 0.5)
tmp = torch.nn.functional.softplus(tmp)
return torch.threshold(tmp, -2.0, 0.5)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, z, check_stride=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_normalization_partition(self):
sizes = [3, 8, 5]
dtype = torch.float
device = "cuda"
x = torch.randn(sizes, dtype=dtype, device=device)
y = torch.randn(sizes, dtype=dtype, device=device)
z = torch.randn(sizes, dtype=dtype, device=device)
r_m = torch.randn(8, dtype=dtype, device=device)
r_v = torch.randn(8, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
o = torch.add(x, y)
o = torch.nn.functional.softmax(o, dim=0)
o = torch.add(o, z)
o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z, r_m, r_v)
jit_o = t_jit(x, y, z, r_m, r_v)
o = t(x, y, z, r_m, r_v)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_one(self):
dtype = torch.float
device = "cuda"
x = torch.randn([4, 5, 6], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0, 1, 2])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_single_reduction_broadcast(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 8], dtype=dtype, device=device)
y = torch.randn([4, 8], dtype=dtype, device=device)
z = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, z)
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_trivial_reduction(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor):
o = torch.add(x, 1)
o = torch.sum(o, dim=[0])
o = torch.sum(o, dim=[0])
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profiling_node(self):
dtype = torch.float
device = "cuda"
x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device)
def repro(x: torch.Tensor, alpha: float):
o = torch.rand_like(x)
o = torch.add(o, alpha)
return o
repro_jit = torch.jit.script(repro)
self._run_helper(repro_jit, repro, x, 0.6)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_rand_like(self):
dtype = torch.float
device = "cuda"
def t(x: torch.Tensor, alpha: float):
o = torch.rand_like(x)
o = torch.add(o, alpha)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
for m_format in [torch.contiguous_format, torch.channels_last]:
x = torch.randn(4, 5, 6, 7, dtype=dtype, device=device).to(memory_format=m_format)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.6, check_stride=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_reduction_sizes_op(self):
dtype = torch.float
device = "cuda"
x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor):
o = x + y
o = torch.relu(o)
o = o.sum((1, 3))
return o.size()
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profile_ivalue(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 7], dtype=dtype, device=device)
y = torch.randn([7, 4, 7], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool):
o = torch.add(x, y)
o = o.sum(dim, keepdim=keepdim)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, (0, 1), False)
jit_o = t_jit(x, y, (0, 1), False)
o = t(x, y, (0, 1), False)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_profile_ivalue_multiple_profiles(self):
dtype = torch.float
device = "cuda"
x = torch.randn([7, 4, 7], dtype=dtype, device=device)
def t(x, num: int):
for i in range(num):
# varying reduction axes should break profile_ivalue
tmp = x.sum(i, keepdim=True)
# inplace add on input/output, can't be functionalized/fused
x += tmp
return x
with nvfuser_singleton_fusion(True):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 3, num_fusion=0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device)
y = torch.randn([2, 4, 4], dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]):
o = torch.add(x, y)
o = o.sum_to_size(new_size)
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, (4, 1))
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device)
y = torch.randn([2, 5, 8], dtype=dtype, device=device)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
self._run_helper(t_jit, t, x, y, (5, 1))
with nvfuser_singleton_fusion(True):
x = torch.randn([2, 5, 8], dtype=dtype, device=device)
def t(x: torch.Tensor):
# no-op reduction
return x.sum_to_size((2, 5, 8))
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_grad_sum_to_size(self):
dtype = torch.float
device = "cuda"
x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_()
y = torch.randn([4], dtype=dtype, device=device).requires_grad_()
grad = torch.randn([2, 4, 4], dtype=dtype, device=device)
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
def t(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.relu(o)
return o
# profiling runs for forward & backward
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o.backward(grad)
jit_o = t_jit(x, y)
jit_o.backward(grad)
x.grad = None
y.grad = None
jit_o = t_jit(x, y)
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GUARD).run(bwd_graph)
# update shape: old kernel should handle dynamic shape well without
# recompilation
x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_()
y = torch.randn([8], dtype=dtype, device=device).requires_grad_()
ref_x = x.detach().clone().requires_grad_()
ref_y = y.detach().clone().requires_grad_()
grad = torch.randn([2, 5, 8], dtype=dtype, device=device)
jit_o = t_jit(x, y)
# (TODO) check executed kernel, should extend autograd.profiler to fused
# kernels
jit_o.backward(grad)
o = t(ref_x, ref_y)
o.backward(grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(o, jit_o)
self.assertEqual(x.grad, ref_x.grad)
self.assertEqual(y.grad, ref_y.grad)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_inference_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.15, False)
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o + 1.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_fusion(self):
dtype = torch.float
device = "cuda"
sizes = [2, 3, 4, 5]
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
def t2(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.softmax(x, dim=-1)
o = torch.nn.functional.dropout(o, p, training=train)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
t2.__disable_jit_function_caching__ = True
for fn in [t, t2]:
for m_format in [torch.contiguous_format, torch.channels_last]:
fn_jit = torch.jit.script(fn)
x = torch.randn(sizes, dtype=dtype, device=device, requires_grad=True).to(memory_format=m_format)
grads = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=m_format)
# The drop probability needs to be set to zero given that the order of picking random
# numbers between eager mode and the jit is different
self._run_training_helper(fn_jit, fn, grads, x, 0.0, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_gelu(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False)
def t(x: torch.Tensor, mode: str):
o = torch.nn.functional.gelu(x, approximate=mode)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_training_helper(t_jit, t, grads, x, 'none')
self._run_training_helper(t_jit, t, grads, x, 'tanh')
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_training_prob_check(self):
dtype = torch.float
device = "cuda"
x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device)
def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
torch.cuda.manual_seed_all(123)
jit_o = t_jit(x, prob, True)
self.assertTrue(jit_o.detach().isfinite().all().item())
num_elems = x.numel()
num_zeros = num_elems - jit_o.detach().count_nonzero().item()
percent_zeros = num_zeros / num_elems
self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear(self):
in_feature = 2
out_feature = 8
# Changing the input dims to be 3-D to avoid eager mode bias fusion
# The bias fusion causes some precision issues with TF-32
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.linear(x, weight, bias)
o = torch.relu(o)
return o
# disabling cache so new inputs would generate new graph
t.__disable_jit_function_caching__ = True
sizes = [in_feature, ]
for i in range(4):
# increase input rank in each iteration
sizes.insert(0, i + 2)
x = torch.randn(*sizes, dtype=torch.float32, device='cuda')
t_jit = torch.jit.script(t)
# fusion only happens for input rank >= 4
has_fusion = 0 if len(sizes) < 4 else 1
self._run_helper(t_jit, t, x, weight, bias, check_stride=True, num_fusion=has_fusion)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_linear_symbolic_shapes(self):
def fn(x: int):
y = torch.zeros((3, 4, x, x + 2)).cuda()
for i in range(2):
inp = torch.rand((3, 4, x, x + i)).cuda()
weight = torch.rand((x + 2, x + i)).cuda()
bias = torch.rand((x, x + 2)).cuda()
y += torch.sin(torch.nn.functional.linear(inp, weight, bias))
return y
fn_s = torch.jit.script(fn)
fn_s(5)
fn_s(5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_conv2d_symbolic_shapes(self):
def fn(x: int):
responses = []
for i in range(2):
inp = torch.rand((3, 3, 32, 32)).cuda()
weight = torch.rand((x + i, 3, 7, 7)).cuda()
bias = torch.rand((x + i)).cuda()
res = torch.nn.functional.conv2d(inp, weight, bias, padding=3)
responses.append(res)
return responses
fn_s = torch.jit.script(fn)
fn_s(5)
fn_s(5)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_backward_type(self):
# not super useful to check gradient of integer/bool, so skipping here
type_pairs = [
(torch.float, torch.half),
(torch.double, torch.half),
(torch.float, torch.double),
]
if TEST_BF16:
type_pairs += [
(torch.float, torch.bfloat16),
(torch.double, torch.bfloat16),
]
for x_type, y_type in type_pairs:
x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True)
y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True)
grad = torch.randn(4, 2, dtype=torch.float, device='cuda')
def test1(x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = torch.add(o, y)
o = o + 1.0
return o
test1_jit = torch.jit.script(test1)
for i in range(3):
jit_o = test1_jit(x, y)
jit_o.backward(grad)
bwd_graph = list(
list(test1_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_1(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast():
jit_o = t_jit(x, y)
if i == 2:
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.half)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_autocast_2(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast():
jit_o = t_jit(x)
if i == 2:
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast():
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_1_bfloat(self):
def t(x: torch.Tensor, y: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch._C._nn.linear(o, y)
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
jit_o = t_jit(x, y)
if i == 2:
fwd_graph = t_jit.graph_for(x, y)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.bfloat16)
self.assertEqual(x.grad.dtype, x.dtype)
self.assertEqual(y.grad.dtype, y.dtype)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_autocast_2_bfloat(self):
def t(x: torch.Tensor):
o = x * 2.0
o = torch.softmax(o, dim=-1)
o = o * 3.0
o = torch.softmax(o, dim=-1)
o = o * 4.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True)
grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
t_jit = torch.jit.script(t)
for i in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
jit_o = t_jit(x)
if i == 2:
fwd_graph = t_jit.graph_for(x)
jit_o.backward(grad)
self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[
0].code.grad_executor_states()[0].execution_plans.values()
)[0].graph
FileCheck().check(FUSION_GROUP).run(bwd_graph)
self.assertEqual(jit_o.dtype, torch.float)
self.assertEqual(x.grad.dtype, x.dtype)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp32_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_dtype_fp16_to_fp16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.half)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.half, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.half)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_fp32_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.float, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_fp32(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.float)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.float)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not TEST_BF16, "device does not support BFloat16")
def test_to_dtype_bf16_to_bf16(self):
def t(x: torch.Tensor):
o = x * 2.0
o = o.to(dtype=torch.bfloat16)
o = o * 3.0
return o
x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda')
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
self.assertEqual(jit_o.dtype, torch.bfloat16)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_multiple_device_pw(self):
def t(x):
o = x + 1.0
o = torch.relu(o)
return o
x = torch.randn(2, dtype=torch.float32, device="cuda")
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x)
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
torch.cuda.device(1)
x = x.to("cuda:1")
jit_o = t_jit(x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_for_with_missing_optimized_engine(self):
x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_()
def t(x: torch.Tensor, flag: bool):
x = x + 1.0
x = torch.relu(x)
if flag:
o = x + 1.0
o = torch.relu(o)
else:
o = x + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, False)
jit_o = t_jit(x, False)
jit_o = t_jit(x, True)
o = t(x, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_branches(self):
in_feature = 2
out_feature = 4
x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool):
if flag:
o = torch.nn.functional.linear(x, weight, bias)
o = o + 1.0
o = torch.relu(o)
else:
o = x.sum()
o = o + 2.0
o = torch.relu(o)
return o
t_jit = torch.jit.script(t)
jit_o = t_jit(x, weight, bias, True)
jit_o = t_jit(x, weight, bias, True)
o = t(x, weight, bias, True)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor(self):
x = torch.empty([], device="cuda", dtype=torch.float32)
def t(x: torch.Tensor):
o = x + 1.0
o = torch.nn.functional.relu(o)
return o
# bias set to true.
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o, jit_o)
# since the output value is not used at all, the fusion operator should
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
@unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None,
"skipping graph_rng when caching allocator is disabled")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_rng(self):
self.assertTrue(torch._C._jit_nvfuser_enabled())
size = 10000
a = torch.randn((size,), device="cuda", dtype=torch.float)
def t(x):
o = x + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
o = o + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
return o
t_jit = torch.jit.script(t)
for _ in range(3):
t_jit(a)
self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
# Control (jitted, ungraphed)
torch.cuda.manual_seed(5)
eager_out = a.clone()
for _ in range(3):
eager_out = t_jit(eager_out)
graph_in = a.clone()
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
torch.cuda.manual_seed(5)
g.capture_begin()
graph_out = t_jit(graph_in)
g.capture_end()
torch.cuda.current_stream().wait_stream(s)
# g is now a jitted, graphed version of t.
# Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
# The ops in the overall sequence should be the same as Control.
g.replay()
# graph_out is now filled with g's result. Use it as ungraphed input.
out = t_jit(graph_out)
graph_in.copy_(out)
g.replay()
# If replay() updated RNG state correctly, graph_out should now equal eager_out
self.assertEqual(graph_out, eager_out)
def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True,
track_running_stats=True, train=True,
dtype=torch.float32):
# enabling inlining to avoid counter increment in BN forward
torch._C._debug_set_autodiff_subgraph_inlining(True)
class MyModule(torch.nn.Module):
def __init__(self, num_features=10, affine=True, track_running_stats=True):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features,
1e-5,
affine=affine,
track_running_stats=track_running_stats).to(dtype=dtype)
def forward(self, x):
o = self.bn(x)
o = o * 2.0
return o
x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_()
grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10)
my_module = MyModule(c, affine, track_running_stats).cuda()
ref_module = MyModule(c, affine, track_running_stats).cuda()
if not train:
my_module.eval()
ref_module.eval()
t_jit = torch.jit.script(my_module)
ref_module.load_state_dict(my_module.state_dict())
ref_x = x.detach().requires_grad_()
for i in range(0, 3):
jit_o = t_jit(x)
jit_o.backward(grad)
# TODO: remove this run?
o = ref_module(ref_x)
o.backward(grad)
has_affine = ref_module.bn.weight is not None
has_running_stats = ref_module.bn.running_mean is not None
if has_running_stats:
my_module.bn.running_mean.zero_()
my_module.bn.running_var.fill_(1.0)
ref_module.bn.running_mean.zero_()
ref_module.bn.running_var.fill_(1.0)
# Verify that when train is False, we don't have grad for weight/bias.
if has_affine and train:
my_module.bn.weight.grad.zero_()
my_module.bn.bias.grad.zero_()
ref_module.bn.weight.grad.zero_()
ref_module.bn.bias.grad.zero_()
x.grad.zero_()
ref_x.grad.zero_()
# real runs
jit_o = t_jit(x)
jit_o.backward(grad)
o = ref_module(ref_x)
o.backward(grad)
# assert forward graph fusion
self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True)
# assert backward graph fusion
bwd_graph = list(
list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0]
.execution_plans.values())[0].graph
self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
e0 = 1e-5 if dtype is not torch.half else 1e-3
e1 = 1e-4 if dtype is not torch.half else 1e-3
e2 = 1e-3 if dtype is not torch.half else 1e-2
self.assertTrue(self._compare("comparing output failed", jit_o, o, e0))
self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, e1))
# TODO: switch to welford and reduce this to 1e-5
# The 1e-3 looks bad, but we don't have welford in codegen, so numeric
# is very different between reference and codegen.
if has_affine and train:
self.assertTrue(self._compare("comparing weight grad failed",
my_module.bn.weight.grad,
ref_module.bn.weight.grad,
e2))
self.assertTrue(self._compare("comparing bias grad failed",
my_module.bn.bias.grad,
ref_module.bn.bias.grad,
e1))
if has_running_stats:
self.assertTrue(self._compare("comparing running_mean failed",
my_module.bn.running_mean,
ref_module.bn.running_mean,
e0))
self.assertTrue(self._compare("comparing running_var failed",
my_module.bn.running_var,
ref_module.bn.running_var,
e0))
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_half(self):
with torch.backends.cudnn.flags(enabled=True):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(4, 8, 5, affine, track_running_stats, training, torch.half)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_impl_index_inner_bcast(self):
# the repro
self._test_batch_norm_impl_index_helper(2, 1, 1, False, True, True)
# running the full set
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(2, 1, 1, affine, track_running_stats, training)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_batch_norm_impl_index_correctness(self):
with torch.backends.cudnn.flags(enabled=True):
batch = [2, 7, 16]
channels = [4, 89, 19, 32]
hw = [1, 8, 17, 32]
# avoid tolerance failure in CI
torch.cuda.manual_seed_all(211)
# failing sizes (2, 1, 1, 1)
# failing sizes (2, 89, 8, 8) training False, track True, affine: False
for b, c, hw in itertools.product(batch, channels, hw):
setups = [
[True, True],
[False, False],
[True, False],
[False, True]]
for training_and_track, affine in itertools.product(setups, [True, False]):
training, track_running_stats = training_and_track
self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_softplus_fuser(self):
def shifted_softplus(x: torch.Tensor, shift: float):
return functional.softplus(x) - shift
jitted = torch.jit.script(shifted_softplus)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_()
inp_ref = inp.detach().clone().requires_grad_()
grad = torch.randn(4, 2, dtype=torch.float32, device="cuda")
aten_o = shifted_softplus(inp_ref, 0.693147)
aten_o.backward(grad)
aten_grad = inp_ref.grad
for i in range(3):
jit_o = jitted(inp, 0.693147)
inp.grad = None # avoid accumulation on grad
jit_o.backward(grad)
jit_grad = inp.grad
assert torch.allclose(jit_o, aten_o)
assert torch.allclose(jit_grad, aten_grad)
self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_inplace_removal(self):
def t(x: torch.Tensor):
o = torch.nn.functional.softmax(x, dim=0)
o += x
return o.relu_()
jitted = torch.jit.script(t)
inp = torch.randn(4, 2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'aten::add', True)
self.assertGraphContains(graph, 'aten::relu', True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_conv2d_bias(self):
def t(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.conv2d(x, w, bias)
return o.relu()
jitted = torch.jit.script(t)
inp = torch.randn(4, 5, 3, 3, dtype=torch.float32, device="cuda")
weight = torch.randn(2, 5, 2, 2, dtype=torch.float32, device="cuda")
bias = torch.randn(2, dtype=torch.float32, device="cuda")
for i in range(3):
jit_o = jitted(inp, weight, bias)
graph = jitted.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
def t_not_fused(x: torch.Tensor, w: torch.Tensor):
o = torch.nn.functional.conv2d(x, w)
return o.relu()
jitted_not_fused = torch.jit.script(t_not_fused)
for i in range(3):
jit_o = jitted_not_fused(inp, weight)
graph = jitted_not_fused.graph_for(inp)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContains(graph, 'aten::relu', True)
def t_bias(x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor):
o = torch.nn.functional.conv2d(x, w, bias)
return o.relu()
jitted_bias = torch.jit.script(t_bias)
for i in range(3):
jit_o = jitted_bias(inp, weight, bias)
graph = jitted_bias.graph_for(inp)
self.assertGraphContains(graph, FUSION_GROUP, True)
self.assertGraphContains(graph, 'prim::add_optional', True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_remove_output_used_only_in_dtype(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn0 = torch.nn.BatchNorm2d(num_features)
self.bn1 = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
o1 = self.bn0(x)
o2 = self.bn1(y)
return torch.relu(o1 + o2)
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_fix_shape_expression_bn(self):
class MyModule(torch.nn.Module):
def __init__(self, num_features=4):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(num_features)
def forward(self, x, y):
out1 = self.bn(x)
out2 = out1 + y
out3 = torch.relu(out2)
return out3
t = MyModule(4).float().cuda()
jitted = torch.jit.script(t)
x = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
y = torch.randn(3, 4, 2, 5, dtype=torch.float32, device="cuda")
with torch.cuda.amp.autocast(True):
for i in range(5):
jit_o = jitted(x, y)
jit_o = jitted(x, y)
o = t(x, y)
self.assertTrue(torch.allclose(jit_o, o))
graph = jitted.graph_for(x, y)
self.assertGraphContains(graph, FUSION_GROUP, True)
def _run_fwd_helper(self, func, ops, *args):
jitted = torch.jit.script(func)
for i in range(3):
jit_o = jitted(*args)
jit_o = jitted(*args)
o = func(*args)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
graph = jitted.graph_for(*args)
self.assertGraphContains(graph, FUSION_GROUP, True)
for op in ops:
self.assertGraphContainsExactly(graph, op, 0)
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(2, 5, dtype=dtype, device=device)
def t(x: torch.Tensor):
o1 = x + 1.0
o2 = x * 0.5
return o1, o2
self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x)
def t2(x: torch.Tensor, y: torch.Tensor):
o1 = x.sum(0)
o2 = (x * y).sum(0)
return o1, o2
self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clean_profile_ivalue(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device, requires_grad=True)
# turn on autodiff subgraph inlining
# this is to verify that we clean up profile_ivalue node out side of
# fusion code path.
torch._C._debug_set_autodiff_subgraph_inlining(True)
def t(x: torch.Tensor, flag: bool):
return torch.dropout(x, 0.5, flag)
jit_t = torch.jit.script(t)
for idx in range(5):
out = jit_t(x, True)
graph = jit_t.graph_for(x, True)
out = jit_t(x, False)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_sibling_fusion_no_scalar_inputs(self):
device = "cuda"
dtype = torch.float
x = torch.randn(2, 5, dtype=dtype, device=device)
y = torch.randn(3, dtype=dtype, device=device)
# no tensor dependency between o1/o2, we shouldn't be fusing them
def t(x: torch.Tensor, y: torch.Tensor):
o1 = x + 1
o2 = y - 1
return o1, o2
jitted = torch.jit.script(t)
for i in range(3):
jit_o = jitted(x, y)
graph = jitted.graph_for(x, y)
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs: torch.Tensor, view_shape: List[int]):
o = inputs + self.bias
o = o.view(view_shape)
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x, output_shape)
# optimization
jit_o = t_jit(x, output_shape)
# final
jit_o = t_jit(x, output_shape)
# eager - baseline
o = t(x, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, output_shape)
has_inferred_dimension = any([dim == -1 for dim in output_shape])
if has_inferred_dimension:
# prohibit fusing when view_shape contains an inferred dimension
self.assertGraphContainsExactly(graph, FUSION_GROUP, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
else:
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::view_copy', True)
def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error):
class BiasViewRelu(torch.nn.Module):
def __init__(self):
super(BiasViewRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, bias : torch.Tensor, view_shape : List[int]):
o = inputs.view(view_shape)
inputs.add_(bias)
return torch.relu(o)
t = BiasViewRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x.clone(), bias, output_shape)
# optimization
jit_o = t_jit(x.clone(), bias, output_shape)
# final
jit_o = t_jit(x.clone(), bias, output_shape)
# eager - baseline
o = t(x.clone(), bias, output_shape)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias, output_shape)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::view_copy', 0)
# generate random view given original view
def _random_view(self, original_view, max_len=8, max_views=10000):
class Moves(enum.Enum):
Merge = 0
Split = 1
Broadcast = 2
ImplicitBroadcast = 3
Keep = 4
def valid(old_view, new_view):
old_view_size = reduce(operator.mul, old_view)
new_view_size = reduce(operator.mul, new_view)
return old_view_size == new_view_size
# given a random starting number, find the nearest divisor
def find_nearest_divisor(N):
if 2 >= (N - 1):
return -1
result = random.randint(2, N - 1)
while (N % result) != 0:
result += 1
return result
complete_views = set([tuple(original_view)])
to_visit = []
# empty new view, curent originaal view, start pos=0, move count = 0, last_move
to_visit.append(([], original_view, 0, [], Moves.Keep))
# depth-first search of view shapes, starting from the original view
while len(to_visit) > 0 and len(complete_views) < max_views:
new_view, old_view, odx, move_list, last_move = to_visit[-1]
to_visit.pop()
# iterate over each move type
for idx in range(len(Moves)):
state = Moves(idx)
new_view_clone = copy.deepcopy(new_view)
old_view_clone = copy.deepcopy(old_view)
new_move_list = move_list + [state]
new_odx = odx
# Update state using Move state
if state == Moves.Keep:
new_size = old_view_clone[odx]
new_view_clone.append(new_size)
new_odx += 1
elif state == Moves.Merge:
if odx + 1 < len(old_view_clone):
new_size = old_view_clone[odx] * old_view_clone[odx + 1]
new_view_clone.append(new_size)
new_odx += 2
else:
continue
elif state == Moves.Broadcast and last_move != Moves.Broadcast:
new_view_clone.append(1)
elif state == Moves.Split:
new_size = find_nearest_divisor(old_view_clone[odx])
if new_size == -1:
continue
new_view_clone.append(new_size)
old_view_clone[odx] = int(old_view[odx] / new_size)
if old_view_clone[odx] == 1:
new_odx += 1
elif state == Moves.ImplicitBroadcast:
old_view_clone.insert(odx + 1, 1)
new_size = old_view[odx] * 1
new_view_clone.append(new_size)
new_odx += 2
if new_odx < len(old_view_clone) and len(new_move_list) < max_len:
to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state))
elif (valid(original_view, new_view_clone)):
final_new_view = tuple(new_view_clone)
complete_views.add(final_new_view)
return list(complete_views)
# ndims - number of dimensions
# test_fn - view test function
def _view_test_generator(self, ndims, test_fn):
# create random tensor
# max value for each dimension
max_size = 10e7
max_value = max(int(pow(max_size, 1. / ndims)), 1)
sizes = [random.randint(1, max_value) for idx in range(ndims)]
x = torch.randn(sizes)
original_sizes = list(x.size())
all_views = self._random_view(original_sizes)
random.shuffle(all_views)
max_samples = 20
max_views = min(len(all_views), max_samples)
total = 0
correct = 0
# test random combinations of compatible views
for idx in range(max_views):
for jdx in range(idx + 1, max_views):
total += 1
test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view(self):
torch._C._jit_set_nvfuser_guard_mode(True)
self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6)
for ndims in range(1, 5):
self._view_test_generator(ndims, self._bias_view_relu_helper)
self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
def _bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error):
class BiasFlattenRelu(torch.nn.Module):
def __init__(self):
super(BiasFlattenRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, start_dim : int, end_dim : int):
o = inputs + self.bias
o = o.flatten(start_dim, end_dim)
return torch.relu(o)
t = BiasFlattenRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, start_dim, end_dim)
self.assertGraphContains(t_jit.graph_for(x, start_dim, end_dim), 'prim::flatten_copy', True)
def _alias_bias_flatten_relu_helper(self, shape, start_dim, end_dim, dtype, device, error):
class BiasFlattenRelu(torch.nn.Module):
def __init__(self):
super(BiasFlattenRelu, self).__init__()
self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False)
with torch.no_grad():
self.bias.fill_(10)
def forward(self, inputs : torch.Tensor, bias : torch.Tensor, start_dim : int, end_dim : int):
o = inputs.flatten(start_dim, end_dim)
inputs.add_(bias)
return torch.relu(o)
t = BiasFlattenRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profiling
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# optimization
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# final
jit_o = t_jit(x.clone(), bias, start_dim, end_dim)
# eager - baseline
o = t(x.clone(), bias, start_dim, end_dim)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias, start_dim, end_dim)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::flatten_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since flatten is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_flatten(self):
torch._C._jit_set_nvfuser_guard_mode(True)
self._bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6)
self._bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, -1, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 0, 3, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 1, 2, torch.float, 'cuda', 1e-6)
self._alias_bias_flatten_relu_helper([2, 3, 4, 5], 2, 2, torch.float, 'cuda', 1e-6)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_strict_fusion(self):
def success(x):
with torch.jit.strict_fusion():
return x + x + x
scripted = self.checkScript(success, (torch.rand([4], device='cuda'),))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("aten::add").check("prim::CudaFusionGroup").run(g)
def failure(x):
with torch.jit.strict_fusion():
return x + torch.mm(x, x) + x
with self.assertRaises(Exception) as error_out:
foo_s = torch.jit.script(failure)
foo_s(torch.rand([4, 4]))
foo_s(torch.rand([4, 4]))
fc = FileCheck().check("Found unfused operators")
fc.check("aten::mm").run(str(error_out.exception))
def _ltc_helper(self, shape, dtype, device, error, approximate=True):
# modeled after LTC linear layer
class LTC(torch.nn.Module):
def __init__(self):
super(LTC, self).__init__()
self.weight = torch.nn.Parameter(torch.randn([1024, 1024], dtype=dtype, device=device), requires_grad=False)
self.bias = torch.nn.Parameter(torch.randn([1, 1024], dtype=dtype, device=device), requires_grad=False)
def forward(self, inputs : torch.Tensor):
o = inputs.view([32768, 1024])
o = torch.mm(o, self.weight)
o = o.view([256, 128, 1024])
o = o + self.bias
o = o.view([32768, 1024])
o = o.view([256, 128, 1024])
return torch.nn.functional.gelu(o)
t = LTC()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
# profile/optimization runs
for i in range(3):
jit_o = t_jit(x)
o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::view_copy', True)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nested_view(self):
self._ltc_helper([256, 128, 1024], torch.float, 'cuda', 1e-6)
def _bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = inputs + bias
o = torch.squeeze(o)
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::squeeze_copy', True)
def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error):
class BiasSqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasSqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = torch.squeeze(inputs)
inputs.add_(bias)
return torch.relu(o)
t = BiasSqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
o = t(x.clone(), bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze(self):
self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
# remove this after opinfo tests are enabled
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze_zero(self):
x = torch.tensor(1.0, dtype=torch.float, device="cuda")
def squeeze_0(x: torch.Tensor):
o = x + 1.
o = torch.squeeze(o, 0)
o = o * 2.
return o
def squeeze_1(x: torch.Tensor):
o = x + 1.
o = torch.squeeze(o, -1)
o = o + .5
return o
squeeze_0_jit = torch.jit.script(squeeze_0)
self._run_helper(squeeze_0_jit, squeeze_0, x)
squeeze_1_jit = torch.jit.script(squeeze_1)
self._run_helper(squeeze_1_jit, squeeze_1, x)
def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs: torch.Tensor, bias: torch.Tensor):
o = inputs + bias
o = torch.unsqueeze(o, 0)
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
jit_o = t_jit(x, bias)
o = t(x, bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContains(graph, FUSION_GUARD)
self.assertGraphContains(graph, 'prim::unsqueeze_copy', True)
def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error):
class BiasUnsqueezeRelu(torch.nn.Module):
def __init__(self):
super(BiasUnsqueezeRelu, self).__init__()
def forward(self, inputs : torch.Tensor, bias : torch.Tensor):
o = torch.unsqueeze(inputs, 0)
inputs.add_(bias)
return torch.relu(o)
t = BiasUnsqueezeRelu()
x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False)
t_jit = torch.jit.script(t)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
jit_o = t_jit(x.clone(), bias)
o = t(x.clone(), bias)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
graph = t_jit.graph_for(x, bias)
self.assertGraphContainsExactly(graph, FUSION_GUARD, 0)
self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_unsqueeze(self):
self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_alias_pass_fix(self):
x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda")
w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda")
b = torch.randn(24, dtype=torch.float, device="cuda")
def t(x, w, b):
b2 = b + 1.0
o = torch.conv2d(x, w, b2)
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, w, b)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_squeeze_negative_dim(self):
x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda")
def t(x):
o = x + 1.0
o = o.squeeze(-2)
o = o * 2.0
return o
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_singleton_fusion(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.relu()
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_issue1445_fusion(self):
def f(t0, t1, t2, t3):
masked_input = torch.where(t1, t2, t3)
total = masked_input.sum([0, 1, 2, 3])
sizes : List[int] = []
t10 = torch.reshape(t0, sizes)
t7 = total / t10
t4 = t7.to(dtype=torch.float)
return t4
x = torch.randn(1, 1, 1, 1, device='cuda').to(dtype=torch.long)
y = torch.randn(3, 2, 1, 1, device='cuda').to(dtype=torch.bool).expand([3, 2, 1, 2])
z = torch.randn(3, 2, 1, 2, device='cuda')
w = torch.tensor(1.5, device='cuda')
f_jit = torch.jit.script(f)
for i in range(5):
out_jit = f_jit(x, y, z, w)
out = f(x, y, z, w)
self.assertEqual(out, out_jit)
self.assertGraphContainsExactly(f_jit.graph_for(x, y, z, w), FUSION_GROUP, 1)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_disable_sibling_fuse(self):
x = torch.randn(4, 2, device="cuda")
y = torch.randn(8, device="cuda")
s = torch.tensor(1.5, device="cuda")
with nvfuser_horizontal_fusion(False):
def t(x, y, s):
o1 = x + s
o2 = y + s
return o1, o2
t_jit = torch.jit.script(t)
for i in range(5):
t_jit(x, y, s)
# sibling fusion should be disabled with the flag
self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_build_shape_expression_native_dropout(self):
x = torch.randn(4, 2, device="cuda")
def t(x):
o, mask = torch.native_dropout(x, 0.0, True)
o1 = o.sigmoid()
o2 = mask.float().sigmoid()
return (o1, o2)
t_jit = torch.jit.script(t)
jit_o = t_jit(x)
jit_o = t_jit(x)
o = t(x)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_tensor_permuted(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
y = torch.tensor(1.0, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cpu_scalar(self):
x = torch.randn(4, 2, 3, device="cuda")
y = torch.tensor(1.0, device="cpu")
z = torch.tensor(2.0, device="cpu")
with nvfuser_singleton_fusion(True):
# testing cpu scalar tensor promotion
def t(x, y):
return x + y
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
# scalar cpu tensor add should NOT be fused
@torch.jit.script
def t1(y, z):
return y * z
for _ in range(5):
t1(y, z)
self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0)
# everything, including scalar cpu tensor add should be fused
@torch.jit.script
def t2(x, y, z):
tmp = y + z
return tmp + x
for _ in range(5):
t2(x, y, z)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0)
self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1)
# 'cpu_tmp = y + z' shouldn't be fused.
@torch.jit.script
def t3(x, y, z):
cpu_tmp = y + z
out = x + y
return cpu_tmp, out
for _ in range(5):
t3(x, y, z)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1)
self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since squeeze/unsqueeze is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_shape_expression(self):
x = torch.randn(4, 2, 1, 3, device="cuda")
def t_unsqueeze(x):
t0 = x.relu()
t1 = t0.unsqueeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze(x):
t0 = x.relu()
t1 = t0.squeeze()
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def t_squeeze_dim(x):
t0 = x.relu()
t1 = t0.squeeze(-2)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
# squeezing a non-size 1 dimension should be a no op
def t_squeeze_dim_no_op(x):
t0 = x.relu()
t1 = t0.squeeze(1)
t2 = t1 + 1.0
t3 = t1.size()
return t2, t3
def run(fn):
jit_fn = torch.jit.script(fn)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
jit_o = jit_fn(x)
o = fn(x)
# output 0 is a tensor, so we check dtype and value
self.assertEqual(o[0].dtype, jit_o[0].dtype)
self.assertEqual(o[0], jit_o[0])
# output 1 is shape
self.assertEqual(o[1], jit_o[1])
self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1)
for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]:
run(t)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scalar_cuda_tensor(self):
x = torch.tensor(2.0, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x + 1.0
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@torch.jit.script
def t_jitted(x):
return x.sum(0)
for i in range(5):
t_jitted(x)
self.assertGraphContainsExactly(t_jitted.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_overlapped_input(self):
x = torch.randn(8, device="cuda").as_strided((2, 4), (1, 1))
with nvfuser_singleton_fusion(True):
def t(x):
return x + 1.0
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_reduction_empty_axes(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
with nvfuser_singleton_fusion(True):
def t(x):
sizes : List[int] = []
return x.sum(sizes)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_int_tensor_input(self):
x = torch.randn(4, 2, device="cuda").to(dtype=torch.int)
with nvfuser_singleton_fusion(True):
def t(x):
return x.amax(dim=0)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_boolean(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.to(dtype=torch.bool)
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_copy(self):
x = torch.randn(4, 2, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, dtype : torch.dtype):
o = torch.ops.aten._to_copy(x, dtype=dtype)
return o
t.__disable_jit_function_caching__ = True
t_jit = torch.jit.script(t)
for dtype in [torch.float16, torch.bool, torch.float64]:
self._run_helper(t_jit, t, x, dtype)
def t_none(x):
with torch.jit.strict_fusion():
o = torch.ops.aten._to_copy(x, dtype=None)
return o
t_jit_none = torch.jit.script(t_none)
self._run_helper(t_jit_none, t_none, x)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view_copy_graph_guard(self):
x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0])
y = [4, 6]
with nvfuser_singleton_fusion(True):
def t(x, y : List[int]):
t1 = x + 1.0
t2 = t1 * 1.0
out = t2.reshape(y)
return out.relu()
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since view is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_view_copy_graph_guard_double_fusion(self):
x = torch.randn(2, 2, 5, device="cuda")
w = torch.randn(5, 5, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, w):
o = x.view([4, x.size()[-1]])
o = torch.matmul(o, w)
o = o.view([2, 2, o.size()[1]])
return o
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x, w)
o = t(x, w)
self.assertEqual(jit_o, o)
self.assertGraphContainsExactly(t_jit.graph_for(x, w), FUSION_GUARD, 2, consider_subgraphs=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_input_output_passthrough(self):
def t(t0, t1, t2):
mask = t1.to(dtype=torch.bool)
masked_input = torch.where(t0, mask, t2)
return masked_input, mask
t_jit = torch.jit.script(t)
# stick to integers, this avoid the numerical difference due to our
# promotion
x = torch.randn(4, 4, device='cuda').to(dtype=torch.bool)
y = torch.randn(4, 4, device='cuda').to(dtype=torch.bool)
z = torch.tensor(1.0, device='cuda').to(dtype=torch.bool)
jit_o = t_jit(x, y, z)
jit_o = t_jit(x, y, z)
o = t(x, y, z)
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_pointwise_reference_tensor(self):
def t(input1, input2, scalar):
_unsafe_view = torch.ops.aten._unsafe_view(input1, [2, 4, 16])
add_ = torch.ops.aten.add_(_unsafe_view, input2)
gelu_ = torch.ops.aten.gelu(add_)
view_ = torch.ops.aten.view(gelu_, [8, 16])
mul_ = torch.ops.aten.mul(add_, scalar)
return [view_, mul_]
x = torch.randn(8, 16, device="cuda")
bias = torch.randn(16, device="cuda")
scalar = torch.ones(torch.Size([]), device="cuda")
t_jit = torch.jit.script(t)
for i in range(3):
jit_o = t_jit(x, bias, scalar)
o = t(x, bias, scalar)
self.assertEqual(jit_o, o)
self.assertGraphContains(t_jit.graph_for(x, bias, scalar), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_native_batch_norm_backward(self):
grad_output = torch.randn(4, 2, 3, device="cuda")
input = torch.randn(4, 2, 3, device="cuda")
weight = torch.randn(2, device="cuda")
r_m = torch.randn(2, device="cuda")
r_v = torch.randn(2, device="cuda").abs()
save_mean = torch.randn(2, device="cuda")
save_invstd = torch.randn(2, device="cuda").abs()
with nvfuser_singleton_fusion(True):
def t(grad_out, input, weight, r_m, r_v, save_mean, save_invstd, train: bool, eps: float, mask: List[bool]):
return torch.ops.aten.native_batch_norm_backward(grad_out, input, weight, r_m, r_v, save_mean,
save_invstd, train, eps, mask)
t_jit = torch.jit.script(t)
for i in range(4):
jit_o = t_jit(grad_output, input, weight, r_m.clone(), r_v.clone(),
save_mean, save_invstd, True, 1e-5, [True, True, True])
ref_m = r_m.clone()
ref_v = r_v.clone()
jit_o = t_jit(grad_output, input, weight, r_m, r_v, save_mean, save_invstd, True, 1e-5, [True, True, True])
o = t(grad_output, input, weight, ref_m, ref_v, save_mean, save_invstd, True, 1e-5, [True, True, True])
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertEqual(ref_m.dtype, r_m.dtype)
self.assertEqual(ref_m, r_m)
self.assertEqual(ref_v.dtype, r_v.dtype)
self.assertEqual(ref_v, r_v)
self.assertGraphContains(t_jit.graph_for(grad_output, input, weight, r_m.clone(), r_v.clone, save_mean,
save_invstd, True, 1e-5, [True, True, True]), FUSION_GUARD)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_contiguous_on_broadcasted(self):
x = torch.randn(4, 1, device="cuda")
y = torch.randn(4, 128, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x, y):
t1 = x.expand([4, 128])
t2 = t1 * y
return t2
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_skip_parser(self):
x = torch.randn(4, 12, device="cuda")
with nvfuser_singleton_fusion(True):
def fn(x):
t1 = x + 1.0
return t1.relu()
fn_jit = torch.jit.script(fn)
self._run_helper(fn_jit, fn, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_jit.graph_for(x), 'aten::add', 0)
# flips skip parse for `aten::add`, following fusion should skip the
# add node
self.assertFalse(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True))
def fn_1(x):
t1 = x + 2.0 # change const value so we'll not reuse plan
return t1.relu()
fn_1_jit = torch.jit.script(fn_1)
self._run_helper(fn_1_jit, fn_1, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_1_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_1_jit.graph_for(x), 'aten::add', 1)
# flips skip parse for `aten::add`, next fusion should fuse add node
self.assertTrue(torch._C._jit_set_nvfuser_skip_node_kind("aten::add", True))
def fn_2(x):
t1 = x + 2.0 # change const value so we'll not reuse plan
return t1.relu()
fn_2_jit = torch.jit.script(fn_2)
self._run_helper(fn_2_jit, fn_2, x)
# add node should have been merged into fusion
self.assertGraphContains(fn_2_jit.graph_for(x), FUSION_GUARD)
self.assertGraphContainsExactly(fn_2_jit.graph_for(x), 'aten::add', 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cuda_fusion_guard(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
class ConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sin().sigmoid()
mod = ConvModule().to(device="cuda")
inputs = [torch.randn(20, 16, 50, 100, device="cuda", requires_grad=True)]
def reduce_scalar(temp):
return temp.sum()
scripted = torch.jit.script(mod)
with torch.no_grad():
scripted(*inputs)
res = scripted(*inputs)
reduce_scalar(res).backward()
torch._C._jit_set_nvfuser_guard_mode(old_guard)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nvfuser_comparison_callbacks_with_fallback(self):
try:
fused_result = None
unfused_result = None
graph_ir = None
def callback(fused_outputs, unfused_outputs, graph_str):
nonlocal unfused_result
nonlocal fused_result
nonlocal graph_ir
unfused_result = unfused_outputs[-1]
fused_result = fused_outputs[-1]
graph_ir = graph_str
torch._C._jit_nvfuser_set_comparison_callback(True, callback)
def fn(x, y):
z = torch.add(x, y)
return torch.relu(z)
x = torch.rand((4, 4)).cuda() - 0.5
y = torch.rand((4, 4)).cuda() - 0.5
fn_s = torch.jit.script(fn)
fn_s(x, y)
fn_s(x, y)
fn_s(x, y)
expected = fn(x, y)
self.assertEqual(expected, fused_result)
self.assertEqual(expected, unfused_result)
FileCheck().check("aten::add").run(graph_ir)
finally:
torch._C._jit_nvfuser_clear_comparison_callback()
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_nvfuser_comparison_callbacks_without_fallback(self):
try:
fused_result = None
unfused_result = None
graph_ir = None
def callback(fused_outputs, unfused_outputs, graph_str):
nonlocal unfused_result
nonlocal fused_result
nonlocal graph_ir
if len(unfused_outputs) > 0:
unfused_result = unfused_outputs[-1]
fused_result = fused_outputs[-1]
graph_ir = graph_str
torch._C._jit_nvfuser_set_comparison_callback(False, callback)
def fn(x, y):
z = torch.add(x, y)
return torch.relu(z)
x = torch.rand((4, 4)).cuda() - 0.5
y = torch.rand((4, 4)).cuda() - 0.5
fn_s = torch.jit.script(fn)
fn_s(x, y)
fn_s(x, y)
fn_s(x, y)
expected = fn(x, y)
self.assertEqual(expected, fused_result)
self.assertEqual(None, unfused_result)
FileCheck().check("aten::add").run(graph_ir)
finally:
torch._C._jit_nvfuser_clear_comparison_callback()
@unittest.skipIf(not RUN_NVFUSER, "requires NVFuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_cuda_fusion_guard_backward(self):
old_guard = torch._C._jit_set_nvfuser_guard_mode(True)
inp = torch.randn(10, device="cuda", requires_grad=True)
grad = torch.randn(10, device="cuda")
def f(x):
a = x.cos().cos()
return a
scripted = torch.jit.script(f)
with profile(activities=[ProfilerActivity.CPU]) as prof:
for _ in range(5):
inp.grad = None
out = scripted(inp)
out.backward(grad)
# check that we do not have fallback triggered
self.assertEqual(prof.events().table().find("fallback"), -1)
torch._C._jit_set_nvfuser_guard_mode(old_guard)
# TODO: generalize this
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
def test_inf_quick_patch(self):
inputs = [torch.tensor([-float('inf'), float('inf'), 4.0], device="cuda"),
torch.tensor([1.0, float('inf'), 4.0], device="cuda"),
torch.tensor([-float('inf'), -1.5, 4.0], device="cuda"),
torch.tensor([1.0, -3.0, float('nan')], device="cuda"),
torch.tensor([-float('inf'), -float('inf'), -float('inf')], device="cuda"),
torch.tensor([float('inf'), float('inf'), float('inf')], device="cuda"),
torch.tensor([float('nan'), float('nan'), float('nan')], device="cuda")]
def fn_amax(x):
return x.amax(dim=0)
def fn_amin(x):
return x.amin(dim=0)
def fn_add_nan(x):
return x.relu() + float('nan')
def fn_add(x):
return x + 1.0
with nvfuser_singleton_fusion(True):
for t in [fn_amax, fn_amin, fn_add, fn_add_nan]:
for x in inputs:
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clamp_reversed_bound(self):
x = torch.tensor([1., -float('inf'), 2., float('inf'), float('nan')], device="cuda")
def t(x):
return x.clamp(min=1., max=0.5)
with nvfuser_singleton_fusion(True):
jit_t = torch.jit.script(t)
self._run_helper(jit_t, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_issue_1785(self):
class Fusion(torch.nn.Module):
def __init__(self):
super(Fusion, self).__init__()
def forward(self, x, a, b):
out = torch.mul(x.unsqueeze(-1), a)
out = out + b
return out
x = torch.randn(1024, 192, 3, device='cuda')
a = torch.randn(3, 128, device='cuda')
b = torch.randn(3, 128, device='cuda')
model = Fusion()
jit_model = torch.jit.script(model)
with torch.jit.fuser('fuser2'):
for _ in range(4):
out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)
out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)
self.assertTrue(self._compare("comparing output failed", out_ref, out_jit, 1e-5))
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_high_rank_fusion(self):
# currently we want to limit fusion to node with input where rank <= 8
rank_limit = 8
shapes = [4 for i in range(rank_limit + 1)]
x = torch.randn(shapes, device="cuda")
with nvfuser_singleton_fusion(True):
def t(x):
return x.relu()
jit_t = torch.jit.script(t)
for i in range(5):
jit_t(x)
self.assertGraphContainsExactly(jit_t.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_clamp(self):
x = torch.tensor([1., float('inf'), 2., float('nan'), float('-inf')], device="cuda")
def clamp_max(x):
return x.clamp(max=1.5)
def clamp_min_max(x):
return x.clamp(min=1.5)
def clamp_min(x):
return x.clamp(min=1., max=3.)
with nvfuser_singleton_fusion(True):
for t in [clamp_max, clamp_min, clamp_min_max]:
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_device_constant(self):
x = torch.randn(4, 2, device="cuda")
def t(x):
return torch.rand_like(x, device=torch.device(type='cuda'))
# cpu tensor shouldn't be fused
def t_cpu(x):
return torch.rand_like(x, device=torch.device(type='cpu'))
with nvfuser_singleton_fusion(True):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)
t_cpu_jit = torch.jit.script(t_cpu)
for i in range(5):
t_cpu_jit(x)
self.assertGraphContainsExactly(t_cpu_jit.graph_for(x), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_expand(self):
device = "cuda"
x = torch.randn(3, 5, device=device)
y = torch.randn(4, 2, 3, 5, device=device)
def t(x, y):
with torch.jit.strict_fusion():
x = x.relu()
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
return o0, o1
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x, y, check_stride=True)
def t2(x, y):
o0 = x.expand(2, 3, 5)
o1 = x.expand_as(y)
x.add_(1)
return o0, o1
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x, y, check_stride=True, num_fusion=0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_scheduler_with_polymorphic_broadcast(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.randn(10, device=device)
def t(x0, x1, x2):
x3 = x2.unsqueeze(-1)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(0)
return x4, x6
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)
x2 = torch.randn(128, device=device)
def t2(x0, x1, x2):
x3 = x2.unsqueeze(0)
x4 = x3 + x0
x5 = x3 + x1
x6 = x5.sum(1)
return x4, x6
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)
class TestEnableDisableCudaFuser(JitTestCase):
def setUp(self):
super().setUp()
if RUN_NVFUSER:
self.is_enabled = torch._C._jit_set_nvfuser_enabled(False)
def tearDown(self):
if RUN_NVFUSER:
torch._C._jit_set_nvfuser_enabled(self.is_enabled)
super().tearDown()
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_context_manager_test(self):
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
with torch.jit.fuser('fuser2'):
with torch.jit.fuser('fuser2'):
def t1(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t1)
t_jit(x, y)
t_jit(x, y)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
def t2(x, y):
o = x + y
o = o + 3.0
return o
t_jit_2 = torch.jit.script(t2)
t_jit_2(x, y)
t_jit_2(x, y)
self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD)
def t3(x, y):
o = x + y
o = o + 4.0
return o
t_jit_3 = torch.jit.script(t3)
t_jit_3(x, y)
t_jit_3(x, y)
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
def test_register_fuser(self):
self.assertFalse(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
self.assertFalse(torch._C._jit_nvfuser_enabled())
@unittest.skipIf(RUN_CUDA, "Testing on CPU only")
def test_register_fuser_cpu(self):
with self.assertRaises(RuntimeError):
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(not TEST_WITH_ROCM, "ROCM test only")
def test_register_fuser_rocm(self):
with self.assertRaises(RuntimeError):
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
def test_can_be_enabled_nvfuser(self):
if TEST_WITH_ROCM:
expected = False
else:
expected = RUN_CUDA
self.assertEqual(expected, torch._C._jit_nvfuser_can_be_enabled())
# See TestNNCOpInfoParent
class TestCudaFuserOpInfoParent(JitCommonTestCase):
pass
class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
def setUp(self):
super(TestCudaFuserOpInfoParent, self).setUp()
if RUN_NVFUSER:
self.cuda_fuser_options = CudaFuserTestOptions()
# enables guard mode since tracing could change graph to violate guard.
torch._C._jit_set_nvfuser_guard_mode(True)
self.nvfuser_single_node_mode = torch._C._jit_set_nvfuser_single_node_mode(True)
def tearDown(self):
if RUN_NVFUSER:
self.cuda_fuser_options.restore()
torch._C._jit_set_nvfuser_single_node_mode(self.nvfuser_single_node_mode)
super(TestCudaFuserOpInfoParent, self).tearDown()
@slowTest
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@ops(op_db, dtypes=OpDTypes.supported)
def test_nvfuser_correctness(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True)
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
self.assertEqual(ref, val, exact_layout=True)
# Note: Clearing CU after NVFuser tests
# https://github.com/pytorch/pytorch/issues/35600
# each torch.jit.trace adds state to the _python_cu compilation unit
# since this test traces a lot of functions, out-of-memory can occur
# if the CU is not cleared.
torch.jit._state._python_cu.drop_all_functions()
@slowTest
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@ops(op_db, allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128))
def test_nvfuser_extremal_values(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
def _get_extremal_tensor(x, val, dtype):
if x.dtype != dtype:
return x
return torch.full_like(x, val)
def _get_extremal_input(x, val, dtype):
if isinstance(x, torch.Tensor):
return _get_extremal_tensor(x, val, dtype)
elif is_iterable_of_tensors(x):
return [_get_extremal_tensor(y, val, dtype) for y in x]
return x
def _get_extremal_sample(sample: SampleInput, val, dtype):
extremal_sample = SampleInput(
input=_get_extremal_input(sample.input, val, dtype),
args=[_get_extremal_input(x, val, dtype) for x in sample.args],
kwargs={k: _get_extremal_input(v, val, dtype) for k, v in sample.kwargs.items()},
)
return extremal_sample
def _get_extremal_samples(sample: SampleInput, dtype):
vals = [float('inf'), float('-inf'), float('nan')]
if dtype.is_complex:
complex_vals = itertools.product(vals, vals)
vals = list(map(lambda x: complex(*x), complex_vals))
for val in vals:
yield _get_extremal_sample(sample, val, dtype)
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
for extremal_sample in _get_extremal_samples(sample, dtype):
try:
with freeze_rng_state():
ref = variant(*clone_inputs((extremal_sample.input, *extremal_sample.args)),
**extremal_sample.kwargs)
except (torch._C._LinAlgError, RuntimeError, ValueError):
# if eager errors out, then don't expect NVFuser to pass
continue
with freeze_rng_state():
val = trace(*clone_inputs((extremal_sample.input, *extremal_sample.args)),
**extremal_sample.kwargs)
self.assertEqual(val, ref, equal_nan=True, exact_device=True)
# See [Note: Clearing CU after NVFuser tests]
torch.jit._state._python_cu.drop_all_functions()
instantiate_device_type_tests(TestCudaFuserOpInfo, globals(), only_for=("cuda"))
if __name__ == '__main__':
run_tests()