| # mypy: ignore-errors |
| |
| """Assorted utilities, which do not need anything other then torch and stdlib. |
| """ |
| |
| import operator |
| |
| import torch |
| |
| from . import _dtypes_impl |
| |
| |
| # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 |
| def is_sequence(seq): |
| if isinstance(seq, str): |
| return False |
| try: |
| len(seq) |
| except Exception: |
| return False |
| return True |
| |
| |
| class AxisError(ValueError, IndexError): |
| pass |
| |
| |
| class UFuncTypeError(TypeError, RuntimeError): |
| pass |
| |
| |
| def cast_if_needed(tensor, dtype): |
| # NB: no casting if dtype=None |
| if dtype is not None and tensor.dtype != dtype: |
| tensor = tensor.to(dtype) |
| return tensor |
| |
| |
| def cast_int_to_float(x): |
| # cast integers and bools to the default float dtype |
| if _dtypes_impl._category(x.dtype) < 2: |
| x = x.to(_dtypes_impl.default_dtypes().float_dtype) |
| return x |
| |
| |
| # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h |
| def normalize_axis_index(ax, ndim, argname=None): |
| if not (-ndim <= ax < ndim): |
| raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") |
| if ax < 0: |
| ax += ndim |
| return ax |
| |
| |
| # from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 |
| def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): |
| """ |
| Normalizes an axis argument into a tuple of non-negative integer axes. |
| |
| This handles shorthands such as ``1`` and converts them to ``(1,)``, |
| as well as performing the handling of negative indices covered by |
| `normalize_axis_index`. |
| |
| By default, this forbids axes from being specified multiple times. |
| Used internally by multi-axis-checking logic. |
| |
| Parameters |
| ---------- |
| axis : int, iterable of int |
| The un-normalized index or indices of the axis. |
| ndim : int |
| The number of dimensions of the array that `axis` should be normalized |
| against. |
| argname : str, optional |
| A prefix to put before the error message, typically the name of the |
| argument. |
| allow_duplicate : bool, optional |
| If False, the default, disallow an axis from being specified twice. |
| |
| Returns |
| ------- |
| normalized_axes : tuple of int |
| The normalized axis index, such that `0 <= normalized_axis < ndim` |
| """ |
| # Optimization to speed-up the most common cases. |
| if type(axis) not in (tuple, list): |
| try: |
| axis = [operator.index(axis)] |
| except TypeError: |
| pass |
| # Going via an iterator directly is slower than via list comprehension. |
| axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) |
| if not allow_duplicate and len(set(map(int, axis))) != len(axis): |
| if argname: |
| raise ValueError(f"repeated axis in `{argname}` argument") |
| else: |
| raise ValueError("repeated axis") |
| return axis |
| |
| |
| def allow_only_single_axis(axis): |
| if axis is None: |
| return axis |
| if len(axis) != 1: |
| raise NotImplementedError("does not handle tuple axis") |
| return axis[0] |
| |
| |
| def expand_shape(arr_shape, axis): |
| # taken from numpy 1.23.x, expand_dims function |
| if type(axis) not in (list, tuple): |
| axis = (axis,) |
| out_ndim = len(axis) + len(arr_shape) |
| axis = normalize_axis_tuple(axis, out_ndim) |
| shape_it = iter(arr_shape) |
| shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] |
| return shape |
| |
| |
| def apply_keepdims(tensor, axis, ndim): |
| if axis is None: |
| # tensor was a scalar |
| shape = (1,) * ndim |
| tensor = tensor.expand(shape).contiguous() |
| else: |
| shape = expand_shape(tensor.shape, axis) |
| tensor = tensor.reshape(shape) |
| return tensor |
| |
| |
| def axis_none_flatten(*tensors, axis=None): |
| """Flatten the arrays if axis is None.""" |
| if axis is None: |
| tensors = tuple(ar.flatten() for ar in tensors) |
| return tensors, 0 |
| else: |
| return tensors, axis |
| |
| |
| def typecast_tensor(t, target_dtype, casting): |
| """Dtype-cast tensor to target_dtype. |
| |
| Parameters |
| ---------- |
| t : torch.Tensor |
| The tensor to cast |
| target_dtype : torch dtype object |
| The array dtype to cast all tensors to |
| casting : str |
| The casting mode, see `np.can_cast` |
| |
| Returns |
| ------- |
| `torch.Tensor` of the `target_dtype` dtype |
| |
| Raises |
| ------ |
| ValueError |
| if the argument cannot be cast according to the `casting` rule |
| |
| """ |
| can_cast = _dtypes_impl.can_cast_impl |
| |
| if not can_cast(t.dtype, target_dtype, casting=casting): |
| raise TypeError( |
| f"Cannot cast array data from {t.dtype} to" |
| f" {target_dtype} according to the rule '{casting}'" |
| ) |
| return cast_if_needed(t, target_dtype) |
| |
| |
| def typecast_tensors(tensors, target_dtype, casting): |
| return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) |
| |
| |
| def _try_convert_to_tensor(obj): |
| try: |
| tensor = torch.as_tensor(obj) |
| except Exception as e: |
| mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." |
| raise NotImplementedError(mesg) # noqa: B904 |
| return tensor |
| |
| |
| def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): |
| """The core logic of the array(...) function. |
| |
| Parameters |
| ---------- |
| obj : tensor_like |
| The thing to coerce |
| dtype : torch.dtype object or None |
| Coerce to this torch dtype |
| copy : bool |
| Copy or not |
| ndmin : int |
| The results as least this many dimensions |
| is_weak : bool |
| Whether obj is a weakly typed python scalar. |
| |
| Returns |
| ------- |
| tensor : torch.Tensor |
| a tensor object with requested dtype, ndim and copy semantics. |
| |
| Notes |
| ----- |
| This is almost a "tensor_like" coersion function. Does not handle wrapper |
| ndarrays (those should be handled in the ndarray-aware layer prior to |
| invoking this function). |
| """ |
| if isinstance(obj, torch.Tensor): |
| tensor = obj |
| else: |
| # tensor.dtype is the pytorch default, typically float32. If obj's elements |
| # are not exactly representable in float32, we've lost precision: |
| # >>> torch.as_tensor(1e12).item() - 1e12 |
| # -4096.0 |
| default_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) |
| try: |
| tensor = _try_convert_to_tensor(obj) |
| finally: |
| torch.set_default_dtype(default_dtype) |
| |
| # type cast if requested |
| tensor = cast_if_needed(tensor, dtype) |
| |
| # adjust ndim if needed |
| ndim_extra = ndmin - tensor.ndim |
| if ndim_extra > 0: |
| tensor = tensor.view((1,) * ndim_extra + tensor.shape) |
| |
| # copy if requested |
| if copy: |
| tensor = tensor.clone() |
| |
| return tensor |
| |
| |
| def ndarrays_to_tensors(*inputs): |
| """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" |
| from ._ndarray import ndarray |
| |
| if len(inputs) == 0: |
| return ValueError() |
| elif len(inputs) == 1: |
| input_ = inputs[0] |
| if isinstance(input_, ndarray): |
| return input_.tensor |
| elif isinstance(input_, tuple): |
| result = [] |
| for sub_input in input_: |
| sub_result = ndarrays_to_tensors(sub_input) |
| result.append(sub_result) |
| return tuple(result) |
| else: |
| return input_ |
| else: |
| assert isinstance(inputs, tuple) # sanity check |
| return ndarrays_to_tensors(inputs) |