blob: 8a3fc029584be702aaff106e062e9e0090a634bd [file] [log] [blame]
# Owner(s): ["module: dynamo"]
from unittest import skipIf, SkipTest
import numpy
import pytest
from pytest import raises as assert_raises
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_TORCHDYNAMO,
TestCase,
xpassIfTorchDynamo,
)
# If we are going to trace through these, we should use NumPy
# If testing on eager mode, we use torch._numpy
if TEST_WITH_TORCHDYNAMO:
import numpy as np
import numpy.core.numeric as _util # for normalize_axis_tuple
from numpy.testing import (
assert_allclose,
assert_almost_equal,
assert_array_equal,
assert_equal,
)
else:
import torch._numpy as np
from torch._numpy import _util
from torch._numpy.testing import (
assert_allclose,
assert_almost_equal,
assert_array_equal,
assert_equal,
)
class TestFlatnonzero(TestCase):
def test_basic(self):
x = np.arange(-2, 3)
assert_equal(np.flatnonzero(x), [0, 1, 3, 4])
class TestAny(TestCase):
def test_basic(self):
y1 = [0, 0, 1, 0]
y2 = [0, 0, 0, 0]
y3 = [1, 0, 1, 0]
assert np.any(y1)
assert np.any(y3)
assert not np.any(y2)
def test_nd(self):
y1 = [[0, 0, 0], [0, 1, 0], [1, 1, 0]]
assert np.any(y1)
assert_equal(np.any(y1, axis=0), [1, 1, 0])
assert_equal(np.any(y1, axis=1), [0, 1, 1])
assert_equal(np.any(y1), True)
assert isinstance(np.any(y1, axis=1), np.ndarray)
# YYY: deduplicate
def test_method_vs_function(self):
y = np.array([[0, 1, 0, 3], [1, 0, 2, 0]])
assert_equal(np.any(y), y.any())
class TestAll(TestCase):
def test_basic(self):
y1 = [0, 1, 1, 0]
y2 = [0, 0, 0, 0]
y3 = [1, 1, 1, 1]
assert not np.all(y1)
assert np.all(y3)
assert not np.all(y2)
assert np.all(~np.array(y2))
def test_nd(self):
y1 = [[0, 0, 1], [0, 1, 1], [1, 1, 1]]
assert not np.all(y1)
assert_equal(np.all(y1, axis=0), [0, 0, 1])
assert_equal(np.all(y1, axis=1), [0, 0, 1])
assert_equal(np.all(y1), False)
def test_method_vs_function(self):
y = np.array([[0, 1, 0, 3], [1, 0, 2, 0]])
assert_equal(np.all(y), y.all())
class TestMean(TestCase):
def test_mean(self):
A = [[1, 2, 3], [4, 5, 6]]
assert np.mean(A) == 3.5
assert np.all(np.mean(A, 0) == np.array([2.5, 3.5, 4.5]))
assert np.all(np.mean(A, 1) == np.array([2.0, 5.0]))
# XXX: numpy emits a warning on empty slice
assert np.isnan(np.mean([]))
m = np.asarray(A)
assert np.mean(A) == m.mean()
def test_mean_values(self):
# rmat = np.random.random((4, 5))
rmat = np.arange(20, dtype=float).reshape((4, 5))
cmat = rmat + 1j * rmat
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
for mat in [rmat, cmat]:
for axis in [0, 1]:
tgt = mat.sum(axis=axis)
res = np.mean(mat, axis=axis) * mat.shape[axis]
assert_allclose(res, tgt)
for axis in [None]:
tgt = mat.sum(axis=axis)
res = np.mean(mat, axis=axis) * mat.size
assert_allclose(res, tgt)
def test_mean_float16(self):
# This fail if the sum inside mean is done in float16 instead
# of float32.
assert np.mean(np.ones(100000, dtype="float16")) == 1
@xpassIfTorchDynamo # (reason="XXX: mean(..., where=...) not implemented")
def test_mean_where(self):
a = np.arange(16).reshape((4, 4))
wh_full = np.array(
[
[False, True, False, True],
[True, False, True, False],
[True, True, False, False],
[False, False, True, True],
]
)
wh_partial = np.array([[False], [True], [True], [False]])
_cases = [
(1, True, [1.5, 5.5, 9.5, 13.5]),
(0, wh_full, [6.0, 5.0, 10.0, 9.0]),
(1, wh_full, [2.0, 5.0, 8.5, 14.5]),
(0, wh_partial, [6.0, 7.0, 8.0, 9.0]),
]
for _ax, _wh, _res in _cases:
assert_allclose(a.mean(axis=_ax, where=_wh), np.array(_res))
assert_allclose(np.mean(a, axis=_ax, where=_wh), np.array(_res))
a3d = np.arange(16).reshape((2, 2, 4))
_wh_partial = np.array([False, True, True, False])
_res = [[1.5, 5.5], [9.5, 13.5]]
assert_allclose(a3d.mean(axis=2, where=_wh_partial), np.array(_res))
assert_allclose(np.mean(a3d, axis=2, where=_wh_partial), np.array(_res))
with pytest.warns(RuntimeWarning) as w:
assert_allclose(
a.mean(axis=1, where=wh_partial), np.array([np.nan, 5.5, 9.5, np.nan])
)
with pytest.warns(RuntimeWarning) as w:
assert_equal(a.mean(where=False), np.nan)
with pytest.warns(RuntimeWarning) as w:
assert_equal(np.mean(a, where=False), np.nan)
@instantiate_parametrized_tests
class TestSum(TestCase):
def test_sum(self):
m = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
tgt = [[6], [15], [24]]
out = np.sum(m, axis=1, keepdims=True)
assert_equal(tgt, out)
am = np.asarray(m)
assert_equal(np.sum(m), am.sum())
def test_sum_stability(self):
a = np.ones(500, dtype=np.float32)
zero = np.zeros(1, dtype="float32")[0]
assert_allclose((a / 10.0).sum() - a.size / 10.0, zero, atol=1.5e-4)
a = np.ones(500, dtype=np.float64)
assert_allclose((a / 10.0).sum() - a.size / 10.0, 0.0, atol=1.5e-13)
def test_sum_boolean(self):
a = np.arange(7) % 2 == 0
res = a.sum()
assert_equal(res, 4)
res_float = a.sum(dtype=np.float64)
assert_allclose(res_float, 4.0, atol=1e-15)
assert res_float.dtype == "float64"
@skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x")
@xpassIfTorchDynamo # (reason="sum: does not warn on overflow")
def test_sum_dtypes_warnings(self):
for dt in (int, np.float16, np.float32, np.float64):
for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, 128, 1024, 1235):
# warning if sum overflows, which it does in float16
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", RuntimeWarning)
tgt = dt(v * (v + 1) / 2)
overflow = not np.isfinite(tgt)
assert_equal(len(w), 1 * overflow)
d = np.arange(1, v + 1, dtype=dt)
assert_almost_equal(np.sum(d), tgt)
assert_equal(len(w), 2 * overflow)
assert_almost_equal(np.sum(np.flip(d)), tgt)
assert_equal(len(w), 3 * overflow)
def test_sum_dtypes_2(self):
for dt in (int, np.float16, np.float32, np.float64):
d = np.ones(500, dtype=dt)
assert_almost_equal(np.sum(d[::2]), 250.0)
assert_almost_equal(np.sum(d[1::2]), 250.0)
assert_almost_equal(np.sum(d[::3]), 167.0)
assert_almost_equal(np.sum(d[1::3]), 167.0)
assert_almost_equal(np.sum(np.flip(d)[::2]), 250.0)
assert_almost_equal(np.sum(np.flip(d)[1::2]), 250.0)
assert_almost_equal(np.sum(np.flip(d)[::3]), 167.0)
assert_almost_equal(np.sum(np.flip(d)[1::3]), 167.0)
# sum with first reduction entry != 0
d = np.ones((1,), dtype=dt)
d += d
assert_almost_equal(d, 2.0)
@parametrize("dt", [np.complex64, np.complex128])
def test_sum_complex_1(self, dt):
for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, 128, 1024, 1235):
tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j)
d = np.empty(v, dtype=dt)
d.real = np.arange(1, v + 1)
d.imag = -np.arange(1, v + 1)
assert_allclose(np.sum(d), tgt, atol=1.5e-5)
assert_allclose(np.sum(np.flip(d)), tgt, atol=1.5e-7)
@parametrize("dt", [np.complex64, np.complex128])
def test_sum_complex_2(self, dt):
d = np.ones(500, dtype=dt) + 1j
assert_allclose(np.sum(d[::2]), 250.0 + 250j, atol=1.5e-7)
assert_allclose(np.sum(d[1::2]), 250.0 + 250j, atol=1.5e-7)
assert_allclose(np.sum(d[::3]), 167.0 + 167j, atol=1.5e-7)
assert_allclose(np.sum(d[1::3]), 167.0 + 167j, atol=1.5e-7)
assert_allclose(np.sum(np.flip(d)[::2]), 250.0 + 250j, atol=1.5e-7)
assert_allclose(np.sum(np.flip(d)[1::2]), 250.0 + 250j, atol=1.5e-7)
assert_allclose(np.sum(np.flip(d)[::3]), 167.0 + 167j, atol=1.5e-7)
assert_allclose(np.sum(np.flip(d)[1::3]), 167.0 + 167j, atol=1.5e-7)
# sum with first reduction entry != 0
d = np.ones((1,), dtype=dt) + 1j
d += d
assert_allclose(d, 2.0 + 2j, atol=1.5e-7)
@xpassIfTorchDynamo # (reason="initial=... need implementing")
def test_sum_initial(self):
# Integer, single axis
assert_equal(np.sum([3], initial=2), 5)
# Floating point
assert_almost_equal(np.sum([0.2], initial=0.1), 0.3)
# Multiple non-adjacent axes
assert_equal(
np.sum(np.ones((2, 3, 5), dtype=np.int64), axis=(0, 2), initial=2),
[12, 12, 12],
)
@xpassIfTorchDynamo # (reason="where=... need implementing")
def test_sum_where(self):
# More extensive tests done in test_reduction_with_where.
assert_equal(np.sum([[1.0, 2.0], [3.0, 4.0]], where=[True, False]), 4.0)
assert_equal(
np.sum([[1.0, 2.0], [3.0, 4.0]], axis=0, initial=5.0, where=[True, False]),
[9.0, 5.0],
)
parametrize_axis = parametrize(
"axis", [0, 1, 2, -1, -2, (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
)
parametrize_func = parametrize(
"func",
[
np.any,
np.all,
np.argmin,
np.argmax,
np.min,
np.max,
np.mean,
np.sum,
np.prod,
np.std,
np.var,
np.count_nonzero,
],
)
fails_axes_tuples = {
np.any,
np.all,
np.argmin,
np.argmax,
np.prod,
}
fails_out_arg = {
np.count_nonzero,
}
restricts_dtype_casts = {np.var, np.std}
fails_empty_tuple = {np.argmin, np.argmax}
@instantiate_parametrized_tests
class TestGenericReductions(TestCase):
"""Run a set of generic tests to verify that self.func acts like a
reduction operation.
Specifically, this class checks axis=... and keepdims=... parameters.
To check the out=... parameter, see the _GenericHasOutTestMixin class below.
To use: subclass, define self.func and self.allowed_axes.
"""
@parametrize_func
def test_bad_axis(self, func):
# Basic check of functionality
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
assert_raises(TypeError, func, m, axis="foo")
assert_raises(np.AxisError, func, m, axis=3)
assert_raises(TypeError, func, m, axis=np.array([[1], [2]]))
assert_raises(TypeError, func, m, axis=1.5)
# TODO: add tests with np.int32(3) etc, when implemented
@parametrize_func
def test_array_axis(self, func):
a = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
assert_equal(func(a, axis=np.array(-1)), func(a, axis=-1))
with assert_raises(TypeError):
func(a, axis=np.array([1, 2]))
@parametrize_func
def test_axis_empty_generic(self, func):
if func in fails_empty_tuple:
raise SkipTest("func(..., axis=()) is not valid")
a = np.array([[0, 0, 1], [1, 0, 1]])
assert_array_equal(func(a, axis=()), func(np.expand_dims(a, axis=0), axis=0))
@parametrize_func
def test_axis_bad_tuple(self, func):
# Basic check of functionality
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
if func in fails_axes_tuples:
raise SkipTest(f"{func.__name__} does not allow tuple axis.")
with assert_raises(ValueError):
func(m, axis=(1, 1))
@parametrize_axis
@parametrize_func
def test_keepdims_generic(self, axis, func):
if func in fails_axes_tuples:
raise SkipTest(f"{func.__name__} does not allow tuple axis.")
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
with_keepdims = func(a, axis, keepdims=True)
expanded = np.expand_dims(func(a, axis=axis), axis=axis)
assert_array_equal(with_keepdims, expanded)
@skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on CI w/old numpy")
@parametrize_func
def test_keepdims_generic_axis_none(self, func):
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
with_keepdims = func(a, axis=None, keepdims=True)
scalar = func(a, axis=None)
expanded = np.full((1,) * a.ndim, fill_value=scalar)
assert_array_equal(with_keepdims, expanded)
@parametrize_func
def test_out_scalar(self, func):
# out no axis: scalar
if func in fails_out_arg:
raise SkipTest(f"{func.__name__} does not have out= arg.")
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
result = func(a)
out = np.empty_like(result)
result_with_out = func(a, out=out)
assert result_with_out is out
assert_array_equal(result, result_with_out)
def _check_out_axis(self, axis, dtype, keepdims):
# out with axis
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
result = self.func(a, axis=axis, keepdims=keepdims).astype(dtype)
out = np.empty_like(result, dtype=dtype)
result_with_out = self.func(a, axis=axis, keepdims=keepdims, out=out)
assert result_with_out is out
assert result_with_out.dtype == dtype
assert_array_equal(result, result_with_out)
# TODO: what if result.dtype != out.dtype; does out typecast the result?
# out of wrong shape (any/out does not broadcast)
# np.any(m, out=np.empty_like(m)) raises a ValueError (wrong number
# of dimensions.)
# pytorch.any emits a warning and resizes the out array.
# Here we follow pytorch, since the result is a superset
# of the numpy functionality
@parametrize("keepdims", [True, False])
@parametrize("dtype", [bool, "int32", "float64"])
@parametrize_func
@parametrize_axis
def test_out_axis(self, func, axis, dtype, keepdims):
# out with axis
if func in fails_out_arg:
raise SkipTest(f"{func.__name__} does not have out= arg.")
if func in fails_axes_tuples:
raise SkipTest(f"{func.__name__} does not hangle tuple axis.")
if func in restricts_dtype_casts:
raise SkipTest(f"{func.__name__}: test implies float->int casts")
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
result = func(a, axis=axis, keepdims=keepdims).astype(dtype)
out = np.empty_like(result, dtype=dtype)
result_with_out = func(a, axis=axis, keepdims=keepdims, out=out)
assert result_with_out is out
assert result_with_out.dtype == dtype
assert_array_equal(result, result_with_out)
# TODO: what if result.dtype != out.dtype; does out typecast the result?
# out of wrong shape (any/out does not broadcast)
# np.any(m, out=np.empty_like(m)) raises a ValueError (wrong number
# of dimensions.)
# pytorch.any emits a warning and resizes the out array.
# Here we follow pytorch, since the result is a superset
# of the numpy functionality
@parametrize_func
@parametrize_axis
def test_keepdims_out(self, func, axis):
if func in fails_out_arg:
raise SkipTest(f"{func.__name__} does not have out= arg.")
if func in fails_axes_tuples:
raise SkipTest(f"{func.__name__} does not hangle tuple axis.")
d = np.ones((3, 5, 7, 11))
if axis is None:
shape_out = (1,) * d.ndim
else:
axis_norm = _util.normalize_axis_tuple(axis, d.ndim)
shape_out = tuple(
1 if i in axis_norm else d.shape[i] for i in range(d.ndim)
)
out = np.empty(shape_out)
result = func(d, axis=axis, keepdims=True, out=out)
assert result is out
assert_equal(result.shape, shape_out)
@instantiate_parametrized_tests
class TestGenericCumSumProd(TestCase):
"""Run a set of generic tests to verify that cumsum/cumprod are sane."""
@parametrize("func", [np.cumsum, np.cumprod])
def test_bad_axis(self, func):
# Basic check of functionality
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
assert_raises(TypeError, func, m, axis="foo")
assert_raises(np.AxisError, func, m, axis=3)
assert_raises(TypeError, func, m, axis=np.array([[1], [2]]))
assert_raises(TypeError, func, m, axis=1.5)
# TODO: add tests with np.int32(3) etc, when implemented
@parametrize("func", [np.cumsum, np.cumprod])
def test_array_axis(self, func):
a = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
assert_equal(func(a, axis=np.array(-1)), func(a, axis=-1))
with assert_raises(TypeError):
func(a, axis=np.array([1, 2]))
@parametrize("func", [np.cumsum, np.cumprod])
def test_axis_empty_generic(self, func):
a = np.array([[0, 0, 1], [1, 0, 1]])
assert_array_equal(func(a, axis=None), func(a.ravel(), axis=0))
@parametrize("func", [np.cumsum, np.cumprod])
def test_axis_bad_tuple(self, func):
# Basic check of functionality
m = np.array([[0, 1, 7, 0, 0], [3, 0, 0, 2, 19]])
with assert_raises(TypeError):
func(m, axis=(1, 1))
if __name__ == "__main__":
run_tests()