| # Owner(s): ["module: scatter & gather ops"] |
| |
| from itertools import product |
| from functools import partial |
| |
| import numpy as np |
| import torch |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| dtypes, |
| ) |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| run_tests, |
| gradcheck, |
| parametrize, |
| skipIfRocm, |
| ) |
| |
| |
| reductions = ["max", "mean", "min", "sum", "prod"] |
| |
| |
| def get_default_value(initial_value, reduction): |
| if initial_value is not None: |
| return initial_value |
| if reduction == "max": |
| return -float("Inf") |
| elif reduction == "mean": |
| return float("nan") |
| elif reduction == "min": |
| return float("Inf") |
| elif reduction == "sum": |
| return 0.0 |
| elif reduction == "prod": |
| return 1.0 |
| |
| |
| class TestSegmentReductions(TestCase): |
| def _test_common( |
| self, |
| reduction, |
| device, |
| dtype, |
| unsafe, |
| axis, |
| initial_value, |
| data_arr, |
| lengths_arr, |
| expected_arr, |
| expected_grad_arr, |
| check_backward, |
| lengths_dtype=torch.int, |
| ): |
| lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype) |
| # generate offsets from lengths |
| zeros_shape = list(lengths.shape) |
| zeros_shape[-1] = 1 |
| offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1) |
| |
| data = torch.tensor( |
| data_arr, |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| ) |
| expected_result = torch.tensor(expected_arr, device=device, dtype=dtype) |
| expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype) |
| for mode in ['lengths', 'offsets']: |
| segment_reduce_kwargs = dict( |
| axis=axis, |
| unsafe=unsafe, |
| initial=initial_value) |
| if (mode == 'lengths'): |
| segment_reduce_kwargs['lengths'] = lengths |
| else: |
| segment_reduce_kwargs['offsets'] = offsets |
| actual_result = torch._segment_reduce( |
| data=data, |
| reduce=reduction, |
| **segment_reduce_kwargs |
| ) |
| self.assertEqual( |
| expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True |
| ) |
| |
| if not check_backward: |
| return |
| |
| # Test backward |
| actual_result.sum().backward() |
| self.assertEqual( |
| expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True |
| ) |
| data = data.clone().detach().requires_grad_(True) |
| |
| # gradcheck does not work well with bfloat16 or fp16 cpu types |
| # also there is small numerical difference with fp32 |
| if dtype not in [torch.half, torch.bfloat16, torch.float]: |
| # gradcheck does not like "nan" input, setting to random 10 |
| d_non_nan = np.nan_to_num(data_arr, nan=10) |
| new_data = torch.tensor( |
| # [10 if v == float("nan") else v for v in data], |
| d_non_nan, |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| ) |
| self.assertTrue( |
| gradcheck( |
| lambda x: torch._segment_reduce( |
| data=x, |
| reduce=reduction, |
| **segment_reduce_kwargs |
| ), |
| (new_data,), |
| ) |
| ) |
| |
| @dtypes( |
| *product( |
| (torch.half, torch.bfloat16, torch.float, torch.double), |
| (torch.int, torch.int64), |
| ) |
| ) |
| def test_simple_1d(self, device, dtypes): |
| val_dtype, length_type = dtypes |
| lengths = [1, 2, 3, 0] |
| data = [1, float("nan"), 3, 4, 5, 5] |
| |
| for reduction in reductions: |
| for initial in [0, None]: |
| check_backward = True if initial is not None else False |
| initial_value = initial |
| default_value = get_default_value(initial_value, reduction) |
| if reduction == "max": |
| expected_result = [1, float("nan"), 5, default_value] |
| expected_grad = [1, 1, 0, 0, 0.5, 0.5] |
| elif reduction == "mean": |
| expected_result = [1, float("nan"), 4.666, default_value] |
| expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333] |
| elif reduction == "min": |
| if initial is not None: |
| initial_value = 1000 # some high number |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [1, float("nan"), 4, default_value] |
| expected_grad = [1.0, 1.0, 0, 1, 0, 0] |
| elif reduction == "sum": |
| expected_result = [1, float("nan"), 14, default_value] |
| expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] |
| elif reduction == "prod": |
| if initial is not None: |
| initial_value = 2 # 0 initial_value will zero out everything for prod |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [2, float("nan"), 200, default_value] |
| expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0] |
| else: |
| expected_result = [1, float("nan"), 100, default_value] |
| expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0] |
| for axis in [0, -1]: |
| for unsafe in [True, False]: |
| self._test_common( |
| reduction, |
| device, |
| val_dtype, |
| unsafe, |
| axis, |
| initial_value, |
| data, |
| lengths, |
| expected_result, |
| expected_grad, |
| check_backward, |
| length_type, |
| ) |
| |
| @dtypes( |
| *product( |
| (torch.half, torch.bfloat16, torch.float, torch.double), |
| (torch.int, torch.int64), |
| ) |
| ) |
| def test_simple_zero_length(self, device, dtypes): |
| val_dtype, length_type = dtypes |
| lengths = [0, 0] |
| data = torch.ones(0) |
| |
| for reduction in reductions: |
| for initial in [0, None]: |
| check_backward = True if initial is not None else False |
| initial_value = initial |
| default_value = get_default_value(initial_value, reduction) |
| if reduction == "max": |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| elif reduction == "mean": |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| elif reduction == "min": |
| if initial is not None: |
| initial_value = 1000 # some high number |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| elif reduction == "sum": |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| elif reduction == "prod": |
| if initial is not None: |
| initial_value = 2 # 0 initial_value will zero out everything for prod |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| else: |
| expected_result = [default_value, default_value] |
| expected_grad = [] |
| for axis in [0]: |
| for unsafe in [True, False]: |
| self._test_common( |
| reduction, |
| device, |
| val_dtype, |
| unsafe, |
| axis, |
| initial_value, |
| data, |
| lengths, |
| expected_result, |
| expected_grad, |
| check_backward, |
| length_type, |
| ) |
| |
| @skipIfRocm |
| @dtypes( |
| *product( |
| (torch.half, torch.bfloat16, torch.float, torch.double), |
| (torch.int, torch.int64), |
| ) |
| ) |
| def test_multi_d_simple(self, device, dtypes): |
| val_dtype, length_type = dtypes |
| axis = 0 |
| lengths = [1, 2, 3, 0] |
| data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]] |
| |
| for reduction in reductions: |
| for initial in [0, None]: |
| check_backward = True if initial is not None else False |
| initial_value = initial |
| default_value = get_default_value(initial_value, reduction) |
| if reduction == "max": |
| expected_result = [ |
| [1, 1], |
| [float("nan"), float("nan")], |
| [4, 3], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [1, 1], |
| [1, 0], |
| [0, 1], |
| [1, 0], |
| [0, 0], |
| [0, 1], |
| ] |
| elif reduction == "mean": |
| expected_result = [ |
| [1, 1], |
| [float("nan"), float("nan")], |
| [3, 2], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [1.0, 1.0], |
| [0.5, 0.5], |
| [0.5, 0.5], |
| [0.333, 0.333], |
| [0.333, 0.333], |
| [0.333, 0.333], |
| ] |
| elif reduction == "min": |
| if initial is not None: |
| initial_value = 1000 # some high number |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [ |
| [1, 1], |
| [float("nan"), float("nan")], |
| [2, 1], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [1.0, 1.0], |
| [1, 0], |
| [0, 1], |
| [0, 1], |
| [0, 0], |
| [1, 0], |
| ] |
| elif reduction == "sum": |
| expected_result = [ |
| [1, 1], |
| [float("nan"), float("nan")], |
| [9, 6], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [1.0, 1.0], |
| [1.0, 1.0], |
| [1.0, 1.0], |
| [1.0, 1.0], |
| [1.0, 1.0], |
| [1.0, 1.0], |
| ] |
| elif reduction == "prod": |
| if initial is not None: |
| initial_value = 2 # 0 initial_value will zero out everything for prod |
| default_value = get_default_value(initial_value, reduction) |
| expected_result = [ |
| [2, 2], |
| [float("nan"), float("nan")], |
| [48, 12], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [2.0, 2.0], |
| [6.0, float("nan")], |
| [float("nan"), 2.0], |
| [12.0, 12.0], |
| [16.0, 6.0], |
| [24.0, 4.0], |
| ] |
| else: |
| expected_result = [ |
| [1, 1], |
| [float("nan"), float("nan")], |
| [24, 6], |
| [default_value, default_value], |
| ] |
| expected_grad = [ |
| [1.0, 1.0], |
| [3.0, float("nan")], |
| [float("nan"), 1.0], |
| [6.0, 6.0], |
| [8.0, 3.0], |
| [12.0, 2.0], |
| ] |
| for unsafe in [True, False]: |
| self._test_common( |
| reduction, |
| device, |
| val_dtype, |
| unsafe, |
| axis, |
| initial_value, |
| data, |
| lengths, |
| expected_result, |
| expected_grad, |
| check_backward, |
| ) |
| |
| @dtypes( |
| *product( |
| (torch.half, torch.bfloat16, torch.float, torch.double), |
| (torch.int, torch.int64), |
| ) |
| ) |
| @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean']) |
| def test_pytorch_scatter_test_cases(self, device, dtypes, reduce): |
| val_dtype, length_dtype = dtypes |
| # zero-length segments are filled with reduction inits contrary to pytorch_scatter. |
| tests = [ |
| { |
| 'src': [1, 2, 3, 4, 5, 6], |
| 'index': [0, 0, 1, 1, 1, 3], |
| 'indptr': [0, 2, 5, 5, 6], |
| 'sum': [3, 12, 0, 6], |
| 'prod': [2, 60, 1, 6], |
| 'mean': [1.5, 4, float('nan'), 6], |
| 'min': [1, 3, float('inf'), 6], |
| 'max': [2, 5, -float('inf'), 6], |
| }, |
| { |
| 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], |
| 'index': [0, 0, 1, 1, 1, 3], |
| 'indptr': [0, 2, 5, 5, 6], |
| 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], |
| 'prod': [[3, 8], [315, 480], [1, 1], [11, 12]], |
| 'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]], |
| 'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]], |
| 'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]], |
| }, |
| { |
| 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], |
| 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], |
| 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], |
| 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], |
| 'prod': [[3, 315, 1, 11], [48, 80, 12, 1]], |
| 'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]], |
| 'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]], |
| 'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]], |
| }, |
| { |
| 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], |
| 'index': [[0, 0, 1], [0, 2, 2]], |
| 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], |
| 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
| 'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]], |
| 'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]], |
| [[7, 9], [float('nan'), float('nan')], [11, 12]]], |
| 'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]], |
| [[7, 9], [float('inf'), float('inf')], [10, 11]]], |
| 'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]], |
| [[7, 9], [-float('inf'), -float('inf')], [12, 13]]], |
| }, |
| { |
| 'src': [[1, 3], [2, 4]], |
| 'index': [[0, 0], [0, 0]], |
| 'indptr': [[0, 2], [0, 2]], |
| 'sum': [[4], [6]], |
| 'prod': [[3], [8]], |
| 'mean': [[2], [3]], |
| 'min': [[1], [2]], |
| 'max': [[3], [4]], |
| }, |
| { |
| 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], |
| 'index': [[0, 0], [0, 0]], |
| 'indptr': [[0, 2], [0, 2]], |
| 'sum': [[[4, 4]], [[6, 6]]], |
| 'prod': [[[3, 3]], [[8, 8]]], |
| 'mean': [[[2, 2]], [[3, 3]]], |
| 'min': [[[1, 1]], [[2, 2]]], |
| 'max': [[[3, 3]], [[4, 4]]], |
| }, |
| ] |
| for test in tests: |
| data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True) |
| indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device) |
| dim = indptr.ndim - 1 |
| # calculate lengths from indptr |
| lengths = torch.diff(indptr, dim=dim) |
| expected = torch.tensor(test[reduce], dtype=val_dtype, device=device) |
| |
| actual_result = torch._segment_reduce( |
| data=data, |
| reduce=reduce, |
| lengths=lengths, |
| axis=dim, |
| unsafe=True, |
| ) |
| self.assertEqual(actual_result, expected) |
| |
| # test offsets |
| actual_result = torch._segment_reduce( |
| data=data, |
| reduce=reduce, |
| offsets=indptr, |
| axis=dim, |
| unsafe=True, |
| ) |
| self.assertEqual(actual_result, expected) |
| |
| if val_dtype == torch.float64: |
| def fn(x, mode='lengths'): |
| initial = 1 |
| # supply initial values to prevent gradcheck from failing for 0 length segments |
| # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian |
| if reduce == 'min': |
| initial = 1000 |
| elif reduce == 'max': |
| initial = -1000 |
| segment_reduce_args = {x, reduce} |
| segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial) |
| if mode == 'lengths': |
| segment_reduce_kwargs[mode] = lengths |
| elif mode == 'offsets': |
| segment_reduce_kwargs[mode] = indptr |
| return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs) |
| self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True)))) |
| self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True)))) |
| |
| |
| @dtypes( |
| *product( |
| (torch.half, torch.bfloat16, torch.float, torch.double), |
| (torch.int, torch.int64), |
| ) |
| ) |
| def test_multi_d(self, device, dtypes): |
| val_dtype, length_type = dtypes |
| axis = 0 |
| lengths = [0, 2, 3, 0] |
| data = np.arange(50).reshape(5, 2, 5).tolist() |
| expected_grad = [] |
| |
| # TODO: calculate grad and check correctness |
| check_backward = False |
| |
| for reduction in reductions: |
| initial_value = 0 |
| if reduction == "max": |
| expected_result = [ |
| np.full((2, 5), initial_value).tolist(), |
| np.max(data[:2], axis=0).tolist(), |
| np.max(data[2:], axis=0).tolist(), |
| np.full((2, 5), initial_value).tolist(), |
| ] |
| elif reduction == "mean": |
| expected_result = [ |
| np.full((2, 5), initial_value).tolist(), |
| np.mean(data[:2], axis=0).tolist(), |
| np.mean(data[2:], axis=0).tolist(), |
| np.full((2, 5), initial_value).tolist(), |
| ] |
| elif reduction == "min": |
| initial_value = 1000 # some high number |
| expected_result = [ |
| np.full((2, 5), initial_value).tolist(), |
| np.min(data[:2], axis=0).tolist(), |
| np.min(data[2:], axis=0).tolist(), |
| np.full((2, 5), initial_value).tolist(), |
| ] |
| elif reduction == "sum": |
| expected_result = [ |
| np.full((2, 5), initial_value).tolist(), |
| np.sum(data[:2], axis=0).tolist(), |
| np.sum(data[2:], axis=0).tolist(), |
| np.full((2, 5), initial_value).tolist(), |
| ] |
| elif reduction == "prod": |
| initial_value = 1 |
| expected_result = [ |
| np.full((2, 5), initial_value).tolist(), |
| np.prod(data[:2], axis=0).tolist(), |
| np.prod(data[2:], axis=0).tolist(), |
| np.full((2, 5), initial_value).tolist(), |
| ] |
| for unsafe in [True, False]: |
| self._test_common( |
| reduction, |
| device, |
| val_dtype, |
| unsafe, |
| axis, |
| initial_value, |
| data, |
| lengths, |
| expected_result, |
| expected_grad, |
| check_backward, |
| ) |
| |
| @dtypes(torch.int, torch.int64) |
| def test_unsafe_flag(self, device, dtype): |
| length_type = dtype |
| lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type) |
| data = torch.arange(6, dtype=torch.float, device=device) |
| |
| # test for error on 1-D lenghts |
| with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): |
| torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False) |
| |
| # test for error on multi-D lengths |
| nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device) |
| nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6) |
| with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): |
| torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False) |
| |
| |
| |
| |
| instantiate_device_type_tests(TestSegmentReductions, globals()) |
| |
| if __name__ == "__main__": |
| run_tests() |