| """This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this |
| was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus, |
| we don't internalize without warning, but still go through a deprecation cycle. |
| """ |
| |
| import functools |
| import random |
| import warnings |
| from typing import Any, Callable, Dict, Optional, Tuple, Union |
| |
| import torch |
| |
| from . import _legacy |
| |
| |
| __all__ = [ |
| "rand", |
| "randn", |
| "assert_allclose", |
| "get_all_device_types", |
| "make_non_contiguous", |
| ] |
| |
| |
| def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable: |
| def outer_wrapper(fn: Callable) -> Callable: |
| name = fn.__name__ |
| head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. " |
| |
| @functools.wraps(fn) |
| def inner_wrapper(*args: Any, **kwargs: Any) -> Any: |
| return_value = fn(*args, **kwargs) |
| tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions |
| msg = (head + tail).strip() |
| warnings.warn(msg, FutureWarning) |
| return return_value |
| |
| return inner_wrapper |
| |
| return outer_wrapper |
| |
| |
| rand = warn_deprecated("Use torch.rand() instead.")(torch.rand) |
| randn = warn_deprecated("Use torch.randn() instead.")(torch.randn) |
| |
| |
| _DTYPE_PRECISIONS = { |
| torch.float16: (1e-3, 1e-3), |
| torch.float32: (1e-4, 1e-5), |
| torch.float64: (1e-5, 1e-8), |
| } |
| |
| |
| def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]: |
| actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0)) |
| expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0)) |
| return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) |
| |
| |
| @warn_deprecated( |
| "Use torch.testing.assert_close() instead. " |
| "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844." |
| ) |
| def assert_allclose( |
| actual: Any, |
| expected: Any, |
| rtol: Optional[float] = None, |
| atol: Optional[float] = None, |
| equal_nan: bool = True, |
| msg: str = "", |
| ) -> None: |
| if not isinstance(actual, torch.Tensor): |
| actual = torch.tensor(actual) |
| if not isinstance(expected, torch.Tensor): |
| expected = torch.tensor(expected, dtype=actual.dtype) |
| |
| if rtol is None and atol is None: |
| rtol, atol = _get_default_rtol_and_atol(actual, expected) |
| |
| torch.testing.assert_close( |
| actual, |
| expected, |
| rtol=rtol, |
| atol=atol, |
| equal_nan=equal_nan, |
| check_device=True, |
| check_dtype=False, |
| check_stride=False, |
| msg=msg or None, |
| ) |
| |
| |
| getter_instructions = ( |
| lambda name, args, kwargs, return_value: f"This call can be replaced with {return_value}." # noqa: E731 |
| ) |
| |
| # Deprecate and expose all dtype getters |
| for name in _legacy.__all_dtype_getters__: |
| fn = getattr(_legacy, name) |
| globals()[name] = warn_deprecated(getter_instructions)(fn) |
| __all__.append(name) |
| |
| get_all_device_types = warn_deprecated(getter_instructions)(_legacy.get_all_device_types) |
| |
| |
| @warn_deprecated( |
| "Depending on the use case there a different replacement options:\n\n" |
| "- If you are using `make_non_contiguous` in combination with a creation function to create a noncontiguous tensor " |
| "with random values, use `torch.testing.make_tensor(..., noncontiguous=True)` instead.\n" |
| "- If you are using `make_non_contiguous` with a specific tensor, you can replace this call with " |
| "`torch.repeat_interleave(input, 2, dim=-1)[..., ::2]`.\n" |
| "- If you are using `make_non_contiguous` in the PyTorch test suite, use " |
| "`torch.testing._internal.common_utils.noncontiguous_like` instead." |
| ) |
| def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor: |
| if tensor.numel() <= 1: # can't make non-contiguous |
| return tensor.clone() |
| osize = list(tensor.size()) |
| |
| # randomly inflate a few dimensions in osize |
| for _ in range(2): |
| dim = random.randint(0, len(osize) - 1) |
| add = random.randint(4, 15) |
| osize[dim] = osize[dim] + add |
| |
| # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension, |
| # (which will always happen with a 1-dimensional tensor), so let's make a new |
| # right-most dimension and cut it off |
| |
| input = tensor.new(torch.Size(osize + [random.randint(2, 3)])) |
| input = input.select(len(input.size()) - 1, random.randint(0, 1)) |
| # now extract the input of correct size from 'input' |
| for i in range(len(osize)): |
| if input.size(i) != tensor.size(i): |
| bounds = random.randint(1, input.size(i) - tensor.size(i)) |
| input = input.narrow(i, bounds, tensor.size(i)) |
| |
| input.copy_(tensor) |
| |
| # Use .data here to hide the view relation between input and other temporary Tensors |
| return input.data |