| import torch |
| from torch import Tensor |
| |
| from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union |
| |
| __all__ = [ |
| "BatchSampler", |
| "RandomSampler", |
| "Sampler", |
| "SequentialSampler", |
| "SubsetRandomSampler", |
| "WeightedRandomSampler", |
| ] |
| |
| T_co = TypeVar('T_co', covariant=True) |
| |
| |
| class Sampler(Generic[T_co]): |
| r"""Base class for all Samplers. |
| |
| Every Sampler subclass has to provide an :meth:`__iter__` method, providing a |
| way to iterate over indices of dataset elements, and a :meth:`__len__` method |
| that returns the length of the returned iterators. |
| |
| .. note:: The :meth:`__len__` method isn't strictly required by |
| :class:`~torch.utils.data.DataLoader`, but is expected in any |
| calculation involving the length of a :class:`~torch.utils.data.DataLoader`. |
| """ |
| |
| def __init__(self, data_source: Optional[Sized]) -> None: |
| pass |
| |
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError |
| |
| # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| # |
| # Many times we have an abstract class representing a collection/iterable of |
| # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally |
| # implementing a `__len__` method. In such cases, we must make sure to not |
| # provide a default implementation, because both straightforward default |
| # implementations have their issues: |
| # |
| # + `return NotImplemented`: |
| # Calling `len(subclass_instance)` raises: |
| # TypeError: 'NotImplementedType' object cannot be interpreted as an integer |
| # |
| # + `raise NotImplementedError()`: |
| # This prevents triggering some fallback behavior. E.g., the built-in |
| # `list(X)` tries to call `len(X)` first, and executes a different code |
| # path if the method is not found or `NotImplemented` is returned, while |
| # raising an `NotImplementedError` will propagate and and make the call |
| # fail where it could have use `__iter__` to complete the call. |
| # |
| # Thus, the only two sensible things to do are |
| # |
| # + **not** provide a default `__len__`. |
| # |
| # + raise a `TypeError` instead, which is what Python uses when users call |
| # a method that is not defined on an object. |
| # (@ssnl verifies that this works on at least Python 3.7.) |
| |
| |
| class SequentialSampler(Sampler[int]): |
| r"""Samples elements sequentially, always in the same order. |
| |
| Args: |
| data_source (Dataset): dataset to sample from |
| """ |
| data_source: Sized |
| |
| def __init__(self, data_source: Sized) -> None: |
| self.data_source = data_source |
| |
| def __iter__(self) -> Iterator[int]: |
| return iter(range(len(self.data_source))) |
| |
| def __len__(self) -> int: |
| return len(self.data_source) |
| |
| |
| class RandomSampler(Sampler[int]): |
| r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
| If with replacement, then user can specify :attr:`num_samples` to draw. |
| |
| Args: |
| data_source (Dataset): dataset to sample from |
| replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` |
| num_samples (int): number of samples to draw, default=`len(dataset)`. |
| generator (Generator): Generator used in sampling. |
| """ |
| data_source: Sized |
| replacement: bool |
| |
| def __init__(self, data_source: Sized, replacement: bool = False, |
| num_samples: Optional[int] = None, generator=None) -> None: |
| self.data_source = data_source |
| self.replacement = replacement |
| self._num_samples = num_samples |
| self.generator = generator |
| |
| if not isinstance(self.replacement, bool): |
| raise TypeError("replacement should be a boolean value, but got " |
| "replacement={}".format(self.replacement)) |
| |
| if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
| raise ValueError("num_samples should be a positive integer " |
| "value, but got num_samples={}".format(self.num_samples)) |
| |
| @property |
| def num_samples(self) -> int: |
| # dataset size might change at runtime |
| if self._num_samples is None: |
| return len(self.data_source) |
| return self._num_samples |
| |
| def __iter__(self) -> Iterator[int]: |
| n = len(self.data_source) |
| if self.generator is None: |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| else: |
| generator = self.generator |
| |
| if self.replacement: |
| for _ in range(self.num_samples // 32): |
| yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() |
| yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() |
| else: |
| for _ in range(self.num_samples // n): |
| yield from torch.randperm(n, generator=generator).tolist() |
| yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] |
| |
| def __len__(self) -> int: |
| return self.num_samples |
| |
| |
| class SubsetRandomSampler(Sampler[int]): |
| r"""Samples elements randomly from a given list of indices, without replacement. |
| |
| Args: |
| indices (sequence): a sequence of indices |
| generator (Generator): Generator used in sampling. |
| """ |
| indices: Sequence[int] |
| |
| def __init__(self, indices: Sequence[int], generator=None) -> None: |
| self.indices = indices |
| self.generator = generator |
| |
| def __iter__(self) -> Iterator[int]: |
| for i in torch.randperm(len(self.indices), generator=self.generator): |
| yield self.indices[i] |
| |
| def __len__(self) -> int: |
| return len(self.indices) |
| |
| |
| class WeightedRandomSampler(Sampler[int]): |
| r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). |
| |
| Args: |
| weights (sequence) : a sequence of weights, not necessary summing up to one |
| num_samples (int): number of samples to draw |
| replacement (bool): if ``True``, samples are drawn with replacement. |
| If not, they are drawn without replacement, which means that when a |
| sample index is drawn for a row, it cannot be drawn again for that row. |
| generator (Generator): Generator used in sampling. |
| |
| Example: |
| >>> # xdoctest: +IGNORE_WANT("non-deterministic") |
| >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) |
| [4, 4, 1, 4, 5] |
| >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) |
| [0, 1, 4, 3, 2] |
| """ |
| weights: Tensor |
| num_samples: int |
| replacement: bool |
| |
| def __init__(self, weights: Sequence[float], num_samples: int, |
| replacement: bool = True, generator=None) -> None: |
| if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ |
| num_samples <= 0: |
| raise ValueError("num_samples should be a positive integer " |
| "value, but got num_samples={}".format(num_samples)) |
| if not isinstance(replacement, bool): |
| raise ValueError("replacement should be a boolean value, but got " |
| "replacement={}".format(replacement)) |
| |
| weights_tensor = torch.as_tensor(weights, dtype=torch.double) |
| if len(weights_tensor.shape) != 1: |
| raise ValueError("weights should be a 1d sequence but given " |
| "weights have shape {}".format(tuple(weights_tensor.shape))) |
| |
| self.weights = weights_tensor |
| self.num_samples = num_samples |
| self.replacement = replacement |
| self.generator = generator |
| |
| def __iter__(self) -> Iterator[int]: |
| rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) |
| yield from iter(rand_tensor.tolist()) |
| |
| def __len__(self) -> int: |
| return self.num_samples |
| |
| |
| class BatchSampler(Sampler[List[int]]): |
| r"""Wraps another sampler to yield a mini-batch of indices. |
| |
| Args: |
| sampler (Sampler or Iterable): Base sampler. Can be any iterable object |
| batch_size (int): Size of mini-batch. |
| drop_last (bool): If ``True``, the sampler will drop the last batch if |
| its size would be less than ``batch_size`` |
| |
| Example: |
| >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
| [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
| >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
| [[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
| """ |
| |
| def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: |
| # Since collections.abc.Iterable does not check for `__getitem__`, which |
| # is one way for an object to be an iterable, we don't do an `isinstance` |
| # check here. |
| if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ |
| batch_size <= 0: |
| raise ValueError("batch_size should be a positive integer value, " |
| "but got batch_size={}".format(batch_size)) |
| if not isinstance(drop_last, bool): |
| raise ValueError("drop_last should be a boolean value, but got " |
| "drop_last={}".format(drop_last)) |
| self.sampler = sampler |
| self.batch_size = batch_size |
| self.drop_last = drop_last |
| |
| def __iter__(self) -> Iterator[List[int]]: |
| # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 |
| if self.drop_last: |
| sampler_iter = iter(self.sampler) |
| while True: |
| try: |
| batch = [next(sampler_iter) for _ in range(self.batch_size)] |
| yield batch |
| except StopIteration: |
| break |
| else: |
| batch = [0] * self.batch_size |
| idx_in_batch = 0 |
| for idx in self.sampler: |
| batch[idx_in_batch] = idx |
| idx_in_batch += 1 |
| if idx_in_batch == self.batch_size: |
| yield batch |
| idx_in_batch = 0 |
| batch = [0] * self.batch_size |
| if idx_in_batch > 0: |
| yield batch[:idx_in_batch] |
| |
| def __len__(self) -> int: |
| # Can only be called if self.sampler has __len__ implemented |
| # We cannot enforce this condition, so we turn off typechecking for the |
| # implementation below. |
| # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| if self.drop_last: |
| return len(self.sampler) // self.batch_size # type: ignore[arg-type] |
| else: |
| return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] |