blob: 8a8d536b79857c7c289a72a809b3434ed8a6e174 [file] [log] [blame]
import time
from typing import Any, List
import torch.utils.data.backward_compatibility
import torch.utils.data.graph_settings
from torch.utils.data import DataLoader, IterDataPipe, communication
from torch.utils.data.datapipes.iter import IterableWrapper
__all__ = [
"DataLoader2",
]
class _ThreadingDataLoader2:
def __init__(self, datapipe, num_workers=0, collate_fn=None):
self.threads = []
self.datapipes = []
self.collate_fn = collate_fn
for worker_id in range(num_workers):
(thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe)
torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id)
thread.start()
self.threads.append((thread, req_queue, res_queue)) # These queues are independent
local_datapipe = communication.iter.QueueWrapper(
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
self.datapipes.append(local_datapipe)
def __iter__(self):
not_available = False
forever = True
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
try:
value = dp.nonblocking_next()
yield value
except StopIteration:
exclude_datapipes.append(dp)
except communication.iter.NotAvailable:
not_available = True
if not_available:
time.sleep(0.001)
def __del__(self):
self._cleanup_all_threads()
def _cleanup_all_threads(self):
def clean_me(thread, req_queue, res_queue):
req_queue.put(communication.messages.TerminateRequest())
_ = res_queue.get()
thread.join()
for thread, req_queue, res_queue in self.threads:
clean_me(thread, req_queue, res_queue)
class DataLoader2:
def __new__(cls,
dataset,
batch_size=1,
shuffle=None,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
*,
prefetch_factor=2,
persistent_workers=False,
batch_outside_worker=False,
parallelism_mode='mp'):
if isinstance(dataset, IterDataPipe):
data_loader: Any = None
if batch_sampler is not None:
raise Exception(
'batch_sampler is not yet supported by DataPipes')
if sampler is not None:
raise Exception(
'sampler is not yet supported by DataPipes')
datapipe = dataset
datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment]
if batch_outside_worker and pin_memory:
raise Exception(
'pin_memory is not yet compatible with batch_outside_worker')
if not batch_outside_worker:
if batch_size is not None:
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
# Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing
# for Iterable, but required to set Pipes correctly.
data_loader = DataLoader(datapipe,
batch_size=None, # Replaced by .batch DataPipe
shuffle=shuffle,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=False, # Replaced by .batch DataPipe
timeout=timeout,
worker_init_fn=worker_init_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
elif parallelism_mode == 'thread':
if collate_fn is not None and not batch_outside_worker:
datapipe = datapipe.map(collate_fn)
if pin_memory:
raise Exception(
'pin_memory is not yet supported by DataPipes with Threading')
if worker_init_fn is not None:
raise Exception(
'worker_init_fn is not yet supported by DataPipes with Threading')
data_loader = _ThreadingDataLoader2(datapipe,
num_workers=num_workers,
collate_fn=collate_fn)
else:
raise Exception('Unsupported parallelism mode', parallelism_mode)
if not batch_outside_worker:
return data_loader
else:
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
datapipe = IterableWrapper(data_loader).batch(
batch_size, drop_last=drop_last).map(collate_fn)
return datapipe
else:
if parallelism_mode == 'thread':
raise Exception(
'thread parallelism mode is not supported for old DataSets')
return DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)