| # mypy: ignore-errors |
| |
| """Wrapper to mimic (parts of) np.random API surface. |
| |
| NumPy has strict guarantees on reproducibility etc; here we don't give any. |
| |
| Q: default dtype is float64 in numpy |
| |
| """ |
| from __future__ import annotations |
| |
| import functools |
| from math import sqrt |
| from typing import Optional |
| |
| import torch |
| |
| from . import _dtypes_impl, _util |
| from ._normalizations import array_or_scalar, ArrayLike, normalizer |
| |
| |
| __all__ = [ |
| "seed", |
| "random_sample", |
| "sample", |
| "random", |
| "rand", |
| "randn", |
| "normal", |
| "choice", |
| "randint", |
| "shuffle", |
| "uniform", |
| ] |
| |
| |
| def use_numpy_random(): |
| # local import to avoid ref cycles |
| import torch._dynamo.config as config |
| |
| return config.use_numpy_random_stream |
| |
| |
| def deco_stream(func): |
| @functools.wraps(func) |
| def inner(*args, **kwds): |
| if not use_numpy_random(): |
| return func(*args, **kwds) |
| else: |
| import numpy |
| |
| from ._ndarray import ndarray |
| |
| f = getattr(numpy.random, func.__name__) |
| |
| # numpy funcs accept numpy ndarrays, unwrap |
| args = tuple( |
| arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args |
| ) |
| kwds = { |
| key: val.tensor.numpy() if isinstance(val, ndarray) else val |
| for key, val in kwds.items() |
| } |
| |
| value = f(*args, **kwds) |
| |
| # `value` can be either numpy.ndarray or python scalar (or None) |
| if isinstance(value, numpy.ndarray): |
| value = ndarray(torch.as_tensor(value)) |
| |
| return value |
| |
| return inner |
| |
| |
| @deco_stream |
| def seed(seed=None): |
| if seed is not None: |
| torch.random.manual_seed(seed) |
| |
| |
| @deco_stream |
| def random_sample(size=None): |
| if size is None: |
| size = () |
| dtype = _dtypes_impl.default_dtypes().float_dtype |
| values = torch.empty(size, dtype=dtype).uniform_() |
| return array_or_scalar(values, return_scalar=size == ()) |
| |
| |
| def rand(*size): |
| if size == (): |
| size = None |
| return random_sample(size) |
| |
| |
| sample = random_sample |
| random = random_sample |
| |
| |
| @deco_stream |
| def uniform(low=0.0, high=1.0, size=None): |
| if size is None: |
| size = () |
| dtype = _dtypes_impl.default_dtypes().float_dtype |
| values = torch.empty(size, dtype=dtype).uniform_(low, high) |
| return array_or_scalar(values, return_scalar=size == ()) |
| |
| |
| @deco_stream |
| def randn(*size): |
| dtype = _dtypes_impl.default_dtypes().float_dtype |
| values = torch.randn(size, dtype=dtype) |
| return array_or_scalar(values, return_scalar=size == ()) |
| |
| |
| @deco_stream |
| def normal(loc=0.0, scale=1.0, size=None): |
| if size is None: |
| size = () |
| dtype = _dtypes_impl.default_dtypes().float_dtype |
| values = torch.empty(size, dtype=dtype).normal_(loc, scale) |
| return array_or_scalar(values, return_scalar=size == ()) |
| |
| |
| @deco_stream |
| def shuffle(x): |
| # no @normalizer because we do not cast e.g. lists to tensors |
| from ._ndarray import ndarray |
| |
| if isinstance(x, torch.Tensor): |
| tensor = x |
| elif isinstance(x, ndarray): |
| tensor = x.tensor |
| else: |
| raise NotImplementedError("We do not random.shuffle lists in-place") |
| |
| perm = torch.randperm(tensor.shape[0]) |
| xp = tensor[perm] |
| tensor.copy_(xp) |
| |
| |
| @deco_stream |
| def randint(low, high=None, size=None): |
| if size is None: |
| size = () |
| if not isinstance(size, (tuple, list)): |
| size = (size,) |
| if high is None: |
| low, high = 0, low |
| values = torch.randint(low, high, size=size) |
| return array_or_scalar(values, int, return_scalar=size == ()) |
| |
| |
| @deco_stream |
| @normalizer |
| def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): |
| # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch |
| if a.numel() == 1: |
| a = torch.arange(a) |
| |
| # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises |
| |
| # number of draws |
| if size is None: |
| num_el = 1 |
| elif _util.is_sequence(size): |
| num_el = 1 |
| for el in size: |
| num_el *= el |
| else: |
| num_el = size |
| |
| # prepare the probabilities |
| if p is None: |
| p = torch.ones_like(a) / a.shape[0] |
| |
| # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 |
| atol = sqrt(torch.finfo(p.dtype).eps) |
| if abs(p.sum() - 1.0) > atol: |
| raise ValueError("probabilities do not sum to 1.") |
| |
| # actually sample |
| indices = torch.multinomial(p, num_el, replacement=replace) |
| |
| if _util.is_sequence(size): |
| indices = indices.reshape(size) |
| |
| samples = a[indices] |
| |
| return samples |