| # mypy: ignore-errors |
| |
| from __future__ import annotations |
| |
| import builtins |
| import math |
| import operator |
| from typing import Sequence |
| |
| import torch |
| |
| from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util |
| from ._normalizations import ( |
| ArrayLike, |
| normalize_array_like, |
| normalizer, |
| NotImplementedType, |
| ) |
| |
| |
| newaxis = None |
| |
| FLAGS = [ |
| "C_CONTIGUOUS", |
| "F_CONTIGUOUS", |
| "OWNDATA", |
| "WRITEABLE", |
| "ALIGNED", |
| "WRITEBACKIFCOPY", |
| "FNC", |
| "FORC", |
| "BEHAVED", |
| "CARRAY", |
| "FARRAY", |
| ] |
| |
| SHORTHAND_TO_FLAGS = { |
| "C": "C_CONTIGUOUS", |
| "F": "F_CONTIGUOUS", |
| "O": "OWNDATA", |
| "W": "WRITEABLE", |
| "A": "ALIGNED", |
| "X": "WRITEBACKIFCOPY", |
| "B": "BEHAVED", |
| "CA": "CARRAY", |
| "FA": "FARRAY", |
| } |
| |
| |
| class Flags: |
| def __init__(self, flag_to_value: dict): |
| assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check |
| self._flag_to_value = flag_to_value |
| |
| def __getattr__(self, attr: str): |
| if attr.islower() and attr.upper() in FLAGS: |
| return self[attr.upper()] |
| else: |
| raise AttributeError(f"No flag attribute '{attr}'") |
| |
| def __getitem__(self, key): |
| if key in SHORTHAND_TO_FLAGS.keys(): |
| key = SHORTHAND_TO_FLAGS[key] |
| if key in FLAGS: |
| try: |
| return self._flag_to_value[key] |
| except KeyError as e: |
| raise NotImplementedError(f"{key=}") from e |
| else: |
| raise KeyError(f"No flag key '{key}'") |
| |
| def __setattr__(self, attr, value): |
| if attr.islower() and attr.upper() in FLAGS: |
| self[attr.upper()] = value |
| else: |
| super().__setattr__(attr, value) |
| |
| def __setitem__(self, key, value): |
| if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): |
| raise NotImplementedError("Modifying flags is not implemented") |
| else: |
| raise KeyError(f"No flag key '{key}'") |
| |
| |
| def create_method(fn, name=None): |
| name = name or fn.__name__ |
| |
| def f(*args, **kwargs): |
| return fn(*args, **kwargs) |
| |
| f.__name__ = name |
| f.__qualname__ = f"ndarray.{name}" |
| return f |
| |
| |
| # Map ndarray.name_method -> np.name_func |
| # If name_func == None, it means that name_method == name_func |
| methods = { |
| "clip": None, |
| "nonzero": None, |
| "repeat": None, |
| "round": None, |
| "squeeze": None, |
| "swapaxes": None, |
| "ravel": None, |
| # linalg |
| "diagonal": None, |
| "dot": None, |
| "trace": None, |
| # sorting |
| "argsort": None, |
| "searchsorted": None, |
| # reductions |
| "argmax": None, |
| "argmin": None, |
| "any": None, |
| "all": None, |
| "max": None, |
| "min": None, |
| "ptp": None, |
| "sum": None, |
| "prod": None, |
| "mean": None, |
| "var": None, |
| "std": None, |
| # scans |
| "cumsum": None, |
| "cumprod": None, |
| # advanced indexing |
| "take": None, |
| "choose": None, |
| } |
| |
| dunder = { |
| "abs": "absolute", |
| "invert": None, |
| "pos": "positive", |
| "neg": "negative", |
| "gt": "greater", |
| "lt": "less", |
| "ge": "greater_equal", |
| "le": "less_equal", |
| } |
| |
| # dunder methods with right-looking and in-place variants |
| ri_dunder = { |
| "add": None, |
| "sub": "subtract", |
| "mul": "multiply", |
| "truediv": "divide", |
| "floordiv": "floor_divide", |
| "pow": "power", |
| "mod": "remainder", |
| "and": "bitwise_and", |
| "or": "bitwise_or", |
| "xor": "bitwise_xor", |
| "lshift": "left_shift", |
| "rshift": "right_shift", |
| "matmul": None, |
| } |
| |
| |
| def _upcast_int_indices(index): |
| if isinstance(index, torch.Tensor): |
| if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): |
| return index.to(torch.int64) |
| elif isinstance(index, tuple): |
| return tuple(_upcast_int_indices(i) for i in index) |
| return index |
| |
| |
| # Used to indicate that a parameter is unspecified (as opposed to explicitly |
| # `None`) |
| class _Unspecified: |
| pass |
| |
| |
| _Unspecified.unspecified = _Unspecified() |
| |
| ############################################################### |
| # ndarray class # |
| ############################################################### |
| |
| |
| class ndarray: |
| def __init__(self, t=None): |
| if t is None: |
| self.tensor = torch.Tensor() |
| elif isinstance(t, torch.Tensor): |
| self.tensor = t |
| else: |
| raise ValueError( |
| "ndarray constructor is not recommended; prefer" |
| "either array(...) or zeros/empty(...)" |
| ) |
| |
| # Register NumPy functions as methods |
| for method, name in methods.items(): |
| fn = getattr(_funcs, name or method) |
| vars()[method] = create_method(fn, method) |
| |
| # Regular methods but coming from ufuncs |
| conj = create_method(_ufuncs.conjugate, "conj") |
| conjugate = create_method(_ufuncs.conjugate) |
| |
| for method, name in dunder.items(): |
| fn = getattr(_ufuncs, name or method) |
| method = f"__{method}__" |
| vars()[method] = create_method(fn, method) |
| |
| for method, name in ri_dunder.items(): |
| fn = getattr(_ufuncs, name or method) |
| plain = f"__{method}__" |
| vars()[plain] = create_method(fn, plain) |
| rvar = f"__r{method}__" |
| vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) |
| ivar = f"__i{method}__" |
| vars()[ivar] = create_method( |
| lambda self, other, fn=fn: fn(self, other, out=self), ivar |
| ) |
| |
| # There's no __idivmod__ |
| __divmod__ = create_method(_ufuncs.divmod, "__divmod__") |
| __rdivmod__ = create_method( |
| lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" |
| ) |
| |
| # prevent loop variables leaking into the ndarray class namespace |
| del ivar, rvar, name, plain, fn, method |
| |
| @property |
| def shape(self): |
| return tuple(self.tensor.shape) |
| |
| @property |
| def size(self): |
| return self.tensor.numel() |
| |
| @property |
| def ndim(self): |
| return self.tensor.ndim |
| |
| @property |
| def dtype(self): |
| return _dtypes.dtype(self.tensor.dtype) |
| |
| @property |
| def strides(self): |
| elsize = self.tensor.element_size() |
| return tuple(stride * elsize for stride in self.tensor.stride()) |
| |
| @property |
| def itemsize(self): |
| return self.tensor.element_size() |
| |
| @property |
| def flags(self): |
| # Note contiguous in torch is assumed C-style |
| return Flags( |
| { |
| "C_CONTIGUOUS": self.tensor.is_contiguous(), |
| "F_CONTIGUOUS": self.T.tensor.is_contiguous(), |
| "OWNDATA": self.tensor._base is None, |
| "WRITEABLE": True, # pytorch does not have readonly tensors |
| } |
| ) |
| |
| @property |
| def data(self): |
| return self.tensor.data_ptr() |
| |
| @property |
| def nbytes(self): |
| return self.tensor.storage().nbytes() |
| |
| @property |
| def T(self): |
| return self.transpose() |
| |
| @property |
| def real(self): |
| return _funcs.real(self) |
| |
| @real.setter |
| def real(self, value): |
| self.tensor.real = asarray(value).tensor |
| |
| @property |
| def imag(self): |
| return _funcs.imag(self) |
| |
| @imag.setter |
| def imag(self, value): |
| self.tensor.imag = asarray(value).tensor |
| |
| # ctors |
| def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): |
| if order != "K": |
| raise NotImplementedError(f"astype(..., order={order} is not implemented.") |
| if casting != "unsafe": |
| raise NotImplementedError( |
| f"astype(..., casting={casting} is not implemented." |
| ) |
| if not subok: |
| raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") |
| if not copy: |
| raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") |
| torch_dtype = _dtypes.dtype(dtype).torch_dtype |
| t = self.tensor.to(torch_dtype) |
| return ndarray(t) |
| |
| @normalizer |
| def copy(self: ArrayLike, order: NotImplementedType = "C"): |
| return self.clone() |
| |
| @normalizer |
| def flatten(self: ArrayLike, order: NotImplementedType = "C"): |
| return torch.flatten(self) |
| |
| def resize(self, *new_shape, refcheck=False): |
| # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. |
| if refcheck: |
| raise NotImplementedError( |
| f"resize(..., refcheck={refcheck} is not implemented." |
| ) |
| if new_shape in [(), (None,)]: |
| return |
| |
| # support both x.resize((2, 2)) and x.resize(2, 2) |
| if len(new_shape) == 1: |
| new_shape = new_shape[0] |
| if isinstance(new_shape, int): |
| new_shape = (new_shape,) |
| |
| if builtins.any(x < 0 for x in new_shape): |
| raise ValueError("all elements of `new_shape` must be non-negative") |
| |
| new_numel, old_numel = math.prod(new_shape), self.tensor.numel() |
| |
| self.tensor.resize_(new_shape) |
| |
| if new_numel >= old_numel: |
| # zero-fill new elements |
| assert self.tensor.is_contiguous() |
| b = self.tensor.flatten() # does not copy |
| b[old_numel:].zero_() |
| |
| def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): |
| if dtype is _Unspecified.unspecified: |
| dtype = self.dtype |
| if type is not _Unspecified.unspecified: |
| raise NotImplementedError(f"view(..., type={type} is not implemented.") |
| torch_dtype = _dtypes.dtype(dtype).torch_dtype |
| tview = self.tensor.view(torch_dtype) |
| return ndarray(tview) |
| |
| @normalizer |
| def fill(self, value: ArrayLike): |
| # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and |
| # error out on D > 0 arrays |
| self.tensor.fill_(value) |
| |
| def tolist(self): |
| return self.tensor.tolist() |
| |
| def __iter__(self): |
| return (ndarray(x) for x in self.tensor.__iter__()) |
| |
| def __str__(self): |
| return ( |
| str(self.tensor) |
| .replace("tensor", "torch.ndarray") |
| .replace("dtype=torch.", "dtype=") |
| ) |
| |
| __repr__ = create_method(__str__) |
| |
| def __eq__(self, other): |
| try: |
| return _ufuncs.equal(self, other) |
| except (RuntimeError, TypeError): |
| # Failed to convert other to array: definitely not equal. |
| falsy = torch.full(self.shape, fill_value=False, dtype=bool) |
| return asarray(falsy) |
| |
| def __ne__(self, other): |
| return ~(self == other) |
| |
| def __index__(self): |
| try: |
| return operator.index(self.tensor.item()) |
| except Exception as exc: |
| raise TypeError( |
| "only integer scalar arrays can be converted to a scalar index" |
| ) from exc |
| |
| def __bool__(self): |
| return bool(self.tensor) |
| |
| def __int__(self): |
| return int(self.tensor) |
| |
| def __float__(self): |
| return float(self.tensor) |
| |
| def __complex__(self): |
| return complex(self.tensor) |
| |
| def is_integer(self): |
| try: |
| v = self.tensor.item() |
| result = int(v) == v |
| except Exception: |
| result = False |
| return result |
| |
| def __len__(self): |
| return self.tensor.shape[0] |
| |
| def __contains__(self, x): |
| return self.tensor.__contains__(x) |
| |
| def transpose(self, *axes): |
| # np.transpose(arr, axis=None) but arr.transpose(*axes) |
| return _funcs.transpose(self, axes) |
| |
| def reshape(self, *shape, order="C"): |
| # arr.reshape(shape) and arr.reshape(*shape) |
| return _funcs.reshape(self, shape, order=order) |
| |
| def sort(self, axis=-1, kind=None, order=None): |
| # ndarray.sort works in-place |
| _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) |
| |
| def item(self, *args): |
| # Mimic NumPy's implementation with three special cases (no arguments, |
| # a flat index and a multi-index): |
| # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 |
| if args == (): |
| return self.tensor.item() |
| elif len(args) == 1: |
| # int argument |
| return self.ravel()[args[0]] |
| else: |
| return self.__getitem__(args) |
| |
| def __getitem__(self, index): |
| tensor = self.tensor |
| |
| def neg_step(i, s): |
| if not (isinstance(s, slice) and s.step is not None and s.step < 0): |
| return s |
| |
| nonlocal tensor |
| tensor = torch.flip(tensor, (i,)) |
| |
| # Account for the fact that a slice includes the start but not the end |
| assert isinstance(s.start, int) or s.start is None |
| assert isinstance(s.stop, int) or s.stop is None |
| start = s.stop + 1 if s.stop else None |
| stop = s.start + 1 if s.start else None |
| |
| return slice(start, stop, -s.step) |
| |
| if isinstance(index, Sequence): |
| index = type(index)(neg_step(i, s) for i, s in enumerate(index)) |
| else: |
| index = neg_step(0, index) |
| index = _util.ndarrays_to_tensors(index) |
| index = _upcast_int_indices(index) |
| return ndarray(tensor.__getitem__(index)) |
| |
| def __setitem__(self, index, value): |
| index = _util.ndarrays_to_tensors(index) |
| index = _upcast_int_indices(index) |
| |
| if not _dtypes_impl.is_scalar(value): |
| value = normalize_array_like(value) |
| value = _util.cast_if_needed(value, self.tensor.dtype) |
| |
| return self.tensor.__setitem__(index, value) |
| |
| take = _funcs.take |
| put = _funcs.put |
| |
| def __dlpack__(self, *, stream=None): |
| return self.tensor.__dlpack__(stream=stream) |
| |
| def __dlpack_device__(self): |
| return self.tensor.__dlpack_device__() |
| |
| |
| def _tolist(obj): |
| """Recursively convert tensors into lists.""" |
| a1 = [] |
| for elem in obj: |
| if isinstance(elem, (list, tuple)): |
| elem = _tolist(elem) |
| if isinstance(elem, ndarray): |
| a1.append(elem.tensor.tolist()) |
| else: |
| a1.append(elem) |
| return a1 |
| |
| |
| # This is the ideally the only place which talks to ndarray directly. |
| # The rest goes through asarray (preferred) or array. |
| |
| |
| def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): |
| if subok is not False: |
| raise NotImplementedError("'subok' parameter is not supported.") |
| if like is not None: |
| raise NotImplementedError("'like' parameter is not supported.") |
| if order != "K": |
| raise NotImplementedError |
| |
| # a happy path |
| if ( |
| isinstance(obj, ndarray) |
| and copy is False |
| and dtype is None |
| and ndmin <= obj.ndim |
| ): |
| return obj |
| |
| if isinstance(obj, (list, tuple)): |
| # FIXME and they have the same dtype, device, etc |
| if obj and all(isinstance(x, torch.Tensor) for x in obj): |
| # list of arrays: *under torch.Dynamo* these are FakeTensors |
| obj = torch.stack(obj) |
| else: |
| # XXX: remove tolist |
| # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists |
| obj = _tolist(obj) |
| |
| # is obj an ndarray already? |
| if isinstance(obj, ndarray): |
| obj = obj.tensor |
| |
| # is a specific dtype requested? |
| torch_dtype = None |
| if dtype is not None: |
| torch_dtype = _dtypes.dtype(dtype).torch_dtype |
| |
| tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) |
| return ndarray(tensor) |
| |
| |
| def asarray(a, dtype=None, order="K", *, like=None): |
| return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) |
| |
| |
| def ascontiguousarray(a, dtype=None, *, like=None): |
| arr = asarray(a, dtype=dtype, like=like) |
| if not arr.tensor.is_contiguous(): |
| arr.tensor = arr.tensor.contiguous() |
| return arr |
| |
| |
| def from_dlpack(x, /): |
| t = torch.from_dlpack(x) |
| return ndarray(t) |
| |
| |
| def _extract_dtype(entry): |
| try: |
| dty = _dtypes.dtype(entry) |
| except Exception: |
| dty = asarray(entry).dtype |
| return dty |
| |
| |
| def can_cast(from_, to, casting="safe"): |
| from_ = _extract_dtype(from_) |
| to_ = _extract_dtype(to) |
| |
| return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) |
| |
| |
| def result_type(*arrays_and_dtypes): |
| tensors = [] |
| for entry in arrays_and_dtypes: |
| try: |
| t = asarray(entry).tensor |
| except (RuntimeError, ValueError, TypeError): |
| dty = _dtypes.dtype(entry) |
| t = torch.empty(1, dtype=dty.torch_dtype) |
| tensors.append(t) |
| |
| torch_dtype = _dtypes_impl.result_type_impl(*tensors) |
| return _dtypes.dtype(torch_dtype) |