| # Owner(s): ["module: inductor"] |
| import contextlib |
| import copy |
| import itertools |
| import math |
| import platform |
| import sys |
| import unittest |
| from typing import Callable |
| from unittest.mock import patch |
| |
| import numpy as np |
| import sympy |
| import torch |
| from torch._C import FileCheck |
| from torch._dynamo.exc import BackendCompilerFailed |
| from torch._dynamo.testing import rand_strided |
| from torch._dynamo.utils import same |
| from torch._inductor import codecache, config, metrics |
| from torch._inductor.codegen.common import OptimizationContext |
| from torch._inductor.codegen.cpp import ( |
| CppOverrides, |
| CppVecKernelChecker, |
| CppVecOverrides, |
| ) |
| from torch._inductor.compile_fx import ( |
| compile_fx, |
| compile_fx_inner, |
| complex_memory_overlap, |
| ) |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.ir import InterpreterShim |
| from torch._inductor.utils import timed |
| from torch._inductor.virtualized import V |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.nn import functional as F |
| from torch.testing._internal.common_utils import IS_MACOS, slowTest |
| from torch.utils._python_dispatch import TorchDispatchMode |
| |
| try: |
| try: |
| from . import test_torchinductor |
| except ImportError: |
| import test_torchinductor |
| except unittest.SkipTest: |
| if __name__ == "__main__": |
| sys.exit(0) |
| raise |
| |
| |
| vec_dtypes = test_torchinductor.vec_dtypes |
| _lowp_fp_dtypes = ( |
| torch.bfloat16, |
| torch.float16, |
| ) |
| run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code |
| TestCase = test_torchinductor.TestCase |
| aten = torch.ops.aten |
| check_model = test_torchinductor.check_model |
| |
| |
| class LstmModule(torch.nn.Module): |
| def __init__( |
| self, |
| input_size, |
| hidden_size, |
| num_layers, |
| bias=True, |
| bidirectional=False, |
| batch_first=False, |
| ): |
| super().__init__() |
| self.lstm = torch.nn.LSTM( |
| input_size=input_size, |
| hidden_size=hidden_size, |
| num_layers=num_layers, |
| bias=bias, |
| bidirectional=bidirectional, |
| batch_first=batch_first, |
| ) |
| |
| def forward(self, x, h=None): |
| x, h = self.lstm(x, h) |
| return x, h |
| |
| |
| class CPUReproTests(TestCase): |
| common = check_model |
| |
| def test_conv_stride_constraints(self): |
| for fmt in [torch.contiguous_format, torch.channels_last]: |
| # TorchDispatch doesn't work in our cuda invocation for some reason |
| m = torch.nn.Conv2d(5, 6, [3, 3]) |
| |
| def fn(inp, weight): |
| return ( |
| F.conv2d( |
| inp, weight, None, m.stride, m.padding, m.dilation, m.groups |
| ), |
| ) |
| |
| inp = torch.randn([2, 5, 16, 16]) |
| inps = [inp, m.weight.to(memory_format=fmt)] |
| fn_fx = make_fx(fn)(*inps) |
| fn_compiled = compile_fx_inner(fn_fx, inps) |
| test_self = self |
| conv_seen = False |
| |
| class RecordFunctions(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs if kwargs else {} |
| if func == torch.ops.aten.convolution.default: |
| # For CPU and mkldnn enable, we always using channles last |
| nonlocal fmt |
| if ( |
| torch.backends.mkldnn.enabled |
| and torch.backends.mkldnn.is_available() |
| ): |
| fmt = torch.channels_last |
| test_self.assertTrue(args[0].is_contiguous(memory_format=fmt)) |
| test_self.assertTrue(args[1].is_contiguous(memory_format=fmt)) |
| nonlocal conv_seen |
| conv_seen = True |
| |
| return func(*args, **kwargs) |
| |
| with RecordFunctions(): |
| out = fn_compiled(inps) |
| |
| self.assertTrue(conv_seen) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv2d_bn_mixed_dtype(self): |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| bias=False, |
| dtype=torch.bfloat16, |
| ) |
| self.bn = torch.nn.BatchNorm2d( |
| 16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True |
| ) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) |
| mod = Model().eval() |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv2d_packed(self): |
| options = itertools.product([[3, 56, 56]], [True, False], [0, (0,)]) |
| for x_shape, mode_train, padding in options: |
| mod = torch.nn.Sequential( |
| torch.nn.Conv2d(3, 64, 3, 3, padding=padding) |
| ).train(mode=mode_train) |
| v = torch.randn(x_shape, dtype=torch.float32) |
| |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv2d_autocast(self): |
| v = torch.randn(1, 3, 28, 18, dtype=torch.float32) |
| mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval() |
| with torch.no_grad(), torch.cpu.amp.autocast(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_unsupported_conv_transpose(self): |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv_transpose = torch.nn.ConvTranspose2d( |
| 3, 6, 3, stride=1, padding=1, output_padding=1 |
| ) |
| |
| def forward(self, input_tensor): |
| x = self.conv_transpose(input_tensor) |
| output = torch.tanh(x) |
| return output |
| |
| input = torch.randn(1, 3, 28, 28) |
| m = Model().eval() |
| |
| with torch.no_grad(): |
| compiled_m = torch.compile(m) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "output padding must be smaller than either stride or dilation", |
| ): |
| compiled_m(input) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv_used_from_multiple_places(self): |
| class M(torch.nn.Module): |
| def __init__(self, conv_in_channel, conv_out_channel) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(conv_in_channel, conv_out_channel, (3, 3)) |
| |
| def forward(self, x): |
| res = self.conv(x) |
| res = F.relu(res) |
| res = self.conv(res) |
| return res |
| |
| with torch.no_grad(): |
| mod = M(3, 3).eval() |
| x = torch.randn(1, 3, 224, 224) |
| self.common( |
| mod, |
| (x,), |
| ) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_linear_used_from_multiple_places(self): |
| class M(torch.nn.Module): |
| def __init__(self, in_channel, out_channel) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(in_channel, out_channel) |
| |
| def forward(self, x): |
| res = self.linear(x) |
| res = F.relu(res) |
| res = self.linear(res) |
| return res |
| |
| if torch.ops.mkldnn._is_mkldnn_bf16_supported(): |
| with torch.no_grad(): |
| m = M(224, 224).bfloat16().eval() |
| m_opt = torch.compile(m) |
| x = torch.randn(224, 224, dtype=torch.bfloat16) |
| m_opt(x) |
| self.assertEqual(m(x), m_opt(x)) |
| |
| @config.patch(implicit_fallbacks=True) |
| def test_multihead_attention_cpu(self): |
| def fn( |
| q, |
| k, |
| v, |
| embed_dim, |
| num_heads, |
| qkv_weight, |
| qkv_bias, |
| proj_weight, |
| proj_bias, |
| mask, |
| need_weights, |
| ): |
| return torch._native_multi_head_attention( |
| q, |
| k, |
| v, |
| embed_dim, |
| num_heads, |
| qkv_weight, |
| qkv_bias, |
| proj_weight, |
| proj_bias, |
| mask, |
| need_weights, |
| ) |
| |
| B = 1 |
| T = 3 |
| embed_dim = 6 |
| num_heads = 2 |
| q = torch.randn([B, T, embed_dim]) |
| k = torch.randn([B, T, embed_dim]) |
| v = torch.randn([B, T, embed_dim]) |
| qkv_weight = torch.randn([3 * embed_dim, embed_dim]) |
| qkv_bias = torch.randn([3 * embed_dim]) |
| proj_weight = torch.randn([3 * embed_dim, embed_dim]) |
| proj_bias = torch.randn([3 * embed_dim]) |
| mask = None |
| need_weights = False |
| |
| inps = [ |
| q, |
| k, |
| v, |
| embed_dim, |
| num_heads, |
| qkv_weight, |
| qkv_bias, |
| proj_weight, |
| proj_bias, |
| mask, |
| need_weights, |
| ] |
| self.common(fn, inps) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_linear_packed(self): |
| options = itertools.product( |
| [[2, 3, 10], [2, 10], [10], [2, 0]], [3, 0], [True, False] |
| ) |
| for input_shape, out_dim, bias in options: |
| mod = torch.nn.Sequential( |
| torch.nn.Linear(input_shape[-1], out_dim, bias=bias) |
| ).eval() |
| |
| v = torch.randn(input_shape) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| if torch.ops.mkldnn._is_mkldnn_bf16_supported() and len(input_shape) > 1: |
| mod = mod.to(torch.bfloat16) |
| v = v.to(torch.bfloat16) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv_transpose2d_packed_cpu(self): |
| options = itertools.product([[1, 3, 28, 28], [3, 28, 28]], [0, (0,)]) |
| for x_shape, padding in options: |
| mod = torch.nn.Sequential( |
| torch.nn.ConvTranspose2d(3, 64, 3, 3, padding=padding) |
| ).eval() |
| v = torch.randn(x_shape, dtype=torch.float32) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| @torch._dynamo.config.patch(assume_static_by_default=False) |
| @torch._dynamo.config.patch(allow_rnn=True) |
| @config.patch(freezing=True) |
| def _test_lstm_packed(self, params_dict, change_input_sizes=False): |
| from torch._dynamo.utils import counters |
| |
| for ( |
| unbatched, |
| input_size, |
| hidden_size, |
| num_layers, |
| bidirectional, |
| bias, |
| empty_state, |
| batch_first, |
| batch_size, |
| seq_len, |
| ) in itertools.product(*list(params_dict.values())): |
| dtypes = [torch.float] |
| if torch.ops.mkldnn._is_mkldnn_bf16_supported(): |
| dtypes.append(torch.bfloat16) |
| for dtype in dtypes: |
| counters.clear() |
| num_directions = 2 if bidirectional else 1 |
| |
| seq_len_var = seq_len + 3 |
| if unbatched: |
| v = torch.randn(seq_len, input_size) |
| v_var = torch.randn(seq_len_var, input_size) |
| h = torch.randn(num_layers * num_directions, hidden_size) |
| c = torch.randn(num_layers * num_directions, hidden_size) |
| else: |
| if batch_first: |
| v = torch.randn(batch_size, seq_len, input_size) |
| v_var = torch.randn(batch_size, seq_len_var, input_size) |
| else: |
| v = torch.randn(seq_len, batch_size, input_size) |
| v_var = torch.randn(seq_len_var, batch_size, input_size) |
| h = torch.randn( |
| num_layers * num_directions, batch_size, hidden_size |
| ) |
| c = torch.randn( |
| num_layers * num_directions, batch_size, hidden_size |
| ) |
| |
| mod = LstmModule( |
| input_size, |
| hidden_size, |
| num_layers, |
| bias, |
| bidirectional, |
| batch_first, |
| ).eval() |
| maybe_autocast = ( |
| torch.cpu.amp.autocast() |
| if dtype == torch.bfloat16 |
| else contextlib.nullcontext() |
| ) |
| |
| with torch.no_grad(), maybe_autocast: |
| inps = [v] |
| if not empty_state: |
| inps.append((h, c)) |
| |
| fn_opt = torch._dynamo.optimize("inductor")(mod) |
| _, code = run_and_get_cpp_code(fn_opt, *inps) |
| |
| # Check that _flat_weights are not functional_tensor, otherwise |
| # deepcopy will fail during recompilation. |
| fn_opt_copy = copy.deepcopy(fn_opt) |
| _flat_weights = fn_opt_copy.lstm._flat_weights |
| for _flat_weight in _flat_weights: |
| self.assertFalse(torch._is_functional_tensor(_flat_weight)) |
| |
| self.assertTrue("aten.mkldnn_rnn_layer" in code) |
| self.assertEqual(fn_opt(*inps), mod(*inps)) |
| self.assertEqual( |
| counters["inductor"]["pattern_matcher_count"], |
| num_layers * num_directions |
| + 2, # num of mkldnn_rnn_layer call + 2 view call on the concatenated hy, cy. |
| ) |
| |
| # Change input sizes |
| if change_input_sizes: |
| inps_var = [v_var] |
| self.assertEqual(fn_opt(*inps_var), mod(*inps_var)) |
| |
| @slowTest |
| def test_lstm_packed(self): |
| params_dict = { |
| "unbatched": [True, False], |
| "input_size": [1, 2], |
| "hidden_size": [5, 32], |
| "num_layers": [1, 3], |
| "bidirectional": [False, True], |
| "bias": [False, True], |
| "empty_state": [False, True], |
| "batch_first": [True, False], |
| "batch_size": [1, 2], |
| "seq_len": [1, 3], |
| } |
| self._test_lstm_packed(params_dict) |
| |
| def test_lstm_packed_change_input_sizes_cpu(self): |
| params_dict = { |
| "unbatched": [False], |
| "input_size": [2], |
| "hidden_size": [5], |
| "num_layers": [3], |
| "bidirectional": [True], |
| "bias": [True], |
| "empty_state": [False], |
| "batch_first": [False], |
| "batch_size": [2], |
| "seq_len": [3], |
| } |
| self._test_lstm_packed(params_dict, change_input_sizes=True) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| @torch._dynamo.config.patch(assume_static_by_default=False) |
| @torch._dynamo.config.patch(allow_rnn=True) |
| def test_pack_padded_sequence_lstm(self): |
| embedding_dim = 12 |
| hidden_dim = 10 |
| batch_size = 24 |
| num_layers = 1 |
| bidirectional = True |
| num_direc = 2 |
| max_lens = 96 |
| |
| sent = torch.randn(batch_size, max_lens, embedding_dim) |
| hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim) |
| hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim) |
| |
| sent_lens = torch.Tensor( |
| [1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, 1, 2, 4, 6, 2, 1] |
| ) |
| |
| assert sent_lens.shape[0] == batch_size |
| assert sent_lens.max().item() == max_lens |
| |
| hidden_0 = hid_0.clone().requires_grad_(False) |
| hidden_1 = hid_1.clone().requires_grad_(False) |
| embeds = torch.nn.utils.rnn.pack_padded_sequence( |
| sent, sent_lens, batch_first=True, enforce_sorted=False |
| ) |
| |
| mod = LstmModule( |
| embedding_dim, |
| hidden_dim, |
| num_layers=num_layers, |
| bias=True, |
| bidirectional=bidirectional, |
| batch_first=True, |
| ).eval() |
| |
| with torch.no_grad(): |
| inps = [embeds, (hidden_0, hidden_1)] |
| fn_opt = torch._dynamo.optimize("inductor")(mod) |
| _, code = run_and_get_cpp_code(fn_opt, *inps) |
| # This case is unsupported |
| self.assertFalse("torch.ops.mkldnn._lstm" in code) |
| self.assertEqual(fn_opt(*inps), mod(*inps)) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_conv_transpose2d_has_output_size_input(self): |
| # https://github.com/pytorch/pytorch/issues/100344. |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv_transpose = torch.nn.ConvTranspose2d( |
| in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1 |
| ) |
| |
| def forward(self, x): |
| return self.conv_transpose(x, output_size=(10, 10)) |
| |
| mod = M().eval() |
| v = torch.randn(1, 3, 10, 10, dtype=torch.float32) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| def test_pad_with_nan_value(self): |
| # https://github.com/pytorch/pytorch/issues/100988. |
| class Model(torch.nn.Module): |
| def forward(self, x): |
| x = F.pad(x, (1, 1, 1, 1), value=float("nan")) |
| return x |
| |
| mod = Model().eval() |
| v = torch.randn(1, 3, 10, 10, dtype=torch.float32) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| def test_masked_fill_with_inf_or_nan_value(self): |
| def fn(value, mask): |
| y1 = torch.masked_fill(value, mask, float("inf")) |
| y2 = torch.masked_fill(value, mask, float("-inf")) |
| y3 = torch.masked_fill(value, mask, float("nan")) |
| return y1, y2, y3 |
| |
| value = torch.randn((2, 17)) |
| mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool) |
| with torch.no_grad(): |
| self.common( |
| fn, |
| (value, mask), |
| ) |
| |
| @config.patch(implicit_fallbacks=True) |
| def test_repeat_interleave(self): |
| def fn(y): |
| return torch.repeat_interleave(y, 2, output_size=8) |
| |
| a = torch.tensor([[1, 2], [3, 4]]) |
| self.common( |
| fn, |
| (a,), |
| ) |
| |
| def test_inplace_squeeze_needed(self): |
| mod = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.LayerNorm(10), |
| torch.nn.ReLU(), |
| ).eval() |
| |
| def fn(x): |
| return mod(x) |
| |
| v = torch.randn(10) |
| # TODO: OMP parallel reduction order is not deterministic. |
| # Hence, the accurarcy might vary up and down. For short term, |
| # we increase the tolerance and will fix it later by using |
| # aten parallel. |
| self.common(fn, (v,), atol=5e-1, rtol=5e-1) |
| |
| def test_cat_mul(self): |
| # https://github.com/pytorch/pytorch/issues/93365 |
| def fn(p0, p1): |
| y1 = torch.cat([p0, p1], dim=0) |
| y2 = torch.mul(y1, y1) |
| return y1, y2 |
| |
| p0 = torch.randn(3, 4) |
| p1 = torch.randn(3, 4) |
| self.common(fn, (p0, p1)) |
| |
| def test_pow_cos(self): |
| # https://github.com/pytorch/pytorch/issues/98149 |
| def fn(x): |
| t = x.pow(5) |
| return torch.cos(t) |
| |
| x = torch.tensor([4], dtype=torch.uint8) |
| self.common(fn, (x,)) |
| |
| def test_reduce_with_masked(self): |
| # https://github.com/pytorch/pytorch/issues/96484 |
| def fn(a, b): |
| a = torch.nn.functional.pad(a, (0, -1)) |
| c = a + b |
| return c.min(0).values |
| |
| a = torch.randn([2]) |
| b = torch.randn([2]) |
| self.common(fn, (a, b)) |
| |
| def test_scalar_sign_with_min(self): |
| # https://github.com/pytorch/pytorch/issues/101340 |
| def fn(a): |
| t1 = torch.tanh(a) |
| t2 = torch.sign(t1) |
| return torch.min(t1, t2) |
| |
| a = torch.randn(1, 3) |
| self.common(fn, (a,)) |
| |
| def test_index_propagation_issue_102065(self): |
| def fn(x): |
| x = torch.arange(x.numel()) |
| return (x.unsqueeze(0) - x.unsqueeze(1)) ** 2 |
| |
| self.common( |
| fn, |
| (torch.randn(8),), |
| ) |
| |
| def test_ModularIndexing_range_issue_103133(self): |
| def fn(q, k): |
| einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k)) |
| constant_pad_nd = torch.ops.aten.constant_pad_nd.default( |
| einsum, [0, 0, 0, 1], 0.0 |
| ) |
| view = torch.ops.aten.view.default(constant_pad_nd, [12, 1, 512, 513]) |
| y = view.new_zeros((12, 2, 256, 513)) |
| y[:, :-1, :, 256:] = view[:, :, :256, :257] |
| return y |
| |
| self.common( |
| fn, |
| ( |
| torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)), |
| torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)), |
| ), |
| ) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_max_reduction_lowp_fp(self): |
| def fn(x): |
| return torch.ops.aten.max(x, 1, keepdim=True)[0].float() |
| |
| for dtype in _lowp_fp_dtypes: |
| self.common( |
| fn, |
| (torch.randn(1, 32, 4, 4).to(dtype),), |
| ) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_transpose_lowp_fp(self): |
| for dtype in _lowp_fp_dtypes: |
| |
| def fn(x): |
| return x.to(memory_format=torch.channels_last).to(dtype) |
| |
| self.common( |
| fn, |
| (torch.randn(2, 3, 4, 4),), |
| ) |
| |
| def test_load_inf_bf16(self): |
| def fn1(x): |
| return torch.where(x > 0, x, math.inf) |
| |
| def fn2(x): |
| return torch.where(x > 0, x, -math.inf) |
| |
| for fn in [fn1, fn2]: |
| self.common( |
| fn, |
| (torch.randn(1, 3, 16, 16),), |
| ) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_fp32_load_with_to_lowp_fp(self): |
| # From llama model. |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.cache_k = torch.zeros(8, 4, 2, 2) |
| |
| def forward(self, x, xk): |
| bsz, seqlen, _ = x.shape |
| self.cache_k = self.cache_k.to(x) |
| self.cache_k[:bsz, 1 : 1 + seqlen] = xk |
| return self.cache_k |
| |
| for dtype in _lowp_fp_dtypes: |
| ref_model = Model().eval() |
| opt_model = torch.compile()(Model().eval()) |
| x = torch.randn(4, 2, 2).to(dtype) |
| xk = torch.randn(4, 2, 2, 2).to(dtype) |
| self.assertEqual(opt_model(x, xk), ref_model(x, xk)) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_sigmoid_with_reduction(self): |
| def fn(x): |
| x = torch.ops.aten.sigmoid.default(x) |
| return torch.ops.aten.mean.dim(x, [-1, -2], True) |
| |
| x = torch.randn((1, 8, 8, 8)) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| |
| def test_slice_scatter_default_end_value(self): |
| # From HF AllenaiLongformerBase. |
| def fn(query, key, window_overlap): |
| batch_size, seq_len, num_heads, head_dim = query.size() |
| assert ( |
| seq_len % (window_overlap * 2) == 0 |
| ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" |
| |
| chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 |
| diagonal_chunked_attention_scores = key |
| diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( |
| ( |
| batch_size * num_heads, |
| chunks_count + 1, |
| window_overlap, |
| window_overlap * 2 + 1, |
| ) |
| ) |
| diagonal_attention_scores[ |
| :, :3, :, window_overlap: |
| ] = diagonal_chunked_attention_scores[ |
| :, :, :window_overlap, : window_overlap + 1 |
| ] |
| return diagonal_attention_scores |
| |
| self.common( |
| fn, |
| ( |
| torch.randn(1, 1024, 12, 64), |
| torch.randn(12, 3, 512, 513), |
| 256, |
| ), |
| ) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_to_uint8_rounding_method(self): |
| def fn(x): |
| return x.to(torch.uint8) |
| |
| numerical_testsuit = [4.4, 4.5, 4.6, 5.5] |
| for numerical_number in numerical_testsuit: |
| x = torch.ones(17) * numerical_number |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_decomposed_dequant_relu_quant(self): |
| def fn(x, scale, zero_point, use_dequant, use_quant): |
| # For quantized_decomposed.dequantize_per_tensor |
| # Refer to torch/ao/quantization/fx/_decomposed.py |
| if use_dequant: |
| x = (x.to(torch.float32) - zero_point) * scale |
| |
| x = torch.relu(x) |
| |
| # For quantized_decomposed.quantize_per_tensor |
| # Refer to torch/ao/quantization/fx/_decomposed.py |
| if use_quant: |
| inv_scale = 1.0 / scale |
| x = torch.clamp(torch.round(x * inv_scale) + zero_point, 0, 255).to( |
| torch.uint8 |
| ) |
| return x |
| |
| use_dequant_list = [False, True] |
| use_quant_list = [False, True] |
| for use_dequant, use_quant in itertools.product( |
| use_dequant_list, use_quant_list |
| ): |
| x = torch.clamp( |
| torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255 |
| ) |
| if use_dequant: |
| x = x.to(torch.uint8) |
| zero_point = 100 |
| scale = 0.01 |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x, scale, zero_point, use_dequant, use_quant)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_dequant_quant_lowering(self): |
| def fn(x, scale, zero_point, use_dequant, use_quant): |
| if use_dequant: |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, scale, zero_point, 0, 255, torch.uint8 |
| ) |
| |
| x = torch.relu(x) |
| |
| if use_quant: |
| x = torch.ops.quantized_decomposed.quantize_per_tensor( |
| x, scale, zero_point, 0, 255, torch.uint8 |
| ) |
| return x |
| |
| use_dequant_list = [False, True] |
| use_quant_list = [False, True] |
| use_tensor_overload_list = [False, True] |
| for use_dequant, use_quant, use_tensor_overload in itertools.product( |
| use_dequant_list, use_quant_list, use_tensor_overload_list |
| ): |
| x = torch.clamp( |
| torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255 |
| ) |
| if use_dequant: |
| x = x.to(torch.uint8) |
| zero_point = 100 |
| scale = 0.01 |
| if use_tensor_overload: |
| zero_point = torch.tensor(zero_point, dtype=torch.int64) |
| scale = torch.tensor(scale) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x, scale, zero_point, use_dequant, use_quant)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_dequant_maxpool2d_lowering(self): |
| def fn(x, scale, zero_point): |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, scale, zero_point, 0, 255, torch.uint8 |
| ) |
| max_pool2d_with_indices_default = ( |
| torch.ops.aten.max_pool2d_with_indices.default( |
| x, [2, 2], [2, 2], [1, 1] |
| )[0] |
| ) |
| return max_pool2d_with_indices_default |
| |
| use_tensor_overload_list = [False, True] |
| for use_tensor_overload in use_tensor_overload_list: |
| x = ( |
| torch.clamp( |
| torch.randn((3, 16, 8, 8), dtype=torch.float32) * 100, 0, 255 |
| ) |
| .to(torch.uint8) |
| .contiguous(memory_format=torch.channels_last) |
| ) |
| zero_point = 100 |
| scale = 0.01 |
| if use_tensor_overload: |
| zero_point = torch.tensor(zero_point, dtype=torch.int64) |
| scale = torch.tensor(scale) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x, scale, zero_point)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_tile2d_load_decomposed_dequant_add_relu_quant(self): |
| def fn( |
| x, |
| scale, |
| zero_point, |
| x2, |
| scale2, |
| zero_point2, |
| output_scale, |
| output_zero_point, |
| use_dequant, |
| use_dequant2, |
| use_quant, |
| ): |
| if use_dequant: |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, scale, zero_point, 0, 255, torch.uint8 |
| ) |
| if use_dequant2: |
| x2 = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x2, scale2, zero_point2, 0, 255, torch.uint8 |
| ) |
| temp = x + x2 |
| y = torch.relu(temp) |
| |
| if use_quant: |
| y = torch.ops.quantized_decomposed.quantize_per_tensor( |
| y, output_scale, output_zero_point, 0, 255, torch.uint8 |
| ) |
| return y.contiguous() |
| |
| use_dequant_list = [False, True] |
| use_dequant_list2 = [False, True] |
| use_quant_list = [False, True] |
| for use_dequant, use_dequant2, use_quant in itertools.product( |
| use_dequant_list, use_dequant_list2, use_quant_list |
| ): |
| x = torch.clamp( |
| torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255 |
| ).contiguous(memory_format=torch.channels_last) |
| x2 = torch.clamp( |
| torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255 |
| ).contiguous(memory_format=torch.channels_last) |
| if use_dequant: |
| x = x.to(torch.uint8).contiguous(memory_format=torch.channels_last) |
| if use_dequant2: |
| x2 = x2.to(torch.uint8).contiguous(memory_format=torch.channels_last) |
| zero_point = 1 |
| scale = 0.01 |
| zero_point2 = 2 |
| scale2 = 0.02 |
| output_zero_point = 3 |
| output_scale = 0.03 |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common( |
| fn, |
| ( |
| x, |
| scale, |
| zero_point, |
| x2, |
| scale2, |
| zero_point2, |
| output_scale, |
| output_zero_point, |
| use_dequant, |
| use_dequant2, |
| use_quant, |
| ), |
| ) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_non_contiguous_load_buf_quant(self): |
| def fn( |
| x1, |
| x2, |
| groups, |
| ): |
| x = torch.cat((x1, x2), dim=1) |
| batchsize, num_channels, height, width = x.size() |
| channels_per_group = num_channels // groups |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, 1.0, 0, 0, 255, torch.uint8 |
| ) |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.ops.quantized_decomposed.quantize_per_tensor( |
| x, 1.0, 0, 0, 255, torch.uint8 |
| ) |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, 1.0, 0, 0, 255, torch.uint8 |
| ) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batchsize, num_channels, height, width) |
| return x |
| |
| x = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous( |
| memory_format=torch.channels_last |
| ) |
| x2 = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous( |
| memory_format=torch.channels_last |
| ) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common( |
| fn, |
| ( |
| x, |
| x2, |
| 2, |
| ), |
| ) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_tile2d_store_channel_shuffle_cl_quant_output(self): |
| def channel_shuffle(x, groups, output_scale, output_zero_point): |
| batchsize, num_channels, height, width = x.size() |
| channels_per_group = num_channels // groups |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batchsize, -1, height, width) |
| x = torch.ops.quantized_decomposed.quantize_per_tensor( |
| x, output_scale, output_zero_point, 0, 255, torch.uint8 |
| ) |
| return x.contiguous(memory_format=torch.channels_last) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| x = torch.randn(64, 58, 28, 28) |
| output_zero_point = 3 |
| output_scale = 0.03 |
| self.common(channel_shuffle, (x, 2, output_scale, output_zero_point)) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_dequant_relu_quant_dequant_relu_quant_lowering(self): |
| def fn(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3): |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, scale, zero_point, 0, 255, torch.uint8 |
| ) |
| x = torch.relu(x) |
| x = torch.ops.quantized_decomposed.quantize_per_tensor( |
| x, scale2, zero_point2, 0, 255, torch.uint8 |
| ) |
| x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| x, scale2, zero_point2, 0, 255, torch.uint8 |
| ) |
| x = torch.relu(x) |
| x = torch.ops.quantized_decomposed.quantize_per_tensor( |
| x, scale3, zero_point3, 0, 255, torch.uint8 |
| ) |
| return x |
| |
| for use_tensor_overload in [True, False]: |
| x = torch.clamp( |
| torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255 |
| ).to(torch.uint8) |
| zero_point_list = [100, 101, 102] |
| scale_list = [0.01, 0.02, 0.03] |
| if use_tensor_overload: |
| for i in range(len(zero_point_list)): |
| zero_point_list[i] = torch.tensor( |
| zero_point_list[i], dtype=torch.int64 |
| ) |
| scale_list[i] = torch.tensor(scale_list[i]) |
| zero_point, zero_point2, zero_point3 = zero_point_list |
| scale, scale2, scale3 = scale_list |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common( |
| fn, |
| (x, scale, zero_point, scale2, zero_point2, scale3, zero_point3), |
| rtol=1e-2, |
| atol=1e-2, |
| ) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_inplace_add_alpha(self): |
| def fn(x, y): |
| aten.add_.Tensor(x, y, alpha=0.55) |
| return (x,) |
| |
| x1 = torch.zeros(10) |
| x2 = torch.zeros(10) |
| x3 = torch.zeros(10) |
| y = torch.randn(10) |
| fn_fx = make_fx(fn)(x1, y) |
| fn_compiled = compile_fx_inner(fn_fx, [x1, y]) |
| fn(x2, y) |
| fn_compiled([x3, y]) |
| assert same(x2, x3) |
| |
| def test_int_div(self): |
| def fn(x, y): |
| s3 = x.size(1) |
| a = torch.zeros((1 + s3) // 2) |
| a += y |
| return a, s3 |
| |
| p0 = torch.randint(5, (1, 8)) |
| p1 = torch.randn(1) |
| self.common(fn, (p0, p1)) |
| |
| def test_no_op_squeeze(self): |
| @torch._dynamo.optimize("inductor") |
| def forward(arg0_1): |
| return torch.ops.aten.squeeze.dim(arg0_1, 1) |
| |
| x = torch.randn((10, 20)) |
| self.common(forward, (x,)) |
| |
| def test_parallel_num_threads(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(x1, x2): |
| return x1 + x2 |
| |
| @contextlib.contextmanager |
| def set_num_threads(num_threads): |
| orig_num_threads = torch.get_num_threads() |
| torch.set_num_threads(num_threads) |
| yield |
| torch.set_num_threads(orig_num_threads) |
| |
| x1 = torch.randn((10, 20)) |
| x2 = torch.randn((10, 20)) |
| with set_num_threads(1): |
| assert same(x1 + x2, fn(x1, x2)) |
| with set_num_threads(4): |
| assert same(x1 + x2, fn(x1, x2)) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_timed_cpu_only(self): |
| timed(lambda: torch.randn(10), ()) |
| |
| def test_complex_memory_overlap(self): |
| dense = torch.zeros(64, 32) |
| self.assertFalse(complex_memory_overlap(dense)) |
| self.assertFalse(complex_memory_overlap(dense.t())) |
| |
| strided = dense.split(4, dim=1) |
| self.assertFalse(complex_memory_overlap(strided[0])) |
| self.assertFalse(complex_memory_overlap(strided[0].t())) |
| |
| unsqueezed = dense.unsqueeze(1) |
| self.assertFalse(complex_memory_overlap(unsqueezed)) |
| self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0))) |
| |
| gathered = dense.index_select(0, torch.IntTensor([1, 0, 1])) |
| self.assertFalse(complex_memory_overlap(gathered)) |
| self.assertFalse(complex_memory_overlap(gathered.t())) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_vec_dynamic_shapes(self): |
| def fn(x): |
| return torch.softmax(x, -1) |
| |
| value = torch.randn((2, 10)) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (value,)) |
| |
| @unittest.skipIf( |
| platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(), |
| "Does not support vectorization or not x86_64 machine", |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_auto_simd(self): |
| vec_avx512 = codecache.supported_vec_isa_list[0] |
| vec_avx2 = codecache.supported_vec_isa_list[1] |
| self.assertTrue(vec_avx512.bit_width() == 512) |
| self.assertTrue(vec_avx2.bit_width() == 256) |
| self.assertTrue(vec_avx512.nelements() == 16) |
| self.assertTrue(vec_avx2.nelements() == 8) |
| self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) |
| self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) |
| |
| with config.patch({"cpp.simdlen": None}): |
| isa = codecache.pick_vec_isa() |
| if vec_avx512 in codecache.valid_vec_isa_list(): |
| self.assertTrue(isa == vec_avx512) |
| else: |
| self.assertTrue(isa == vec_avx2) |
| |
| with config.patch({"cpp.simdlen": 0}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 1}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 257}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 513}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx512 in isa_list: |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 512}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx512 in isa_list: |
| isa = codecache.pick_vec_isa() |
| self.assertTrue(isa == vec_avx512) |
| |
| with config.patch({"cpp.simdlen": 256}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx2 in isa_list: |
| isa = codecache.pick_vec_isa() |
| self.assertTrue(isa == vec_avx2) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_masked_fill_softmax(self): |
| def fn(value, mask): |
| mask = mask.to(torch.bool) |
| x = torch.masked_fill(value, mask, -33.0) |
| return torch.softmax(x, -1) |
| |
| for dtype in vec_dtypes: |
| value = torch.randn((2, 17), dtype=dtype) |
| mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) |
| with config.patch({"cpp.simdlen": None}): |
| for cpp_wrapper_flag in [True, False]: |
| with config.patch({"cpp_wrapper": cpp_wrapper_flag}): |
| torch._dynamo.reset() |
| metrics.reset() |
| # fp16 inputs are not supported for C++ wrappers on CPU yet |
| if cpp_wrapper_flag and dtype == torch.float16: |
| with self.assertRaisesRegex( |
| BackendCompilerFailed, |
| "Unsupported input dtype torch.float16", |
| ): |
| self.common(fn, (value, mask)) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| else: |
| self.common(fn, (value, mask)) |
| assert metrics.generated_cpp_vec_kernel_count >= 1 |
| |
| def test_load_same_bool_tensor_twice(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(a, b): |
| x = torch.masked_fill(a, b, -33.0) |
| y = torch.masked_fill(a, b, -33.0) |
| return x, y |
| |
| value = torch.randn((2, 17)) |
| mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool) |
| fn(value, mask) |
| |
| def test_cpu_vec_cosim(self): |
| cpp_vec_op_list = [] |
| cpp_op_list = [] |
| |
| for k, v in CppVecOverrides.__dict__.items(): |
| if isinstance(v, staticmethod): |
| cpp_vec_op_list.append(k) |
| for k, v in CppOverrides.__dict__.items(): |
| if isinstance(v, staticmethod): |
| cpp_op_list.append(k) |
| |
| diff = [ |
| "index_expr", |
| "signbit", |
| "isinf", |
| "mod", |
| "masked", |
| "randn", |
| "isnan", |
| "rand", |
| "randint64", |
| "logical_and", |
| "logical_not", |
| "logical_or", |
| "logical_xor", |
| "bitwise_and", |
| "bitwise_left_shift", |
| "bitwise_not", |
| "bitwise_right_shift", |
| "bitwise_or", |
| "bitwise_xor", |
| "to_dtype_bitcast", |
| ] |
| union = {*cpp_vec_op_list, *diff} |
| self.assertTrue( |
| set(cpp_op_list).issubset(union), f"unexpected: {set(cpp_op_list) - union}" |
| ) |
| |
| def test_atomic_add_lowp_fp(self): |
| def fn(test_args): |
| res = torch.gather(**test_args) |
| return res |
| |
| for dtype in _lowp_fp_dtypes: |
| input_tensor_for_ref = torch.tensor( |
| [[3.0, -5.0]], dtype=dtype, requires_grad=True |
| ) |
| input_tensor_for_opt = torch.tensor( |
| [[3.0, -5.0]], dtype=dtype, requires_grad=True |
| ) |
| |
| test_args_for_ref = { |
| "input": input_tensor_for_ref, |
| "dim": 1, |
| "index": torch.tensor([[1]]), |
| } |
| test_args_for_opt = { |
| "input": input_tensor_for_opt, |
| "dim": 1, |
| "index": torch.tensor([[1]]), |
| } |
| |
| opt_fn = torch.compile(fn) |
| |
| ref_fwd = fn(test_args_for_ref) |
| res_fwd = opt_fn(test_args_for_opt) |
| self.assertEqual(res_fwd, ref_fwd) |
| |
| torch.manual_seed(1) |
| bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=dtype) |
| torch.manual_seed(1) |
| bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=dtype) |
| self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt) |
| |
| ref_fwd.backward(bwd_tensor_for_ref) |
| res_fwd.backward(bwd_tensor_for_opt) |
| |
| ref_grad = test_args_for_ref["input"].grad |
| res_grad = test_args_for_opt["input"].grad |
| self.assertEqual(ref_grad, res_grad) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_scatter_using_atomic_add(self): |
| def fn(a, dim, index, b): |
| return aten.scatter(a, dim, index, b, reduce="add") |
| |
| inps = ( |
| torch.randn(5, 29, 13), |
| 2, |
| torch.tensor([[[3, 5, 7, 9]]]), |
| torch.randn(1, 1, 10), |
| ) |
| |
| fn_opt = torch.compile()(fn) |
| with config.patch({"cpp.fallback_scatter_reduce_sum": False}): |
| _, code = run_and_get_cpp_code(fn_opt, *inps) |
| FileCheck().check("atomic_add").run(code) |
| |
| self.assertEqual( |
| fn(*inps), |
| fn_opt(*inps), |
| ) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_new_vec_op_cpu_only(self): |
| def fn(x): |
| return torch.log1p(torch.expm1(torch.erf(x))) |
| |
| for dtype in vec_dtypes: |
| torch.manual_seed(0) |
| x = torch.randn((2, 9), dtype=dtype) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 |
| |
| with config.patch({"cpp.simdlen": None}): |
| for cpp_wrapper_flag in [True, False]: |
| with config.patch({"cpp_wrapper": cpp_wrapper_flag}): |
| torch._dynamo.reset() |
| metrics.reset() |
| # fp16 inputs are not supported for C++ wrappers on CPU yet |
| if cpp_wrapper_flag and dtype == torch.float16: |
| with self.assertRaisesRegex( |
| BackendCompilerFailed, |
| "Unsupported input dtype torch.float16", |
| ): |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| else: |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_cpu_only_for_all_available_isa(self): |
| def fn(x): |
| return torch.sin(torch.cos(torch.erf(x))) |
| |
| x = torch.randn((2, 9)) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None] |
| for item in bit_widths: |
| with config.patch({"cpp.simdlen": item}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @slowTest |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test__adaptive_avg_pool2d(self): |
| def wrap_fn(oh, ow): |
| def fn(x): |
| return torch._adaptive_avg_pool2d(x, (oh, ow)) |
| |
| return fn |
| |
| bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] |
| ih = [16, 65] |
| iw = ih |
| oh = ih |
| ow = ih |
| for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product( |
| ih, iw, oh, ow, bit_widths, vec_dtypes |
| ): |
| x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to( |
| memory_format=torch.channels_last |
| ) |
| _fn = wrap_fn(_oh, _ow) |
| with config.patch({"cpp.simdlen": _simd_len}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(_fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_logical(self): |
| def wrap_fn1(op: Callable): |
| def fn(x: torch.Tensor): |
| return torch.where(op(x), 1.0, 0.0) |
| |
| return fn |
| |
| def wrap_fn2(op: Callable): |
| def fn(x: torch.Tensor, y: torch.Tensor): |
| return torch.where(op(x, y), 1.0, 0.0) |
| |
| return fn |
| |
| for dtype in vec_dtypes: |
| x = torch.randn(64, dtype=dtype) |
| y = torch.randn(64, dtype=dtype) |
| logical_fns = [ |
| torch.logical_and, |
| torch.logical_not, |
| torch.logical_or, |
| torch.logical_xor, |
| ] |
| for logical_fn in logical_fns: |
| torch._dynamo.reset() |
| metrics.reset() |
| if logical_fn == torch.logical_not: |
| _fn = wrap_fn1(logical_fn) |
| _args = (x,) |
| else: |
| _fn = wrap_fn2(logical_fn) |
| _args = (x, y) |
| self.common(_fn, _args) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_compare_op_cpu_only(self): |
| def fn(x): |
| y1 = torch.eq(x, 1.0) |
| x = torch.where(y1, x, -x) |
| y2 = torch.ne(x, 0.0) |
| x = torch.where(y2, x, -x) |
| y3 = torch.lt(x, 5.0) |
| x = torch.where(y3, x, x - 1.0) |
| y4 = torch.gt(x, -2.0) |
| x = torch.where(y4, x, x + 1.0) |
| y5 = torch.le(x, 8.0) |
| x = torch.where(y5, x, x - 1.0) |
| y6 = torch.ge(x, -3.0) |
| x = torch.where(y6, x, x + 1.0) |
| y7 = x == 1.0 |
| x = torch.where(y7, x, -x) |
| y8 = x != 0.0 |
| x = torch.where(y8, x, -x) |
| y9 = x < 5.0 |
| x = torch.where(y9, x, x - 1.0) |
| y10 = x > -2.0 |
| x = torch.where(y10, x, x + 1.0) |
| y11 = x <= 8.0 |
| x = torch.where(y11, x, x - 1.0) |
| y12 = x >= -3.0 |
| x = torch.where(y12, x, x + 1.0) |
| return x |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((2, 9), dtype=dtype) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| assert ( |
| metrics.generated_kernel_count |
| - metrics.generated_cpp_vec_kernel_count |
| ) == 0 |
| |
| def test_skip_cpp_codegen(self): |
| with config.patch({"disable_cpp_codegen": True}): |
| inps = torch.ones([20]), torch.rand([20]) |
| |
| def f(x, y): |
| return x + y + torch.tensor(1) |
| |
| f_opt = torch.compile()(f) |
| |
| _, code = run_and_get_cpp_code(f_opt, inps[0], inps[1]) |
| FileCheck().check_not("void kernel").run(code) |
| |
| self.assertEqual( |
| f(*inps), |
| f_opt(*inps), |
| ) |
| |
| # constant needs to be propagated on fallback |
| def f(x): |
| return x[torch.tensor(1) :] * 2 |
| |
| f_opt = torch.compile()(f) |
| _, code = run_and_get_cpp_code(f_opt, inps[0]) |
| FileCheck().check_not("void kernel").run(code) |
| self.assertEqual(f_opt(inps[0]), f(inps[0])) |
| |
| class Model(torch.nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
| |
| def forward(self, v1: torch.Tensor): |
| vx = v1.min(dim=1).values |
| v2 = torch.randn_like(vx) |
| return v2 |
| |
| model = Model() |
| x = torch.rand(10, 3, 0) |
| model_f = torch.compile()(model) |
| |
| self.assertEqual(model(x), model_f(x)) |
| |
| def test_redundant_to_node_elimination_lowp_fp(self): |
| def fn(x, y): |
| res = x + y |
| res = torch.mean(res) |
| return res |
| |
| for dtype in _lowp_fp_dtypes: |
| x = torch.randn((2, 9), dtype=dtype) |
| y = torch.randn((2, 9), dtype=dtype) |
| |
| for torch_compile_debug in [True, False]: |
| with config.patch( |
| {"trace.enabled": torch_compile_debug, "cpp.simdlen": None} |
| ): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x, y)) |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self): |
| def fn(x): |
| res = x.clone() |
| return res |
| |
| x = torch.randn((100, 100), dtype=torch.bfloat16) |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.cpp_to_dtype_count == 0 |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_insert_to_dtype_count(self): |
| def fn(x): |
| res = x.relu() |
| return res |
| |
| x = torch.randn((100, 100), dtype=torch.bfloat16) |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.cpp_to_dtype_count == 2 |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_memory_copy_with_fusion(self): |
| def fn(x): |
| res = x.relu() |
| x.copy_(res) |
| return (res,) |
| |
| x = torch.randn((100, 100), dtype=torch.bfloat16) |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.cpp_to_dtype_count == 2 |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_cpp_vec_constant_checker(self): |
| _graph: torch.fx.Graph = torch.fx.Graph() |
| a: torch.fx.Node = _graph.create_node("placeholder", "ops") |
| iv: torch.fx.Node = _graph.create_node("placeholder", "iv") |
| fv: torch.fx.Node = _graph.create_node("placeholder", "fv") |
| b: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "constant", |
| args=( |
| a, |
| iv, |
| torch.int64, |
| ), |
| ) |
| c: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "constant", |
| args=( |
| a, |
| fv, |
| torch.double, |
| ), |
| ) |
| d: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "ge", |
| args=( |
| a, |
| b, |
| b, |
| ), |
| ) |
| _graph.output((d, c)) |
| |
| def get_index(): |
| return "" |
| |
| submodules = {"get_index": get_index} |
| |
| graph_lowering = GraphLowering( |
| torch.fx.GraphModule(submodules, _graph), |
| shape_env=None, |
| num_static_inputs=0, |
| ) |
| |
| def set_opt_dtype(graph): |
| for node in graph.nodes: |
| if node.target == "constant": |
| if OptimizationContext.key in node.meta: |
| opt_ctx = node.meta[OptimizationContext.key] |
| else: |
| opt_ctx = OptimizationContext() |
| opt_ctx.dtype = node.args[-1] |
| node.meta[OptimizationContext.key] = opt_ctx |
| |
| with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( |
| graph_lowering |
| ): |
| # The moset inner loop variable is used in the index_expr |
| tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| i32_iinfo = np.iinfo(np.int32) |
| f32_iinfo = np.finfo(np.float32) |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, np.inf |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, -np.inf |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5) |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| set_opt_dtype(_graph) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5) |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_cpp_vec_index_expr_checker(self): |
| _graph: torch.fx.Graph = torch.fx.Graph() |
| a: torch.fx.Node = _graph.create_node("placeholder", "ops") |
| b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=()) |
| c: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "index_expr", |
| args=( |
| a, |
| b, |
| torch.int64, |
| ), |
| ) |
| d: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "ge", |
| args=( |
| a, |
| c, |
| c, |
| ), |
| ) |
| _graph.output(d) |
| |
| def get_index(): |
| return "" |
| |
| submodules = {"get_index": get_index} |
| graph_lowering = GraphLowering( |
| torch.fx.GraphModule(submodules, _graph), |
| shape_env=None, |
| num_static_inputs=0, |
| ) |
| with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( |
| graph_lowering |
| ): |
| itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] |
| |
| tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) |
| # The moset inner loop variable is used in the index_expr |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] |
| |
| ranges = [0, 100, 200] |
| vec_checker.itervars = itervars[:2] |
| vec_checker.ranges = ranges[:2] |
| submodules = {"get_index": get_index} |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| # Most inner loop variable irrevalant |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] |
| |
| ranges = [0, 100, 200] |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| submodules = {"get_index": get_index} |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| i32_iinfo = np.iinfo(np.int32) |
| _max_value = i32_iinfo.max + 1 |
| ranges = [_max_value, _max_value, _max_value] |
| # Most inner loop variable irrevalant but max value is greater than |
| # the max value of INT32 |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return itervars[0] |
| |
| submodules = {"get_index": get_index} |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| # Most inner loop variable irrevalant but min value is greater than |
| # the min value of INT32 |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] - 2 |
| |
| submodules = {"get_index": get_index} |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_maxpool2d_cpu_only(self): |
| for dtype in vec_dtypes: |
| input = torch.randn(26, 32, 112, 112, dtype=dtype).to( |
| memory_format=torch.channels_last |
| ) |
| maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| |
| def func(x): |
| return maxpool(x) |
| |
| with patch.object(config.cpp, "simdlen", None): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(func, (input,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_maxpool2d_with_pre_loop_collapse_cpu_only(self): |
| x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) |
| x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) |
| maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) |
| |
| def func(x1, x2): |
| y = x1 + x2 |
| return maxpool(y) |
| |
| with patch.object(config.cpp, "simdlen", None): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(func, (x1, x2)) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_sign_cpu_only(self): |
| def fn(x): |
| return torch.sign(x) |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((2, 9), dtype=dtype) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_reduction_cpu_only(self): |
| def fn(x): |
| return torch.argmax(x, -1) |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((10, 10), dtype=dtype) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| |
| # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not |
| # supported, the vectorization will not work and skip this test case. For ARM or |
| # other platforms support, we just need to add the ISA info to the supported_vector_isa |
| # and include proper aten vectorization head file. |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_kernel_cpu_only(self): |
| def fn(x1, x2): |
| # Current, there are some limitations as follows. |
| # rsqrt: |
| # assert [both a fallback and a decomp for same kernel: aten.rsqrt.default] |
| # round: |
| # couldn't find symbolic meta function/decomposition |
| # fmod/logical_and/logic_or: |
| # vec kernel has not support to_type |
| x = torch.abs(x1) |
| x = torch.sin(x) |
| x = torch.neg(x) |
| x = torch.square(x) |
| x = torch.sigmoid(x) |
| x = torch.relu(x) |
| x = torch.cos(x) |
| x = torch.exp(x) |
| x = torch.sqrt(x) |
| x = torch.add(x, x1) |
| x = torch.sub(x, x2) |
| x = torch.mul(x, x1) |
| x = torch.div(x, x1) |
| x = torch.pow(x, 10) |
| x = torch.log(x) |
| x = torch.floor(x) |
| x = torch.ceil(x) |
| x = torch.trunc(x) |
| x = torch.lgamma(x) |
| x = torch.fmod(x, x2) |
| x = torch.sign(x) |
| res = x + x2 |
| return res |
| |
| for dtype in vec_dtypes: |
| torch.manual_seed(0) |
| x1 = torch.randn((5, 20), dtype=dtype) |
| x2 = torch.randn((5, 20), dtype=dtype) |
| |
| tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 |
| with config.patch({"cpp.simdlen": 1}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x1, x2)) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common(fn, (x1, x2)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| x1 = torch.randn(10, 20).permute(1, 0) |
| x2 = torch.randn((20, 10)) |
| self.common(fn, (x1, x2)) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| x1 = torch.randn((10, 7)) |
| x2 = torch.randn((10, 7)) |
| self.common(fn, (x1, x2)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| sys.platform != "linux", "cpp kernel profile only support linux now" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| @config.patch({"cpp.enable_kernel_profile": True}) |
| @config.patch({"cpp.descriptive_names": "original_aten"}) |
| def test_cpp_kernel_profile(self): |
| from torch.profiler import profile |
| |
| @torch._dynamo.optimize("inductor", nopython=True) |
| def fn(a, b): |
| return a + b |
| |
| a = torch.rand((100,)) |
| b = torch.rand((100,)) |
| with profile() as prof: |
| fn(a, b) |
| |
| kernel_profile_events = [] |
| for e in prof.profiler.function_events: |
| if "cpp_fused_add_0" in e.name: |
| kernel_profile_events.append(e.name) |
| assert len(kernel_profile_events) > 0 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_channel_shuffle_cl_output(self): |
| """code and shape extracted from shufflenet_v2_x1_0""" |
| |
| def channel_shuffle(x, groups): |
| batchsize, num_channels, height, width = x.size() |
| channels_per_group = num_channels // groups |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batchsize, -1, height, width) |
| return x.contiguous(memory_format=torch.channels_last) |
| |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| torch._dynamo.reset() |
| metrics.reset() |
| x = torch.randn(64, 58, 28, 28) |
| self.common(channel_shuffle, (x, 2)) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @slowTest |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_transpose_with_norm(self): |
| """a sub-module from TIMM gmlp_s16_224""" |
| |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear( |
| in_features=256, out_features=1536, bias=True |
| ) |
| self.act = torch.nn.GELU() |
| self.norm = torch.nn.LayerNorm(768) |
| self.proj = torch.nn.Linear(196, 196) |
| self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.act(x) |
| u, v = x.chunk(2, dim=-1) |
| v = self.norm(v) |
| v = self.proj(v.transpose(-1, -2)) |
| y = u * v.transpose(-1, -2) |
| return self.fc(y) |
| |
| x = torch.randn(128, 196, 256) |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| for eval_mode in [True, False]: |
| torch._dynamo.reset() |
| metrics.reset() |
| m = Model().eval() if eval_mode else Model() |
| self.common(m, (x,)) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 6 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_transpose_copy(self): |
| def fn(a): |
| return a.t().contiguous() |
| |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| for dtype in (torch.float, torch.bfloat16): |
| for shape in ( |
| (7, 7), |
| (8, 8), |
| (9, 9), |
| (16, 16), |
| (17, 17), |
| (32, 32), |
| (33, 33), |
| ): |
| torch._dynamo.reset() |
| metrics.reset() |
| x = torch.randn(shape, dtype=dtype) |
| self.common(fn, (x,)) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_horizontal_fusion(self): |
| def fn(a, b, c, idx): |
| _a = torch.index_select(a, dim=0, index=idx) |
| _b = torch.index_select(b, dim=0, index=idx) |
| _c = torch.index_select(c, dim=0, index=idx) |
| return _a, _b, _c |
| |
| with config.patch({"cpp.max_horizontal_fusion_size": 0}): |
| metrics.reset() |
| torch._dynamo.reset() |
| a = torch.randn(size=(4, 16), dtype=torch.bfloat16) |
| b = torch.randn(size=(4, 16), dtype=torch.bfloat16) |
| c = torch.randn(size=(4, 16), dtype=torch.bfloat16) |
| idx = torch.zeros(size=[4], dtype=torch.int64) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(a, b, c, idx) |
| self.assertEqual(metrics.generated_kernel_count, 3) |
| self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) |
| |
| with config.patch({"cpp.max_horizontal_fusion_size": 1}): |
| metrics.reset() |
| torch._dynamo.reset() |
| a = torch.randn(size=(4, 32), dtype=torch.bfloat16) |
| b = torch.randn(size=(4, 32), dtype=torch.bfloat16) |
| c = torch.randn(size=(4, 32), dtype=torch.bfloat16) |
| idx = torch.zeros(size=[4], dtype=torch.int64) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(a, b, c, idx) |
| self.assertEqual(metrics.generated_kernel_count, 3) |
| self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) |
| |
| with config.patch({"cpp.max_horizontal_fusion_size": 2}): |
| metrics.reset() |
| torch._dynamo.reset() |
| a = torch.randn(size=(4, 64), dtype=torch.bfloat16) |
| b = torch.randn(size=(4, 64), dtype=torch.bfloat16) |
| c = torch.randn(size=(4, 64), dtype=torch.bfloat16) |
| idx = torch.zeros(size=[4], dtype=torch.int64) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(a, b, c, idx) |
| print(metrics.generated_kernel_count) |
| self.assertEqual(metrics.generated_kernel_count, 2) |
| self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) |
| |
| with config.patch({"cpp.max_horizontal_fusion_size": 3}): |
| metrics.reset() |
| torch._dynamo.reset() |
| a = torch.randn(size=(4, 128), dtype=torch.bfloat16) |
| b = torch.randn(size=(4, 128), dtype=torch.bfloat16) |
| c = torch.randn(size=(4, 128), dtype=torch.bfloat16) |
| idx = torch.zeros(size=[4], dtype=torch.int64) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(a, b, c, idx) |
| self.assertEqual(metrics.generated_kernel_count, 1) |
| self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx))) |
| |
| def test_lowp_fp_neg_abs(self): |
| def fn(x): |
| return x.neg().abs() |
| |
| for dtype in _lowp_fp_dtypes: |
| metrics.reset() |
| x = torch.randn(100, 100).to(dtype) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x), opt_fn(x))) |
| assert metrics.cpp_to_dtype_count == 0 |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_transpose_non_contiguous(self): |
| def fn(a): |
| # From part of timm HaloAttn: |
| # (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97). |
| # Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue. |
| as_strided = torch.ops.aten.as_strided.default( |
| a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680] |
| ) |
| as_strided_1 = torch.ops.aten.as_strided.default( |
| as_strided, |
| [1, 384, 2, 2, 12, 12], |
| [153600, 1, 61440, 3072, 7680, 384], |
| ) |
| clone_1 = torch.ops.aten.clone.default( |
| as_strided_1, memory_format=torch.contiguous_format |
| ) |
| _unsafe_view_1 = torch.ops.aten._unsafe_view.default( |
| clone_1, [8, 48, 4, 144] |
| ) |
| permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1]) |
| split_with_sizes = torch.ops.aten.split_with_sizes.default( |
| permute_2, [16, 32], -1 |
| ) |
| getitem = split_with_sizes[0] |
| getitem_1 = split_with_sizes[1] |
| permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2]) |
| expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144]) |
| clone_3 = torch.ops.aten.clone.default( |
| expand_1, memory_format=torch.contiguous_format |
| ) |
| return clone_3 |
| |
| metrics.reset() |
| x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last) |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_non_contiguous_index_with_constant_stride(self): |
| def fn(x): |
| x1 = x[:, :, :, ::2] |
| x2 = x[:, :, :, 1::2] |
| x = torch.stack((-x2, x1), dim=-1) |
| return x.flatten(-2) |
| |
| metrics.reset() |
| x = torch.randn(1, 32, 16, 68) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| _, code = run_and_get_cpp_code(opt_fn, x) |
| self.assertTrue(same(fn(x), opt_fn(x))) |
| # def and use |
| FileCheck().check_count("cpp_fused", 2, exactly=True).run(code) |
| |
| def test_invalid_index_of_empty_tensor(self): |
| def fn(a): |
| b = a[[0]] |
| return b |
| |
| a = torch.tensor([]) |
| with self.assertRaises(RuntimeError): |
| torch.compile(fn)(a) |
| |
| def test_ir_node_str(self): |
| @torch.compile |
| def fn(x: torch.Tensor) -> torch.Tensor: |
| return x.sin(), torch.nn.Softmax(dim=1)(x.cos()) |
| |
| def run_node_alt(*args, **kwargs): |
| rv = run_node(*args, **kwargs) |
| strings.append(str(rv)) |
| return rv |
| |
| strings = [] |
| run_node = GraphLowering.run_node |
| with patch.object(GraphLowering, "run_node", run_node_alt): |
| fn(torch.randn([8, 128])) |
| self.assertGreater(len(strings), 3) |
| |
| def test_vertical_sum_cpu_only(self): |
| def fn1(a): |
| return a.sum(dim=0) |
| |
| def fn2(a): |
| return a.sum(dim=1) |
| |
| metrics.reset() |
| x = torch.randn(100, 100) |
| self.common(fn1, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| metrics.reset() |
| x = torch.randn(100, 100, 100) |
| self.common(fn2, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_transpose_vertical_sum_cpu_only(self): |
| def fn(a, b): |
| c = a * b |
| return c.sum(dim=1) |
| |
| metrics.reset() |
| x = torch.randn(100, 50, 50) |
| y = torch.randn(100, 50, 50).transpose(1, 2) |
| self.common(fn, (x, y)) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_transpose_sum2d_cpu_only(self): |
| def fn(a, b): |
| c = a * b |
| return c.sum() |
| |
| metrics.reset() |
| x = torch.randn(50, 50) |
| y = torch.randn(50, 50).transpose(0, 1) |
| self.common(fn, (x, y)) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_transpose_sum_outer(self): |
| # https://github.com/pytorch/pytorch/issues/98573 |
| def fn(a): |
| return a.transpose(2, 3).sum(dim=1).contiguous() |
| |
| metrics.reset() |
| x = torch.randn(10, 50, 50, 50) |
| self.common(fn, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_to_dtype_bool_float(self): |
| # https://github.com/pytorch/pytorch/issues/100800 |
| def f(a): |
| return torch.where( |
| torch.ones_like(a).to(torch.bool), |
| torch.zeros_like(a), |
| torch.ones_like(a) * 2, |
| ) |
| |
| self.common(f, (torch.ones(16),)) |
| |
| def test_to_dtype_float_bool(self): |
| # https://github.com/pytorch/pytorch/issues/100466 |
| def f(a): |
| a = a * torch.tensor(a >= 0, dtype=torch.float32) |
| return a |
| |
| x = torch.rand(16) |
| self.common(f, (x,)) |
| |
| def test_constant_store(self): |
| # https://github.com/pytorch/pytorch/issues/104515 |
| def f(a): |
| a[0, [3, 3]] = -float("inf") |
| return a |
| |
| x = torch.rand(4, 5) |
| self.common(f, (x,)) |
| |
| def test_to_channels_last_lowp_fp(self): |
| def f(a): |
| return a.to(memory_format=torch.channels_last) |
| |
| for dtype in _lowp_fp_dtypes: |
| x = torch.rand(2, 3, 14, 14).to(dtype) |
| self.common(f, (x,)) |
| |
| def test_broadcast_mul_lowp_fp(self): |
| def f(a, b): |
| return a * b |
| |
| for dtype in _lowp_fp_dtypes: |
| a = torch.randn(2, 16, 16).to(dtype) |
| b = torch.randn(2, 1, 1).to(dtype) |
| self.common(f, (a, b)) |
| |
| def test_linear_buffer_reuse(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(16, 16) |
| self.tanh = torch.nn.Tanh() |
| self.linear2 = torch.nn.Linear(16, 16) |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.tanh(x) |
| x = self.linear2(x) |
| return x |
| |
| mod = M().eval() |
| v = torch.randn(1, 16) |
| |
| with torch.no_grad(): |
| |
| def compile_fx_wrapper(model_, example_inputs_): |
| return compile_fx(model_, example_inputs_) |
| |
| def run(*ex, **kwargs): |
| return mod(*ex, **kwargs) |
| |
| run = torch._dynamo.optimize(compile_fx_wrapper)(run) |
| _, code = run_and_get_cpp_code(run, v) |
| self.assertFalse("= as_strided(" in code) |
| self.assertEqual(run(*v), mod(*v)) |
| |
| @config.patch(inplace_buffers=True) |
| def test_in_out_buffer(self): |
| def fn(x, y): |
| z = torch.matmul(x, y.transpose(-1, -2)) / 8.0 |
| return z |
| |
| inps = [torch.randn(1, 2, 8, 4), torch.randn(1, 2, 8, 4)] |
| fn_opt = torch._dynamo.optimize("inductor")(fn) |
| _, code = run_and_get_cpp_code(fn_opt, *inps) |
| self.assertTrue("in_out_ptr" in code) |
| self.assertEqual(fn_opt(*inps), fn(*inps)) |
| |
| def test_eliminate_meaningless_copy(self): |
| def fn(x1, x2): |
| permute = torch.ops.aten.permute.default(x2, [0, 2, 1, 3]) |
| clone = torch.ops.aten.clone.default( |
| permute, memory_format=torch.contiguous_format |
| ) |
| view = torch.ops.aten.view.default(clone, [1024, -1, 32]) |
| bmm = torch.ops.aten.bmm.default(view, x1) |
| permute = torch.ops.aten.permute.default(view, [0, 2, 1]) |
| return (bmm, permute) |
| |
| metrics.reset() |
| self.common( |
| fn, |
| [ |
| rand_strided( |
| (1024, 32, 128), (4096, 1, 32), device="cpu", dtype=torch.float32 |
| ), |
| rand_strided( |
| (64, 128, 16, 32), |
| (65536, 512, 32, 1), |
| device="cpu", |
| dtype=torch.float32, |
| ), |
| ], |
| ) |
| self.assertEqual(metrics.generated_kernel_count, 1) |
| |
| def test_scalar_mul_bfloat16(self): |
| def f(x): |
| return torch.ops.aten.mul.Tensor(x, 1.7015043497085571) |
| |
| metrics.reset() |
| x = torch.randn(4, 5, dtype=torch.bfloat16) |
| self.common(f, (x,)) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_bf16_zeros(self): |
| def fn(): |
| x = torch.zeros(1, 1, 32, dtype=torch.bfloat16) |
| return x |
| |
| self.common(fn, ()) |
| |
| def test_select_tiliing_with_index_expr(self): |
| def fn(x, y): |
| x = torch.ops.aten.view.default(x, [8, 8, 8, 3136]) |
| x = torch.ops.aten.permute.default(x, [0, 1, 3, 2]) |
| y = torch.ops.aten.mul.Tensor(y, x) |
| return torch.ops.aten.constant_pad_nd.default(y, [0, 0, 1, 0, 0, 0], 0.0) |
| |
| x = torch.randn(8, 64, 56, 56) |
| y = torch.randn(8, 8, 3136, 8) |
| self.common(fn, (x, y)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") |
| @patch("torch.cuda.is_available", lambda: False) |
| @config.patch(freezing=True) |
| def test_linear_with_no_default_contiguous_input(self): |
| mod = torch.nn.Sequential(torch.nn.Linear(16, 16)).eval() |
| temp = torch.randn(1, 16, 1, 1) |
| v = torch.as_strided(temp, [1, 16], [0, 1], 0) |
| self.assertTrue(v.is_contiguous()) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| if torch.ops.mkldnn._is_mkldnn_bf16_supported(): |
| mod = mod.to(torch.bfloat16) |
| v = v.to(torch.bfloat16) |
| with torch.no_grad(): |
| self.common( |
| mod, |
| (v,), |
| ) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| @config.patch(freezing=True) |
| def test_linear_with_reshape(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(16, 16, bias=False) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return x.view(4, 4, 4) |
| |
| mod = M().eval() |
| v = torch.randn(4, 16) |
| with torch.no_grad(): |
| torch._dynamo.reset() |
| metrics.reset() |
| self.common( |
| mod, |
| (v,), |
| ) |
| assert metrics.generated_kernel_count == 0 |
| |
| @config.patch(implicit_fallbacks=True) |
| def test_aten_normal_dtype(self): |
| for dtype in [torch.float64, torch.float16, None]: |
| |
| def fn(): |
| return torch.normal(2, 3, (10, 10), dtype=dtype, device="cpu") |
| |
| self.assertEqual( |
| torch.compile(fn, backend="aot_eager_decomp_partition")().dtype, |
| dtype if dtype else torch.float32, |
| ) |
| self.assertEqual( |
| torch.compile(fn, backend="inductor")().dtype, |
| dtype if dtype else torch.float32, |
| ) |
| |
| def test_group_norm_vec(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.group_norm = torch.nn.GroupNorm(32, 32) |
| |
| def forward(self, x): |
| return self.group_norm(x) |
| |
| metrics.reset() |
| mod = M().eval() |
| x = torch.randn(2, 32, 32, 32) |
| with torch.no_grad(): |
| self.common(mod, (x,)) |
| # 2 generated kernels (one for var_mean, the other for result) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_int_div_vec(self): |
| def fn(x, y, mode): |
| return torch.div(x, y, rounding_mode=mode) |
| |
| x = torch.randint(1, 100, (32, 32)) |
| y = torch.randint(1, 100, (32, 32)) |
| for mode in [None, "trunc", "floor"]: |
| with torch.no_grad(): |
| metrics.reset() |
| self.common(fn, (x, y, mode)) |
| # TODO: support vectorization for int div |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| |
| def test_uint8_add(self): |
| # https://github.com/pytorch/pytorch/issues/113016 |
| def fn(x, y): |
| return torch.add(x, y).neg().to(torch.int32) |
| |
| x = torch.randint(0, 255, (3, 3), dtype=torch.uint8) |
| y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) |
| self.common(fn, (x, y)) |
| |
| def test_uint8_sub(self): |
| # https://github.com/pytorch/pytorch/issues/113016 |
| def fn(x, y): |
| return torch.sub(x, y).neg().to(torch.int32) |
| |
| x = torch.randint(0, 255, (3, 3), dtype=torch.uint8) |
| y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) |
| self.common(fn, (x, y)) |
| |
| def test_non_contiguous_reduction_store(self): |
| # https://github.com/pytorch/pytorch/issues/113018 |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 17), stride=(2, 2)) |
| |
| def forward(self, x): |
| return self.conv(x.max(3).values) |
| |
| m = M() |
| x = torch.randn(1, 39, 1, 18, 17) |
| self.common(m, (x,)) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| from torch.testing._internal.inductor_utils import HAS_CPU |
| |
| if HAS_CPU and not IS_MACOS: |
| run_tests(needs="filelock") |