| # Owner(s): ["module: inductor"] |
| |
| import unittest |
| |
| import torch |
| |
| from torch._inductor.ir import Pointwise |
| from torch._inductor.lowering import register_lowering |
| from torch._inductor.virtualized import ops |
| |
| from torch.testing._internal.common_utils import TestCase as TorchTestCase |
| from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| |
| |
| # These tests check issues for lowerings that aren't in the main pytorch repo |
| class TestCustomLowering(TorchTestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.test_inductor_ops = torch.library.Library("test_inductor_ops", "DEF") |
| cls.impl_cuda = torch.library.Library("test_inductor_ops", "IMPL", "CUDA") |
| cls.impl_meta = torch.library.Library("test_inductor_ops", "IMPL", "Meta") |
| cls._register_jagged_to_padded_dense() |
| |
| @classmethod |
| def tearDown(cls): |
| super().tearDownClass() |
| |
| @classmethod |
| def _register_jagged_to_padded_dense(cls): |
| # Approximation of fbgemm.jagged_to_padded_dense_forward |
| cls.test_inductor_ops.define( |
| "jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor" |
| ) |
| |
| def j2pd_meta(inp, offsets, max_seq_len, pad_value): |
| return torch.empty( |
| (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), |
| device=inp.device, |
| dtype=inp.dtype, |
| ) |
| |
| def j2pd_cuda(inp, offsets, max_seq_len, pad_value): |
| res = torch.full( |
| (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), |
| pad_value, |
| device=inp.device, |
| dtype=inp.dtype, |
| ) |
| for b in range(offsets.shape[0] - 1): |
| for r in range(offsets[b + 1] - offsets[b]): |
| res[b][r] = inp[offsets[b] + r] |
| return res |
| |
| def j2pd_lowering(inp, offsets, max_seq_len, pad_value): |
| offsets_loader = offsets.make_loader() |
| inp_loader = inp.make_loader() |
| jagged_len = inp.get_size()[0] |
| offsets_dtype = offsets.get_dtype() |
| |
| def inner_fn(index): |
| batch_idx, seq_idx, emb_idx = index |
| |
| begin_idx = ops.indirect_indexing( |
| offsets_loader([batch_idx]), |
| jagged_len + 1, |
| ) |
| end_idx = offsets_loader([batch_idx + 1]) |
| jagged_idx = begin_idx + seq_idx |
| |
| return ops.masked( |
| ops.lt( |
| ops.index_expr(jagged_idx, offsets_dtype), |
| end_idx, |
| ), |
| lambda: inp_loader([jagged_idx, emb_idx]), |
| pad_value, |
| ) |
| |
| return Pointwise.create( |
| device=inp.get_device(), |
| dtype=inp.get_dtype(), |
| inner_fn=inner_fn, |
| ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]], |
| ) |
| |
| register_lowering( |
| torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None |
| )(j2pd_lowering) |
| |
| cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta) |
| cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda) |
| |
| @unittest.skipIf(not HAS_CUDA, "CUDA needed") |
| def test_jagged_to_padded_dense_sanity_cuda(self): |
| def fn(inp, offsets, max_seq_len): |
| return torch.ops.test_inductor_ops.jagged_to_padded_dense( |
| inp, offsets, max_seq_len, 60.0 |
| ) |
| |
| inp = torch.rand((9, 96), device="cuda") |
| offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda") |
| max_seq_len = 4 |
| |
| res = fn(inp, offsets, max_seq_len) |
| self.assertEqual(inp[0], res[0][0]) |
| self.assertEqual(inp[1], res[0][1]) |
| self.assertEqual(inp[2], res[1][0]) |
| self.assertEqual(inp[3], res[1][1]) |
| self.assertEqual(inp[5], res[2][0]) |
| self.assertEqual(inp[8], res[2][3]) |
| |
| fn_opt = torch.compile(fn) |
| |
| self.assertEqual( |
| fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) |
| ) |
| |
| @unittest.skipIf(not HAS_CUDA, "CUDA needed") |
| def test_jagged_to_padded_dense_zero_size(self): |
| # Previously, the masking was being completely stripped for the |
| # masked load of the input value. That would lead to an IMA |
| # because cuda was trying to read index 0 of a zero-size tensor. |
| def fn(inp, offsets, max_seq_len): |
| inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1)) |
| return torch.ops.test_inductor_ops.jagged_to_padded_dense( |
| inp, offsets, max_seq_len, 60.0 |
| ) |
| |
| inp = torch.rand((1, 0, 96), device="cuda") |
| offsets = torch.zeros(1025, device="cuda", dtype=torch.int32) |
| max_seq_len = 20 |
| |
| fn_opt = torch.compile(fn) |
| |
| self.assertEqual( |
| fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| if HAS_CPU or HAS_CUDA: |
| run_tests(needs="filelock") |