blob: 28c5894a002c9ffd72e948d595e24fd8688ac17f [file] [log] [blame]
# Owner(s): ["module: nvfuser"]
import unittest
from typing import List
import torch
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase
from torch.testing._internal.jit_utils import RUN_CUDA
import torch._refs as refs
import torch._prims as prims
# Will only create the _nvfuser module if CUDA is available
if hasattr(torch._C, "_nvfuser"):
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
class TestNvFuserFrontend(TestCase):
def test_basic(self) :
input1 = torch.ones(2, 4, 8, device='cuda')
input2 = torch.ones(2, 4, 8, device='cuda')
fc = FusionCache.get()
before_fusions = fc.num_fusions()
fs1 = Fusion()
with FusionDefinition(fs1) as fd :
t0 = fd.define_tensor(3)
t1 = fd.define_tensor(3)
c0 = fd.define_constant(3.0)
t2 = fd.ops.add(t0, t1)
t3 = fd.ops.mul(t2, c0)
t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
fd.add_output(t4)
# Expected Output is a tensor of 48's
nvf_out1 = fs1.execute([input1, input2])[0]
# Create a new fusion with the same definition, it should hit the cache!
fs2 = Fusion()
with FusionDefinition(fs2) as fd :
t0 = fd.define_tensor(3)
t1 = fd.define_tensor(3)
c0 = fd.define_constant(3.0)
t2 = fd.ops.add(t0, t1)
t3 = fd.ops.mul(t2, c0)
t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
fd.add_output(t4)
nvf_out2 = fs2.execute([input1, input2])[0]
# Check there is still only 1 cache entry
fc = FusionCache.get()
self.assertEqual(fc.num_fusions() - before_fusions, 1)
# Create a fusion from a fusion id and make sure it executes!
fs3 = Fusion(fs2.id())
nvf_out3 = fs3.execute([input1, input2])[0]
eager_out = torch.sum((input1 + input2) * 3.0, dim=-1)
self.assertEqual(eager_out, nvf_out1)
self.assertEqual(eager_out, nvf_out2)
self.assertEqual(eager_out, nvf_out3)
def test_basic_fp16(self) :
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor(3, DataType.Half)
t1 = fd.define_tensor(3, DataType.Half)
c0 = fd.define_constant(3.0)
t2 = fd.ops.add(t0, t1)
t3 = fd.ops.mul(t2, c0)
t4 = fd.ops.sum(t3, [-1], False, DataType.Float)
t5 = fd.ops.cast(t4, DataType.Half)
fd.add_output(t5)
input1 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16)
input2 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16)
# Expected Output is a tensor of 48's
nvf_out = fs.execute([input1, input2])[0]
eager_out = torch.sum((input1 + input2) * 3.0, dim=-1)
self.assertEqual(eager_out, nvf_out)
def test_cast_double_to_half(self) :
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor(2, DataType.Double)
t1 = fd.define_tensor(2, DataType.Double)
t0h = fd.ops.cast(t0, DataType.Half)
t1h = fd.ops.cast(t1, DataType.Half)
t2 = fd.ops.add(t0h, t1h)
t3 = fd.ops.relu(t2)
t4 = fd.ops.cast(t3, DataType.Half)
fd.add_output(t4)
input1 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
nvf_out = fs.execute([input1, input2])[0]
eager_out = torch.relu(input1.to(torch.half) + input2.to(torch.half))
self.assertEqual(eager_out, nvf_out)
def test_promote_to_double(self) :
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor(2, DataType.Half)
t1 = fd.define_tensor(2, DataType.Double)
t2 = fd.ops.add(t0, t1)
t5 = fd.ops.relu(t2)
fd.add_output(t5)
input1 = torch.randn(2, 4, device='cuda', dtype=torch.float16)
input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64)
nvf_out = fs.execute([input1, input2])[0]
eager_out = torch.relu(input1 + input2)
self.assertEqual(eager_out, nvf_out)
def test_implicit_broadcast_input(self) :
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor(1)
t1 = fd.define_tensor(3)
t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1])
t2 = fd.ops.add(t0_b, t1)
fd.add_output(t2)
input1 = torch.randn(3, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')
nvf_out = fs.execute([input1, input2])[0]
eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
self.assertEqual(eager_out, nvf_out)
def test_explicit_broadcast_input(self) :
input1 = torch.randn(1, 1, 4, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())
t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
t2 = fd.ops.add(t0_b, t1)
fd.add_output(t2)
nvf_out = fs.execute([input1, input2])[0]
eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2)
self.assertEqual(eager_out, nvf_out)
def test_broadcast_mixing(self) :
fs = Fusion()
with FusionDefinition(fs) as fd :
t0 = fd.define_tensor([3, 1], [1, 1])
t1 = fd.define_tensor(1)
t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0])
t2 = fd.ops.add(t0, t1_b)
fd.add_output(t2)
input1 = torch.randn(3, 1, device='cuda')
input2 = torch.randn(3, device='cuda')
nvf_out = fs.execute([input1, input2])[0]
eager_out = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0]))
self.assertEqual(eager_out, nvf_out)
def test_prim_layer_norm_fwd(self) :
def primitive_definition(
inputs: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalization_axis: int,
keepdim: bool,
) -> torch.Tensor:
mean = inputs.mean(normalization_axis, keepdim=keepdim)
diff = inputs - mean
diff_sq = diff * diff
var = diff_sq.mean(normalization_axis, keepdim=keepdim)
pre_shift_scale_norm_output = (inputs - mean) / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output + bias
return norm_output
def nvfuser_fusion(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool
) -> None :
inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
sum0 = fd.ops.sum(inputs, axes=[normalization_axis], keepdim=keepDim)
norm_const = fd.define_constant(norm_size)
mean = fd.ops.div(sum0, norm_const)
diff = fd.ops.sub(inputs, mean)
diff_sq = fd.ops.mul(diff, diff)
sum1 = fd.ops.sum(diff_sq, axes=[normalization_axis], keepdim=keepDim)
var = fd.ops.div(sum1, norm_const)
eps_const = fd.define_constant(eps)
var_eps = fd.ops.add(var, eps_const)
invstd = fd.ops.rsqrt(var_eps)
pre_scale_bias = fd.ops.mul(diff, invstd)
weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
scale = fd.ops.mul(pre_scale_bias, weights_bcast)
bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2])
out = fd.ops.add(scale, bias_bcast)
fd.add_output(out)
fd.add_output(mean)
fd.add_output(invstd)
def nvfuser_fusion_var_mean(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool
) -> None :
inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
var, mean = fd.ops.var_mean(inputs, axes=[normalization_axis], correction=0, keepdim=keepDim)
eps_const = fd.define_constant(eps)
var_eps = fd.ops.add(var, eps_const)
invstd = fd.ops.rsqrt(var_eps)
diff = fd.ops.sub(inputs, mean)
pre_scale_bias = fd.ops.mul(diff, invstd)
weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
scale = fd.ops.mul(pre_scale_bias, weights_bcast)
bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2])
out = fd.ops.add(scale, bias_bcast)
fd.add_output(out)
fd.add_output(mean)
fd.add_output(invstd)
input_size = [64, 128, 1024]
dtype = torch.float32
device = 'cuda'
inputs = torch.randn(*input_size, device=device, requires_grad=True)
weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
biases = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
fc = FusionCache.get()
before_fusions = fc.num_fusions()
for _ in range(5) :
nvf_fusion = Fusion()
with FusionDefinition(nvf_fusion) as fd:
nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
nvf_out = nvf_fusion.execute([inputs, weights, biases])
for _ in range(5) :
nvf_var_mean_fusion = Fusion()
with FusionDefinition(nvf_var_mean_fusion) as fd:
nvfuser_fusion_var_mean(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
nvf_var_mean_out = nvf_var_mean_fusion.execute([inputs, weights, biases])
for _ in range(5) :
eager_out = primitive_definition(inputs, weights, biases, 2, True)
self.assertEqual(eager_out, nvf_out[0])
self.assertEqual(eager_out, nvf_var_mean_out[0])
fusion_cache = FusionCache.get()
self.assertEqual(fc.num_fusions() - before_fusions, 2)
def test_prim_rms_norm_fwd(self) :
def primitive_definition(
inputs: torch.Tensor,
weight: torch.Tensor,
normalization_axis: int,
keepdim: bool,
) -> torch.Tensor:
var = inputs.mul(inputs).mean(normalization_axis, keepdim)
pre_shift_scale_norm_output = inputs / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output
return norm_output
def nvfuser_fusion(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool
) -> None :
inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
inputs_sq = fd.ops.mul(inputs, inputs)
sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim)
norm_const = fd.define_constant(norm_size)
var = fd.ops.div(sum0, norm_const)
eps_const = fd.define_constant(eps)
var_eps = fd.ops.add(var, eps_const)
invstd = fd.ops.rsqrt(var_eps)
pre_scale = fd.ops.mul(inputs, invstd)
weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2])
out = fd.ops.mul(pre_scale, weights_bcast)
fd.add_output(out)
fd.add_output(invstd)
input_size = [64, 128, 1024]
dtype = torch.float32
device = 'cuda'
inputs = torch.randn(*input_size, device=device, requires_grad=True)
weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
fc = FusionCache.get()
before_fusions = fc.num_fusions()
for _ in range(5) :
nvf_fusion = Fusion()
with FusionDefinition(nvf_fusion) as fd:
nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True)
nvf_out = nvf_fusion.execute([inputs, weights])
for _ in range(5) :
eager_out = primitive_definition(inputs, weights, 2, True)
self.assertEqual(eager_out, nvf_out[0])
self.assertEqual(fc.num_fusions() - before_fusions, 1)
if __name__ == '__main__':
run_tests()