| # Owner(s): ["module: nvfuser"] |
| |
| import unittest |
| from typing import List |
| |
| import torch |
| from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase |
| from torch.testing._internal.jit_utils import RUN_CUDA |
| import torch._refs as refs |
| import torch._prims as prims |
| |
| # Will only create the _nvfuser module if CUDA is available |
| if hasattr(torch._C, "_nvfuser"): |
| from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType |
| |
| RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM |
| |
| def is_pre_volta(): |
| if not RUN_NVFUSER: |
| return False |
| prop = torch.cuda.get_device_properties(torch.cuda.current_device()) |
| return prop.major < 7 |
| |
| @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") |
| @unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.") |
| class TestNvFuserFrontend(TestCase): |
| def test_basic(self) : |
| input1 = torch.ones(2, 4, 8, device='cuda') |
| input2 = torch.ones(2, 4, 8, device='cuda') |
| fc = FusionCache.get() |
| before_fusions = fc.num_fusions() |
| |
| fs1 = Fusion() |
| with FusionDefinition(fs1) as fd : |
| t0 = fd.define_tensor(3) |
| t1 = fd.define_tensor(3) |
| c0 = fd.define_constant(3.0) |
| |
| t2 = fd.ops.add(t0, t1) |
| t3 = fd.ops.mul(t2, c0) |
| t4 = fd.ops.sum(t3, [-1], False, DataType.Float) |
| |
| fd.add_output(t4) |
| |
| # Expected Output is a tensor of 48's |
| nvf_out1 = fs1.execute([input1, input2])[0] |
| |
| # Create a new fusion with the same definition, it should hit the cache! |
| fs2 = Fusion() |
| with FusionDefinition(fs2) as fd : |
| t0 = fd.define_tensor(3) |
| t1 = fd.define_tensor(3) |
| c0 = fd.define_constant(3.0) |
| |
| t2 = fd.ops.add(t0, t1) |
| t3 = fd.ops.mul(t2, c0) |
| t4 = fd.ops.sum(t3, [-1], False, DataType.Float) |
| |
| fd.add_output(t4) |
| |
| nvf_out2 = fs2.execute([input1, input2])[0] |
| |
| # Check there is still only 1 cache entry |
| fc = FusionCache.get() |
| self.assertEqual(fc.num_fusions() - before_fusions, 1) |
| |
| # Create a fusion from a fusion id and make sure it executes! |
| fs3 = Fusion(fs2.id()) |
| nvf_out3 = fs3.execute([input1, input2])[0] |
| |
| eager_out = torch.sum((input1 + input2) * 3.0, dim=-1) |
| self.assertEqual(eager_out, nvf_out1) |
| self.assertEqual(eager_out, nvf_out2) |
| self.assertEqual(eager_out, nvf_out3) |
| |
| def test_basic_fp16(self) : |
| fs = Fusion() |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor(3, DataType.Half) |
| t1 = fd.define_tensor(3, DataType.Half) |
| c0 = fd.define_constant(3.0) |
| |
| t2 = fd.ops.add(t0, t1) |
| t3 = fd.ops.mul(t2, c0) |
| t4 = fd.ops.sum(t3, [-1], False, DataType.Float) |
| |
| t5 = fd.ops.cast(t4, DataType.Half) |
| fd.add_output(t5) |
| |
| input1 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16) |
| input2 = torch.ones(2, 4, 8, device='cuda', dtype=torch.float16) |
| |
| # Expected Output is a tensor of 48's |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = torch.sum((input1 + input2) * 3.0, dim=-1) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_cast_double_to_half(self) : |
| fs = Fusion() |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor(2, DataType.Double) |
| t1 = fd.define_tensor(2, DataType.Double) |
| |
| t0h = fd.ops.cast(t0, DataType.Half) |
| t1h = fd.ops.cast(t1, DataType.Half) |
| t2 = fd.ops.add(t0h, t1h) |
| t3 = fd.ops.relu(t2) |
| t4 = fd.ops.cast(t3, DataType.Half) |
| |
| fd.add_output(t4) |
| |
| input1 = torch.randn(2, 4, device='cuda', dtype=torch.float64) |
| input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64) |
| |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = torch.relu(input1.to(torch.half) + input2.to(torch.half)) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_promote_to_double(self) : |
| fs = Fusion() |
| |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor(2, DataType.Half) |
| t1 = fd.define_tensor(2, DataType.Double) |
| |
| t2 = fd.ops.add(t0, t1) |
| t5 = fd.ops.relu(t2) |
| |
| fd.add_output(t5) |
| |
| input1 = torch.randn(2, 4, device='cuda', dtype=torch.float16) |
| input2 = torch.randn(2, 4, device='cuda', dtype=torch.float64) |
| |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = torch.relu(input1 + input2) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_implicit_broadcast_input(self) : |
| fs = Fusion() |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor(1) |
| t1 = fd.define_tensor(3) |
| |
| t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1]) |
| t2 = fd.ops.add(t0_b, t1) |
| |
| fd.add_output(t2) |
| |
| input1 = torch.randn(3, device='cuda') |
| input2 = torch.randn(2, 3, 4, device='cuda') |
| |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_explicit_broadcast_input(self) : |
| input1 = torch.randn(1, 1, 4, device='cuda') |
| input2 = torch.randn(2, 3, 4, device='cuda') |
| |
| fs = Fusion() |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) |
| t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride()) |
| |
| t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) |
| t2 = fd.ops.add(t0_b, t1) |
| |
| fd.add_output(t2) |
| |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_broadcast_mixing(self) : |
| fs = Fusion() |
| with FusionDefinition(fs) as fd : |
| t0 = fd.define_tensor([3, 1], [1, 1]) |
| t1 = fd.define_tensor(1) |
| |
| t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0]) |
| t2 = fd.ops.add(t0, t1_b) |
| |
| fd.add_output(t2) |
| |
| input1 = torch.randn(3, 1, device='cuda') |
| input2 = torch.randn(3, device='cuda') |
| |
| nvf_out = fs.execute([input1, input2])[0] |
| eager_out = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0])) |
| self.assertEqual(eager_out, nvf_out) |
| |
| def test_prim_layer_norm_fwd(self) : |
| def primitive_definition( |
| inputs: torch.Tensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor, |
| normalization_axis: int, |
| keepdim: bool, |
| ) -> torch.Tensor: |
| mean = inputs.mean(normalization_axis, keepdim=keepdim) |
| diff = inputs - mean |
| diff_sq = diff * diff |
| var = diff_sq.mean(normalization_axis, keepdim=keepdim) |
| pre_shift_scale_norm_output = (inputs - mean) / torch.sqrt(var + 1e-12) |
| norm_output = weight * pre_shift_scale_norm_output + bias |
| return norm_output |
| |
| def nvfuser_fusion( |
| fd: FusionDefinition, |
| normalization_axis: int, |
| norm_size: int, |
| input_shape: List[int], |
| eps: float, |
| keepDim: bool |
| ) -> None : |
| inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) |
| weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) |
| bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) |
| sum0 = fd.ops.sum(inputs, axes=[normalization_axis], keepdim=keepDim) |
| norm_const = fd.define_constant(norm_size) |
| mean = fd.ops.div(sum0, norm_const) |
| diff = fd.ops.sub(inputs, mean) |
| diff_sq = fd.ops.mul(diff, diff) |
| sum1 = fd.ops.sum(diff_sq, axes=[normalization_axis], keepdim=keepDim) |
| var = fd.ops.div(sum1, norm_const) |
| eps_const = fd.define_constant(eps) |
| var_eps = fd.ops.add(var, eps_const) |
| invstd = fd.ops.rsqrt(var_eps) |
| pre_scale_bias = fd.ops.mul(diff, invstd) |
| weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2]) |
| scale = fd.ops.mul(pre_scale_bias, weights_bcast) |
| bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2]) |
| out = fd.ops.add(scale, bias_bcast) |
| fd.add_output(out) |
| fd.add_output(mean) |
| fd.add_output(invstd) |
| |
| def nvfuser_fusion_var_mean( |
| fd: FusionDefinition, |
| normalization_axis: int, |
| norm_size: int, |
| input_shape: List[int], |
| eps: float, |
| keepDim: bool |
| ) -> None : |
| inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) |
| weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) |
| bias = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) |
| var, mean = fd.ops.var_mean(inputs, axes=[normalization_axis], correction=0, keepdim=keepDim) |
| eps_const = fd.define_constant(eps) |
| var_eps = fd.ops.add(var, eps_const) |
| invstd = fd.ops.rsqrt(var_eps) |
| diff = fd.ops.sub(inputs, mean) |
| pre_scale_bias = fd.ops.mul(diff, invstd) |
| weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2]) |
| scale = fd.ops.mul(pre_scale_bias, weights_bcast) |
| bias_bcast = fd.ops.broadcast_in_dim(bias, output_shape=input_shape, broadcast_dims=[2]) |
| out = fd.ops.add(scale, bias_bcast) |
| fd.add_output(out) |
| fd.add_output(mean) |
| fd.add_output(invstd) |
| |
| input_size = [64, 128, 1024] |
| dtype = torch.float32 |
| device = 'cuda' |
| inputs = torch.randn(*input_size, device=device, requires_grad=True) |
| weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device)) |
| biases = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device)) |
| fc = FusionCache.get() |
| before_fusions = fc.num_fusions() |
| |
| for _ in range(5) : |
| nvf_fusion = Fusion() |
| with FusionDefinition(nvf_fusion) as fd: |
| nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True) |
| nvf_out = nvf_fusion.execute([inputs, weights, biases]) |
| |
| for _ in range(5) : |
| nvf_var_mean_fusion = Fusion() |
| with FusionDefinition(nvf_var_mean_fusion) as fd: |
| nvfuser_fusion_var_mean(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True) |
| nvf_var_mean_out = nvf_var_mean_fusion.execute([inputs, weights, biases]) |
| |
| for _ in range(5) : |
| eager_out = primitive_definition(inputs, weights, biases, 2, True) |
| |
| self.assertEqual(eager_out, nvf_out[0]) |
| self.assertEqual(eager_out, nvf_var_mean_out[0]) |
| fusion_cache = FusionCache.get() |
| self.assertEqual(fc.num_fusions() - before_fusions, 2) |
| |
| def test_prim_rms_norm_fwd(self) : |
| def primitive_definition( |
| inputs: torch.Tensor, |
| weight: torch.Tensor, |
| normalization_axis: int, |
| keepdim: bool, |
| ) -> torch.Tensor: |
| var = inputs.mul(inputs).mean(normalization_axis, keepdim) |
| pre_shift_scale_norm_output = inputs / torch.sqrt(var + 1e-12) |
| norm_output = weight * pre_shift_scale_norm_output |
| return norm_output |
| |
| def nvfuser_fusion( |
| fd: FusionDefinition, |
| normalization_axis: int, |
| norm_size: int, |
| input_shape: List[int], |
| eps: float, |
| keepDim: bool |
| ) -> None : |
| inputs = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) |
| weights = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) |
| inputs_sq = fd.ops.mul(inputs, inputs) |
| sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim) |
| norm_const = fd.define_constant(norm_size) |
| var = fd.ops.div(sum0, norm_const) |
| eps_const = fd.define_constant(eps) |
| var_eps = fd.ops.add(var, eps_const) |
| invstd = fd.ops.rsqrt(var_eps) |
| pre_scale = fd.ops.mul(inputs, invstd) |
| weights_bcast = fd.ops.broadcast_in_dim(weights, output_shape=input_shape, broadcast_dims=[2]) |
| out = fd.ops.mul(pre_scale, weights_bcast) |
| fd.add_output(out) |
| fd.add_output(invstd) |
| |
| input_size = [64, 128, 1024] |
| dtype = torch.float32 |
| device = 'cuda' |
| inputs = torch.randn(*input_size, device=device, requires_grad=True) |
| weights = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device)) |
| fc = FusionCache.get() |
| before_fusions = fc.num_fusions() |
| |
| for _ in range(5) : |
| nvf_fusion = Fusion() |
| with FusionDefinition(nvf_fusion) as fd: |
| nvfuser_fusion(fd, 2, inputs.size()[2], inputs.size(), 1e-12, True) |
| nvf_out = nvf_fusion.execute([inputs, weights]) |
| |
| for _ in range(5) : |
| eager_out = primitive_definition(inputs, weights, 2, True) |
| |
| self.assertEqual(eager_out, nvf_out[0]) |
| self.assertEqual(fc.num_fusions() - before_fusions, 1) |
| |
| if __name__ == '__main__': |
| run_tests() |