| import bisect |
| import warnings |
| import math |
| from typing import ( |
| Generic, |
| Iterable, |
| Iterator, |
| List, |
| Optional, |
| Sequence, |
| Tuple, |
| TypeVar, |
| Union, |
| Dict |
| ) |
| |
| # No 'default_generator' in torch/__init__.pyi |
| from torch import default_generator, randperm |
| from torch._utils import _accumulate |
| |
| from ... import Generator, Tensor |
| |
| __all__ = [ |
| "Dataset", |
| "IterableDataset", |
| "TensorDataset", |
| "StackDataset", |
| "ConcatDataset", |
| "ChainDataset", |
| "Subset", |
| "random_split", |
| ] |
| |
| T_co = TypeVar('T_co', covariant=True) |
| T = TypeVar('T') |
| T_dict = Dict[str, T_co] |
| T_tuple = Tuple[T_co, ...] |
| T_stack = TypeVar('T_stack', T_tuple, T_dict) |
| |
| |
| class Dataset(Generic[T_co]): |
| r"""An abstract class representing a :class:`Dataset`. |
| |
| All datasets that represent a map from keys to data samples should subclass |
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a |
| data sample for a given key. Subclasses could also optionally overwrite |
| :meth:`__len__`, which is expected to return the size of the dataset by many |
| :class:`~torch.utils.data.Sampler` implementations and the default options |
| of :class:`~torch.utils.data.DataLoader`. Subclasses could also |
| optionally implement :meth:`__getitems__`, for speedup batched samples |
| loading. This method accepts list of indices of samples of batch and returns |
| list of samples. |
| |
| .. note:: |
| :class:`~torch.utils.data.DataLoader` by default constructs an index |
| sampler that yields integral indices. To make it work with a map-style |
| dataset with non-integral indices/keys, a custom sampler must be provided. |
| """ |
| |
| def __getitem__(self, index) -> T_co: |
| raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") |
| |
| # def __getitems__(self, indices: List) -> List[T_co]: |
| # Not implemented to prevent false-positives in fetcher check in |
| # torch.utils.data._utils.fetch._MapDatasetFetcher |
| |
| def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': |
| return ConcatDataset([self, other]) |
| |
| # No `def __len__(self)` default? |
| # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| # in pytorch/torch/utils/data/sampler.py |
| |
| |
| class IterableDataset(Dataset[T_co]): |
| r"""An iterable Dataset. |
| |
| All datasets that represent an iterable of data samples should subclass it. |
| Such form of datasets is particularly useful when data come from a stream. |
| |
| All subclasses should overwrite :meth:`__iter__`, which would return an |
| iterator of samples in this dataset. |
| |
| When a subclass is used with :class:`~torch.utils.data.DataLoader`, each |
| item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` |
| iterator. When :attr:`num_workers > 0`, each worker process will have a |
| different copy of the dataset object, so it is often desired to configure |
| each copy independently to avoid having duplicate data returned from the |
| workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker |
| process, returns information about the worker. It can be used in either the |
| dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's |
| :attr:`worker_init_fn` option to modify each copy's behavior. |
| |
| Example 1: splitting workload across all workers in :meth:`__iter__`:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) |
| >>> # xdoctest: +SKIP("Fails on MacOS12") |
| >>> class MyIterableDataset(torch.utils.data.IterableDataset): |
| ... def __init__(self, start, end): |
| ... super(MyIterableDataset).__init__() |
| ... assert end > start, "this example code only works with end >= start" |
| ... self.start = start |
| ... self.end = end |
| ... |
| ... def __iter__(self): |
| ... worker_info = torch.utils.data.get_worker_info() |
| ... if worker_info is None: # single-process data loading, return the full iterator |
| ... iter_start = self.start |
| ... iter_end = self.end |
| ... else: # in a worker process |
| ... # split workload |
| ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) |
| ... worker_id = worker_info.id |
| ... iter_start = self.start + worker_id * per_worker |
| ... iter_end = min(iter_start + per_worker, self.end) |
| ... return iter(range(iter_start, iter_end)) |
| ... |
| >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. |
| >>> ds = MyIterableDataset(start=3, end=7) |
| |
| >>> # Single-process loading |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) |
| [tensor([3]), tensor([4]), tensor([5]), tensor([6])] |
| |
| >>> # xdoctest: +REQUIRES(POSIX) |
| >>> # Mult-process loading with two worker processes |
| >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. |
| >>> # xdoctest: +IGNORE_WANT("non deterministic") |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) |
| [tensor([3]), tensor([5]), tensor([4]), tensor([6])] |
| |
| >>> # With even more workers |
| >>> # xdoctest: +IGNORE_WANT("non deterministic") |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) |
| [tensor([3]), tensor([5]), tensor([4]), tensor([6])] |
| |
| Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) |
| >>> class MyIterableDataset(torch.utils.data.IterableDataset): |
| ... def __init__(self, start, end): |
| ... super(MyIterableDataset).__init__() |
| ... assert end > start, "this example code only works with end >= start" |
| ... self.start = start |
| ... self.end = end |
| ... |
| ... def __iter__(self): |
| ... return iter(range(self.start, self.end)) |
| ... |
| >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. |
| >>> ds = MyIterableDataset(start=3, end=7) |
| |
| >>> # Single-process loading |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) |
| [3, 4, 5, 6] |
| >>> |
| >>> # Directly doing multi-process loading yields duplicate data |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) |
| [3, 3, 4, 4, 5, 5, 6, 6] |
| |
| >>> # Define a `worker_init_fn` that configures each dataset copy differently |
| >>> def worker_init_fn(worker_id): |
| ... worker_info = torch.utils.data.get_worker_info() |
| ... dataset = worker_info.dataset # the dataset copy in this worker process |
| ... overall_start = dataset.start |
| ... overall_end = dataset.end |
| ... # configure the dataset to only process the split workload |
| ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) |
| ... worker_id = worker_info.id |
| ... dataset.start = overall_start + worker_id * per_worker |
| ... dataset.end = min(dataset.start + per_worker, overall_end) |
| ... |
| |
| >>> # Mult-process loading with the custom `worker_init_fn` |
| >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) |
| [3, 5, 4, 6] |
| |
| >>> # With even more workers |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) |
| [3, 4, 5, 6] |
| """ |
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.") |
| |
| def __add__(self, other: Dataset[T_co]): |
| return ChainDataset([self, other]) |
| |
| # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. |
| # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| |
| |
| class TensorDataset(Dataset[Tuple[Tensor, ...]]): |
| r"""Dataset wrapping tensors. |
| |
| Each sample will be retrieved by indexing tensors along the first dimension. |
| |
| Args: |
| *tensors (Tensor): tensors that have the same size of the first dimension. |
| """ |
| tensors: Tuple[Tensor, ...] |
| |
| def __init__(self, *tensors: Tensor) -> None: |
| assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" |
| self.tensors = tensors |
| |
| def __getitem__(self, index): |
| return tuple(tensor[index] for tensor in self.tensors) |
| |
| def __len__(self): |
| return self.tensors[0].size(0) |
| |
| |
| class StackDataset(Dataset[T_stack]): |
| r"""Dataset as a stacking of multiple datasets. |
| |
| This class is useful to assemble different parts of complex input data, given as datasets. |
| |
| Example: |
| >>> # xdoctest: +SKIP |
| >>> images = ImageDataset() |
| >>> texts = TextDataset() |
| >>> tuple_stack = StackDataset(images, texts) |
| >>> tuple_stack[0] == (images[0], texts[0]) |
| >>> dict_stack = StackDataset(image=images, text=texts) |
| >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} |
| |
| Args: |
| *args (Dataset): Datasets for stacking returned as tuple. |
| **kwargs (Dataset): Datasets for stacking returned as dict. |
| """ |
| datasets: Union[tuple, dict] |
| |
| def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None: |
| if args: |
| if kwargs: |
| raise ValueError("Supported either ``tuple``- (via ``args``) or" |
| "``dict``- (via ``kwargs``) like input/output, but both types are given.") |
| self._length = len(args[0]) # type: ignore[arg-type] |
| if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] |
| raise ValueError("Size mismatch between datasets") |
| self.datasets = args |
| elif kwargs: |
| tmp = list(kwargs.values()) |
| self._length = len(tmp[0]) # type: ignore[arg-type] |
| if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] |
| raise ValueError("Size mismatch between datasets") |
| self.datasets = kwargs |
| else: |
| raise ValueError("At least one dataset should be passed") |
| |
| def __getitem__(self, index): |
| if isinstance(self.datasets, dict): |
| return {k: dataset[index] for k, dataset in self.datasets.items()} |
| return tuple(dataset[index] for dataset in self.datasets) |
| |
| def __len__(self): |
| return self._length |
| |
| |
| class ConcatDataset(Dataset[T_co]): |
| r"""Dataset as a concatenation of multiple datasets. |
| |
| This class is useful to assemble different existing datasets. |
| |
| Args: |
| datasets (sequence): List of datasets to be concatenated |
| """ |
| datasets: List[Dataset[T_co]] |
| cumulative_sizes: List[int] |
| |
| @staticmethod |
| def cumsum(sequence): |
| r, s = [], 0 |
| for e in sequence: |
| l = len(e) |
| r.append(l + s) |
| s += l |
| return r |
| |
| def __init__(self, datasets: Iterable[Dataset]) -> None: |
| super().__init__() |
| self.datasets = list(datasets) |
| assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] |
| for d in self.datasets: |
| assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" |
| self.cumulative_sizes = self.cumsum(self.datasets) |
| |
| def __len__(self): |
| return self.cumulative_sizes[-1] |
| |
| def __getitem__(self, idx): |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| return self.datasets[dataset_idx][sample_idx] |
| |
| @property |
| def cummulative_sizes(self): |
| warnings.warn("cummulative_sizes attribute is renamed to " |
| "cumulative_sizes", DeprecationWarning, stacklevel=2) |
| return self.cumulative_sizes |
| |
| |
| class ChainDataset(IterableDataset): |
| r"""Dataset for chaining multiple :class:`IterableDataset` s. |
| |
| This class is useful to assemble different existing dataset streams. The |
| chaining operation is done on-the-fly, so concatenating large-scale |
| datasets with this class will be efficient. |
| |
| Args: |
| datasets (iterable of IterableDataset): datasets to be chained together |
| """ |
| def __init__(self, datasets: Iterable[Dataset]) -> None: |
| super().__init__() |
| self.datasets = datasets |
| |
| def __iter__(self): |
| for d in self.datasets: |
| assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" |
| yield from d |
| |
| def __len__(self): |
| total = 0 |
| for d in self.datasets: |
| assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" |
| total += len(d) # type: ignore[arg-type] |
| return total |
| |
| |
| class Subset(Dataset[T_co]): |
| r""" |
| Subset of a dataset at specified indices. |
| |
| Args: |
| dataset (Dataset): The whole Dataset |
| indices (sequence): Indices in the whole set selected for subset |
| """ |
| dataset: Dataset[T_co] |
| indices: Sequence[int] |
| |
| def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: |
| self.dataset = dataset |
| self.indices = indices |
| |
| def __getitem__(self, idx): |
| if isinstance(idx, list): |
| return self.dataset[[self.indices[i] for i in idx]] |
| return self.dataset[self.indices[idx]] |
| |
| def __getitems__(self, indices: List[int]) -> List[T_co]: |
| # add batched sampling support when parent dataset supports it. |
| # see torch.utils.data._utils.fetch._MapDatasetFetcher |
| if callable(getattr(self.dataset, "__getitems__", None)): |
| return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] |
| else: |
| return [self.dataset[self.indices[idx]] for idx in indices] |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| |
| def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]], |
| generator: Optional[Generator] = default_generator) -> List[Subset[T]]: |
| r""" |
| Randomly split a dataset into non-overlapping new datasets of given lengths. |
| |
| If a list of fractions that sum up to 1 is given, |
| the lengths will be computed automatically as |
| floor(frac * len(dataset)) for each fraction provided. |
| |
| After computing the lengths, if there are any remainders, 1 count will be |
| distributed in round-robin fashion to the lengths |
| until there are no remainders left. |
| |
| Optionally fix the generator for reproducible results, e.g.: |
| |
| Example: |
| >>> # xdoctest: +SKIP |
| >>> generator1 = torch.Generator().manual_seed(42) |
| >>> generator2 = torch.Generator().manual_seed(42) |
| >>> random_split(range(10), [3, 7], generator=generator1) |
| >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) |
| |
| Args: |
| dataset (Dataset): Dataset to be split |
| lengths (sequence): lengths or fractions of splits to be produced |
| generator (Generator): Generator used for the random permutation. |
| """ |
| if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: |
| subset_lengths: List[int] = [] |
| for i, frac in enumerate(lengths): |
| if frac < 0 or frac > 1: |
| raise ValueError(f"Fraction at index {i} is not between 0 and 1") |
| n_items_in_split = int( |
| math.floor(len(dataset) * frac) # type: ignore[arg-type] |
| ) |
| subset_lengths.append(n_items_in_split) |
| remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] |
| # add 1 to all the lengths in round-robin fashion until the remainder is 0 |
| for i in range(remainder): |
| idx_to_add_at = i % len(subset_lengths) |
| subset_lengths[idx_to_add_at] += 1 |
| lengths = subset_lengths |
| for i, length in enumerate(lengths): |
| if length == 0: |
| warnings.warn(f"Length of split at index {i} is 0. " |
| f"This might result in an empty dataset.") |
| |
| # Cannot verify that dataset is Sized |
| if sum(lengths) != len(dataset): # type: ignore[arg-type] |
| raise ValueError("Sum of input lengths does not equal the length of the input dataset!") |
| |
| indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] |
| return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] |