| import io |
| |
| import torch |
| from ._utils import _type, _cuda |
| from torch.types import Storage |
| from typing import Any, TypeVar, Type, Union, cast |
| import copy |
| import collections |
| from functools import lru_cache |
| try: |
| import numpy as np |
| HAS_NUMPY = True |
| except ModuleNotFoundError: |
| np = None # type: ignore[assignment] |
| |
| T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]') |
| class _StorageBase(object): |
| _cdata: Any |
| is_cuda: bool = False |
| is_sparse: bool = False |
| is_sparse_csr: bool = False |
| device: torch.device |
| |
| def __init__(self, *args, **kwargs): ... # noqa: E704 |
| def __len__(self) -> int: ... # noqa: E704 |
| def __getitem__(self, idx): ... # noqa: E704 |
| def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704 |
| def nbytes(self) -> int: ... # noqa: E704 |
| |
| def size(self) -> int: |
| return self.nbytes() |
| |
| def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704 |
| def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704 |
| def element_size(self) -> int: ... # noqa: E704 |
| def get_device(self) -> int: ... # noqa: E704 |
| def data_ptr(self) -> int: ... # noqa: E704 |
| |
| # Defined in torch/csrc/generic/StorageSharing.cpp |
| def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704 |
| def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704 |
| @classmethod |
| def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704 |
| @classmethod |
| def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704 |
| @classmethod |
| def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704 |
| @classmethod |
| def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704 |
| @classmethod |
| def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704 |
| @classmethod |
| def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704 |
| def _shared_decref(self) -> T: ... # noqa: E704 |
| def _write_file(self, *args, **kwargs): ... # noqa: E704 |
| def resize_(self, size: int): ... # noqa: E704 |
| def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704 |
| def is_pinned(self) -> bool: ... # noqa: E704 |
| def _set_from_file(self, *args, **kwargs): ... # noqa: E704 |
| def _set_cdata(self, *args, **kwargs): ... # noqa: E704 |
| def _share_cuda_(self, *args, **kwargs): ... # noqa: E704 |
| def is_shared(self) -> bool: ... # noqa: E704 |
| @classmethod |
| def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704 |
| def _shared_incref(self, *args, **kwargs): ... # noqa: E704 |
| @classmethod |
| def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704 |
| |
| def __str__(self): |
| info_str = ( |
| f'[{torch.typename(self)}(device={self.device}) ' |
| f'of size {len(self)}]') |
| if self.device.type == 'meta': |
| return '...\n' + info_str |
| else: |
| data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size())) |
| return data_str + '\n' + info_str |
| |
| def __repr__(self): |
| return str(self) |
| |
| def __iter__(self): |
| return iter(map(lambda i: self[i], range(self.size()))) |
| |
| def __copy__(self): |
| return self.clone() |
| |
| def __deepcopy__(self, memo): |
| memo = memo.setdefault('torch', {}) |
| if self._cdata in memo: |
| return memo[self._cdata] |
| new_storage = self.clone() |
| memo[self._cdata] = new_storage |
| return new_storage |
| |
| def __reduce__(self): |
| b = io.BytesIO() |
| torch.save(self, b, _use_new_zipfile_serialization=False) |
| return (_load_from_bytes, (b.getvalue(),)) |
| |
| def __sizeof__(self): |
| return super(_StorageBase, self).__sizeof__() + self.size() |
| |
| def clone(self): |
| """Returns a copy of this storage""" |
| return type(self)(self.nbytes(), device=self.device).copy_(self) |
| |
| def tolist(self): |
| """Returns a list containing the elements of this storage""" |
| return list(self) |
| |
| def cpu(self): |
| """Returns a CPU copy of this storage if it's not already on the CPU""" |
| if self.device.type != 'cpu': |
| return torch._UntypedStorage(self.size()).copy_(self, False) |
| else: |
| return self |
| |
| def _to(self, dtype): |
| if not isinstance(dtype, torch.dtype): |
| raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") |
| storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage() |
| if storage.data_ptr() == self.data_ptr(): |
| storage = storage.clone() |
| return storage |
| |
| def double(self): |
| """Casts this storage to double type""" |
| return self._to(torch.double) |
| |
| def float(self): |
| """Casts this storage to float type""" |
| return self._to(torch.float) |
| |
| def half(self): |
| """Casts this storage to half type""" |
| return self._to(torch.half) |
| |
| def long(self): |
| """Casts this storage to long type""" |
| return self._to(torch.long) |
| |
| def int(self): |
| """Casts this storage to int type""" |
| return self._to(torch.int) |
| |
| def short(self): |
| """Casts this storage to short type""" |
| return self._to(torch.short) |
| |
| def char(self): |
| """Casts this storage to char type""" |
| return self._to(torch.int8) |
| |
| def byte(self): |
| """Casts this storage to byte type""" |
| return self._to(torch.uint8) |
| |
| def bool(self): |
| """Casts this storage to bool type""" |
| return self._to(torch.bool) |
| |
| def bfloat16(self): |
| """Casts this storage to bfloat16 type""" |
| return self._to(torch.bfloat16) |
| |
| def complex_double(self): |
| """Casts this storage to complex double type""" |
| return self._to(torch.cdouble) |
| |
| def complex_float(self): |
| """Casts this storage to complex float type""" |
| return self._to(torch.cfloat) |
| |
| def pin_memory(self): |
| """Copies the storage to pinned memory, if it's not already pinned.""" |
| if self.is_cuda: |
| raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") |
| import torch.cuda |
| allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined] |
| return type(self)(self.size(), allocator=allocator).copy_(self) |
| |
| def share_memory_(self): |
| """Moves the storage to shared memory. |
| |
| This is a no-op for storages already in shared memory and for CUDA |
| storages, which do not need to be moved for sharing across processes. |
| Storages in shared memory cannot be resized. |
| |
| Returns: self |
| """ |
| from torch.multiprocessing import get_sharing_strategy |
| if self.is_cuda: |
| pass # CUDA doesn't use POSIX shared memory |
| elif get_sharing_strategy() == 'file_system': |
| self._share_filename_cpu_() |
| else: |
| self._share_fd_cpu_() |
| return self |
| |
| @classmethod |
| def _new_shared(cls, size, *, device='cpu'): |
| """Creates a new storage in shared memory with the same data type""" |
| from torch.multiprocessing import get_sharing_strategy |
| device = torch.device(device) |
| if device.type == 'cuda': |
| return cls(size, device=device) |
| elif get_sharing_strategy() == 'file_system': |
| return cls._new_using_filename_cpu(size) |
| else: |
| return cls._new_using_fd_cpu(size) |
| |
| def _untyped(self): |
| return self |
| |
| |
| class _UntypedStorage(torch._C.StorageBase, _StorageBase): |
| def __getitem__(self, *args, **kwargs): |
| if self.device.type == 'meta': |
| raise NotImplementedError("Not available for 'meta' device type") |
| return super().__getitem__(*args, **kwargs) |
| |
| |
| def _load_from_bytes(b): |
| return torch.load(io.BytesIO(b)) |
| |
| |
| _StorageBase.type = _type # type: ignore[assignment] |
| _StorageBase.cuda = _cuda # type: ignore[assignment] |
| |
| |
| @lru_cache(maxsize=None) |
| def _dtype_to_storage_type_map(): |
| # NOTE: We should no longer add dtypes to this map. This map |
| # is only used for BC/FC with older PyTorch versions. Going forward, |
| # new dtypes of _TypedStorage should not translate to a legacy |
| # <type>Storage class. Instead, new dtypes of _TypedStorage should |
| # be serialized as an _UntypedStorage paired with a torch.dtype |
| return { |
| torch.double: 'DoubleStorage', |
| torch.float: 'FloatStorage', |
| torch.half: 'HalfStorage', |
| torch.long: 'LongStorage', |
| torch.int: 'IntStorage', |
| torch.int16: 'ShortStorage', |
| torch.int8: 'CharStorage', |
| torch.uint8: 'ByteStorage', |
| torch.bool: 'BoolStorage', |
| torch.bfloat16: 'BFloat16Storage', |
| torch.cdouble: 'ComplexDoubleStorage', |
| torch.cfloat: 'ComplexFloatStorage', |
| torch.qint8: 'QInt8Storage', |
| torch.qint32: 'QInt32Storage', |
| torch.quint8: 'QUInt8Storage', |
| torch.quint4x2: 'QUInt4x2Storage', |
| torch.quint2x4: 'QUInt2x4Storage', |
| } |
| |
| @lru_cache(maxsize=None) |
| def _storage_type_to_dtype_map(): |
| dtype_map = { |
| val: key for key, val in _dtype_to_storage_type_map().items()} |
| return dtype_map |
| |
| def _get_storage_from_sequence(sequence, dtype, device): |
| if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: |
| interpret_dtypes = { |
| torch.quint8: torch.uint8, |
| torch.quint4x2: torch.uint8, |
| torch.quint2x4: torch.uint8, |
| torch.qint32: torch.int32, |
| torch.qint8: torch.int8 |
| } |
| tmp_tensor = torch.tensor( |
| sequence, |
| dtype=interpret_dtypes[dtype], |
| device=device) |
| |
| else: |
| tmp_tensor = torch.tensor( |
| sequence, |
| dtype=dtype, |
| device=device) |
| |
| return tmp_tensor.storage()._untyped() |
| |
| def _isint(x): |
| if HAS_NUMPY: |
| return isinstance(x, (int, np.integer)) |
| else: |
| return isinstance(x, int) |
| |
| class _TypedStorage: |
| is_sparse = False |
| |
| dtype: torch.dtype |
| |
| def fill_(self, value): |
| self[0:len(self)] = value |
| return self |
| |
| def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): |
| if cls == torch.storage._LegacyStorage: |
| raise RuntimeError("Only child classes of _LegacyStorage can be instantiated") |
| |
| if cls == _TypedStorage: |
| return super().__new__(cls) |
| |
| else: |
| arg_error_msg = ( |
| f'{cls}.__new__ received an invalid combination ' |
| f'of arguments. Expected one of:\n' |
| ' * no arguments\n' |
| ' * (int size)\n' |
| ' * (Sequence data)\n' |
| ' * (*, _UntypedStorage wrap_storage)') |
| |
| if device is not None: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nKeyword argument 'device' cannot be specified") |
| |
| if dtype is not None: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nKeyword argument 'dtype' cannot be specified") |
| |
| if wrap_storage is None: |
| if len(args) > 1: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nToo many positional arguments") |
| |
| if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence): |
| raise TypeError( |
| arg_error_msg + |
| f"\nArgument type not recognized: {type(args[0])}") |
| |
| return _TypedStorage( |
| *args, |
| dtype=cls.dtype, |
| device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu') |
| |
| else: |
| if len(args) != 0: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nNo positional arguments should be given when using " |
| "'wrap_storage'") |
| |
| if not isinstance(wrap_storage, torch._UntypedStorage): |
| raise TypeError( |
| arg_error_msg + |
| f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}") |
| |
| cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu' |
| |
| if wrap_storage.device.type != cls_device: |
| raise RuntimeError( |
| arg_error_msg + |
| f"\nDevice of 'wrap_storage' must be {cls_device}" |
| f", but got {wrap_storage.device.type}") |
| |
| return _TypedStorage( |
| *args, |
| wrap_storage=wrap_storage, |
| dtype=cls.dtype) |
| |
| def __init__(self, *args, device=None, dtype=None, wrap_storage=None): |
| arg_error_msg = ( |
| '_TypedStorage.__init__ received an invalid combination ' |
| 'of arguments. Expected one of:\n' |
| ' * (*, torch.device device, torch.dtype dtype)\n' |
| ' * (int size, *, torch.device device, torch.dtype dtype)\n' |
| ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n' |
| ' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)') |
| |
| if wrap_storage is not None: |
| if len(args) != 0: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nNo positional arguments should be given when using " |
| "'wrap_storage'") |
| |
| if dtype is None: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nArgument 'dtype' must be specified") |
| |
| if not isinstance(dtype, torch.dtype): |
| raise TypeError( |
| arg_error_msg + |
| f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}") |
| |
| if device is not None: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nArgument 'device' should not be specified when 'wrap_storage' is given") |
| |
| self.dtype = dtype |
| |
| if not isinstance(wrap_storage, torch._UntypedStorage): |
| raise TypeError( |
| arg_error_msg + |
| f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}") |
| |
| self._storage = wrap_storage |
| |
| else: |
| self.dtype = torch.get_default_dtype() if dtype is None else dtype |
| device = torch.device('cpu' if device is None else device) |
| |
| if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: |
| if device.type == 'cuda': |
| raise RuntimeError("Cannot create CUDA storage with quantized dtype") |
| |
| if len(args) == 0: |
| self._storage = torch._UntypedStorage(device=device) |
| |
| elif len(args) == 1: |
| if _isint(args[0]): |
| self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device) |
| elif isinstance(args[0], collections.abc.Sequence): |
| self._storage = _get_storage_from_sequence(args[0], self.dtype, device) |
| else: |
| raise TypeError( |
| arg_error_msg + |
| f"\nArgument type not recognized: {type(args[0])}") |
| |
| else: |
| raise RuntimeError( |
| arg_error_msg + |
| "\nToo many positional arguments") |
| |
| |
| @property |
| def is_cuda(self): |
| return self._storage.device.type == 'cuda' |
| |
| def _untyped(self): |
| return self._storage |
| |
| def _new_wrapped_storage(self, untyped_storage): |
| assert type(untyped_storage) == torch._UntypedStorage |
| |
| if type(self) == _TypedStorage: |
| return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype) |
| else: |
| return type(self)(wrap_storage=untyped_storage) |
| |
| def __len__(self): |
| return self._storage.nbytes() // self.element_size() |
| |
| def _maybe_wrap_index(self, idx, is_stop=False): |
| if idx is None: |
| if is_stop: |
| return self.size() |
| else: |
| return 0 |
| |
| else: |
| if type(idx) != int: |
| raise TypeError( |
| f"can't index a {type(self)} with {type(idx)}") |
| if is_stop: |
| if (idx > self.size()) or (idx < -self.size()): |
| raise IndexError( |
| f'index {idx} out of range for storage of size {self.size()}') |
| if idx > 0: |
| return idx |
| else: |
| return idx % self.size() |
| else: |
| if (idx >= self.size()) or (idx < -self.size()): |
| raise IndexError( |
| f'index {idx} out of range for storage of size {self.size()}') |
| return idx % self.size() |
| |
| def __setitem__(self, idx, value): |
| if not isinstance(idx, (int, slice)): |
| raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") |
| if torch.is_storage(value): |
| raise RuntimeError(f'cannot set item with value type {type(value)}') |
| if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: |
| interpret_dtypes = { |
| torch.quint8: torch.uint8, |
| torch.quint4x2: torch.uint8, |
| torch.quint2x4: torch.uint8, |
| torch.qint32: torch.int32, |
| torch.qint8: torch.int8 |
| } |
| tmp_dtype = interpret_dtypes[self.dtype] |
| tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage( |
| wrap_storage=self._storage, |
| dtype=tmp_dtype)) |
| else: |
| tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) |
| |
| tmp_tensor[idx] = value |
| |
| def __getitem__(self, idx): |
| if self.device.type == 'meta': |
| raise NotImplementedError("Not available for 'meta' device type") |
| |
| # NOTE: Before _TypedStorage existed, indexing with a slice used to be |
| # possible for <type>Storage objects. However, it would return |
| # a storage view, which would be a hassle to implement in _TypedStorage, |
| # so it was disabled |
| if isinstance(idx, slice): |
| raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__') |
| elif not isinstance(idx, int): |
| raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") |
| |
| if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: |
| interpret_dtypes = { |
| torch.quint8: torch.uint8, |
| torch.quint4x2: torch.uint8, |
| torch.quint2x4: torch.uint8, |
| torch.qint32: torch.int32, |
| torch.qint8: torch.int8 |
| } |
| return _TypedStorage( |
| wrap_storage=self._storage, |
| dtype=interpret_dtypes[self.dtype])[idx] |
| |
| idx_wrapped = self._maybe_wrap_index(idx) |
| tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) |
| return tmp_tensor[idx_wrapped].item() |
| |
| def copy_(self, source: T, non_blocking: bool = None): |
| self._storage.copy_(source._untyped(), non_blocking) |
| return self |
| |
| def nbytes(self): |
| return self._storage.nbytes() |
| |
| def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: |
| if dtype is None: |
| legacy_class = self._get_legacy_storage_class() |
| |
| if legacy_class is not None: |
| return legacy_class.__module__ + '.' + legacy_class.__name__ |
| |
| return '.'.join([self.__module__, type(self).__name__]) |
| |
| else: |
| return self._storage.type(dtype, non_blocking) |
| |
| def cuda(self, device=None, non_blocking=False, **kwargs) -> T: |
| if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: |
| raise RuntimeError("Cannot create CUDA storage with quantized dtype") |
| cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs) |
| return self._new_wrapped_storage(cuda_storage) |
| |
| def element_size(self): |
| return torch._utils._element_size(self.dtype) |
| |
| def get_device(self) -> int: |
| return self._storage.get_device() |
| |
| def __str__(self): |
| info_str = ( |
| f'[{torch.typename(self)}(dtype={self.dtype}, ' |
| f'device={self.device}) of size {len(self)}]') |
| if self.device.type == 'meta': |
| return '...\n' + info_str |
| else: |
| data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size())) |
| return data_str + '\n' + info_str |
| |
| def __repr__(self): |
| return str(self) |
| |
| def __iter__(self): |
| return iter(map(lambda i: self[i], range(self.size()))) |
| |
| def __copy__(self): |
| return self._new_wrapped_storage(copy.copy(self._storage)) |
| |
| def __deepcopy__(self, memo): |
| return self._new_wrapped_storage(copy.deepcopy(self._storage, memo)) |
| |
| def __sizeof__(self): |
| return super(_TypedStorage, self).__sizeof__() + self.nbytes() |
| |
| def clone(self): |
| """Returns a copy of this storage""" |
| return self._new_wrapped_storage(self._storage.clone()) |
| |
| def tolist(self): |
| """Returns a list containing the elements of this storage""" |
| return list(self) |
| |
| def cpu(self): |
| """Returns a CPU copy of this storage if it's not already on the CPU""" |
| return self._new_wrapped_storage(self._storage.cpu()) |
| |
| def pin_memory(self): |
| """Coppies the storage to pinned memory, if it's not already pinned.""" |
| return self._new_wrapped_storage(self._storage.pin_memory()) |
| |
| def share_memory_(self): |
| """Moves the storage to shared memory. |
| |
| This is a no-op for storages already in shared memory and for CUDA |
| storages, which do not need to be moved for sharing across processes. |
| Storages in shared memory cannot be resized. |
| |
| Returns: self |
| """ |
| self._storage.share_memory_() |
| return self |
| |
| def _new_shared(self, size, *, device=None): |
| """Creates a new storage in shared memory with the same data type""" |
| if device is None: |
| device = 'cpu' |
| device = torch.device(device) |
| untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device) |
| return _TypedStorage( |
| wrap_storage=untyped_storage, |
| dtype=self.dtype) |
| |
| @property |
| def _cdata(self): |
| return self._storage._cdata |
| |
| @property |
| def device(self): |
| return self._storage.device |
| |
| def size(self): |
| return len(self) |
| |
| def pickle_storage_type(self): |
| try: |
| return _dtype_to_storage_type_map()[self.dtype] |
| except KeyError: |
| raise KeyError(f'dtype {self.dtype} is not recognized') |
| |
| def __reduce__(self): |
| b = io.BytesIO() |
| torch.save(self, b, _use_new_zipfile_serialization=False) |
| return (_load_from_bytes, (b.getvalue(),)) |
| |
| def data_ptr(self): |
| return self._storage.data_ptr() |
| |
| def resize_(self, size): |
| self._storage.resize_(size * self.element_size()) |
| |
| @classmethod |
| def _free_weak_ref(cls, *args, **kwargs): |
| return _UntypedStorage._free_weak_ref(*args, **kwargs) |
| |
| def _weak_ref(self, *args, **kwargs): |
| return self._storage._weak_ref(*args, **kwargs) |
| |
| @classmethod |
| def from_buffer(cls, *args, dtype=None, device=None, **kwargs): |
| if cls == _TypedStorage: |
| dtype = torch.get_default_dtype() if dtype is None else dtype |
| device = torch.device('cpu' if device is None else device) |
| if device.type != 'cpu': |
| raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}') |
| untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) |
| |
| else: |
| if dtype is not None or len(args) == 5: |
| raise RuntimeError(( |
| "from_buffer: 'dtype' can only be specified in " |
| "_UntypedStorage.from_buffer and _TypedStorage.from_buffer")) |
| if device is not None: |
| raise RuntimeError(( |
| "from_buffer: 'device' can only be specified in " |
| "_UntypedStorage.from_buffer and _TypedStorage.from_buffer")) |
| |
| dtype = cls.dtype |
| untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) |
| |
| return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype) |
| |
| def _to(self, dtype): |
| if not isinstance(dtype, torch.dtype): |
| raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") |
| storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage() |
| if storage.data_ptr() == self.data_ptr(): |
| storage = storage.clone() |
| return storage |
| |
| def double(self): |
| """Casts this storage to double type""" |
| return self._to(torch.double) |
| |
| def float(self): |
| """Casts this storage to float type""" |
| return self._to(torch.float) |
| |
| def half(self): |
| """Casts this storage to half type""" |
| return self._to(torch.half) |
| |
| def long(self): |
| """Casts this storage to long type""" |
| return self._to(torch.long) |
| |
| def int(self): |
| """Casts this storage to int type""" |
| return self._to(torch.int) |
| |
| def short(self): |
| """Casts this storage to short type""" |
| return self._to(torch.short) |
| |
| def char(self): |
| """Casts this storage to char type""" |
| return self._to(torch.int8) |
| |
| def byte(self): |
| """Casts this storage to byte type""" |
| return self._to(torch.uint8) |
| |
| def bool(self): |
| """Casts this storage to bool type""" |
| return self._to(torch.bool) |
| |
| def bfloat16(self): |
| """Casts this storage to bfloat16 type""" |
| return self._to(torch.bfloat16) |
| |
| def complex_double(self): |
| """Casts this storage to complex double type""" |
| return self._to(torch.cdouble) |
| |
| def complex_float(self): |
| """Casts this storage to complex float type""" |
| return self._to(torch.cfloat) |
| |
| @classmethod |
| def from_file(cls, filename, shared, size): |
| """ |
| from_file(filename, shared=False, size=0) -> Storage |
| |
| If `shared` is `True`, then memory is shared between all processes. |
| All changes are written to the file. If `shared` is `False`, then the changes on |
| the storage do not affect the file. |
| |
| `size` is the number of elements in the storage. If `shared` is `False`, |
| then the file must contain at least `size * sizeof(Type)` bytes |
| (`Type` is the type of storage). If `shared` is `True` the file will be |
| created if needed. |
| |
| Args: |
| filename (str): file name to map |
| shared (bool): whether to share memory |
| size (int): number of elements in the storage |
| """ |
| if cls == _TypedStorage: |
| raise RuntimeError('from_file can only be called on derived classes') |
| untyped_storage = eval(cls.__module__)._UntypedStorage.from_file( |
| filename, |
| shared, |
| size * torch._utils._element_size(cls.dtype)) |
| storage = cls(wrap_storage=untyped_storage) |
| return storage |
| |
| @classmethod |
| def _expired(cls, *args, **kwargs): |
| return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs) |
| |
| def is_pinned(self): |
| return self._storage.is_pinned() |
| |
| def _write_file(self, *args, **kwargs): |
| return self._storage._write_file(*args, **kwargs) |
| |
| def _set_from_file(self, *args, **kwargs): |
| return self._storage._set_from_file(*args, **kwargs) |
| |
| def _set_cdata(self, *args, **kwargs): |
| return self._storage._set_cdata(*args, **kwargs) |
| |
| def _share_cuda_(self, *args, **kwargs): |
| return self._storage._share_cuda_(*args, **kwargs) |
| |
| def is_shared(self): |
| return self._storage.is_shared() |
| |
| @classmethod |
| def _new_shared_cuda(cls, *args, **kwargs): |
| return torch._UntypedStorage._new_shared_cuda(*args, **kwargs) |
| |
| def _share_filename_cpu_(self, *args, **kwargs): |
| manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs) |
| return manager_handle, storage_handle, size // self.element_size() |
| |
| def _shared_decref(self): |
| self._storage._shared_decref() |
| return self |
| |
| @classmethod |
| def _release_ipc_counter(cls, *args, device=None, **kwargs): |
| return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) |
| |
| def _shared_incref(self, *args, **kwargs): |
| return self._storage._shared_incref(*args, **kwargs) |
| |
| def _share_fd_cpu_(self, *args, **kwargs): |
| fd, size = self._storage._share_fd_cpu_(*args, **kwargs) |
| return fd, size // self.element_size() |
| |
| def _get_legacy_storage_class(self): |
| if self.dtype not in _dtype_to_storage_type_map(): |
| return None |
| |
| storage_name = _dtype_to_storage_type_map()[self.dtype] |
| |
| if self.device.type not in ['cpu', 'cuda']: |
| return None |
| |
| module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.' |
| |
| try: |
| return eval(module + storage_name) |
| except AttributeError: |
| return None |
| |
| _TypedStorage.type.__doc__ = _type.__doc__ |
| _TypedStorage.cuda.__doc__ = _cuda.__doc__ |
| |
| class _LegacyStorageMeta(type): |
| dtype: torch.dtype |
| |
| def __instancecheck__(cls, instance): |
| if type(instance) == _TypedStorage: |
| cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu' |
| return (cls_device == instance.device.type) and (cls.dtype == instance.dtype) |
| return False |
| |
| class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta): |
| @classmethod |
| def _new_shared(cls, size): |
| """Creates a new storage in shared memory with the same data type""" |
| untyped_storage = torch._UntypedStorage._new_shared(size * cls().element_size()) |
| return cls(wrap_storage=untyped_storage) |
| |
| @classmethod |
| def _release_ipc_counter(cls, *args, **kwargs): |
| return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) |
| |
| @classmethod |
| def _new_shared_filename(cls, manager, obj, size): |
| bytes_size = size * torch._utils._element_size(cls.dtype) |
| return cls(wrap_storage=torch._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size)) |
| |
| def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): |
| try: |
| return _storage_type_to_dtype_map()[pickle_storage_type] |
| except KeyError: |
| raise KeyError( |
| f'pickle storage type "{pickle_storage_type}" is not recognized') |