from torch.utils.data.dataloader import ( torch, python_multiprocessing, multiprocessing, ExceptionWrapper, threading, itertools, queue, string_classes, _utils, _DatasetKind, _InfiniteConstantSampler, IterableDataset, SequentialSampler, RandomSampler, BatchSampler) from ._utils.pin_memory import _pin_memory_loop from ._utils.worker import _worker_loop, _ResumeIteration from petrel_client.utils.profile import profileit class DataLoader(object): __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, *, prefetch_factor=2, persistent_workers=False): torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: raise ValueError('num_workers option should be non-negative; ' 'use num_workers=0 to disable multiprocessing.') if timeout < 0: raise ValueError('timeout option should be non-negative') if num_workers == 0 and prefetch_factor != 2: raise ValueError('prefetch_factor option could only be specified in multiprocessing.' 'let num_workers > 0 to enable multiprocessing.') assert prefetch_factor > 0 if persistent_workers and num_workers == 0: raise ValueError('persistent_workers option needs num_workers > 0') self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context # Arg-check dataset related before checking samplers because we want to # tell users that iterable-style datasets are incompatible with custom # samplers first, so that they don't learn that this combo doesn't work # after spending time fixing the custom sampler errors. if isinstance(dataset, IterableDataset): self._dataset_kind = _DatasetKind.Iterable # NOTE [ Custom Samplers and `IterableDataset` ] # # `IterableDataset` does not support custom `batch_sampler` or # `sampler` since the key is irrelevant (unless we support # generator-style dataset one day...). # # For `sampler`, we always create a dummy sampler. This is an # infinite sampler even when the dataset may have an implemented # finite `__len__` because in multi-process data loading, naive # settings will return duplicated data (which may be desired), and # thus using a sampler with length matching that of dataset will # cause data lost (you may have duplicates of the first couple # batches, but never see anything afterwards). Therefore, # `Iterabledataset` always uses an infinite sampler, an instance of # `_InfiniteConstantSampler` defined above. # # A custom `batch_sampler` essentially only controls the batch size. # However, it is unclear how useful it would be since an iterable-style # dataset can handle that within itself. Moreover, it is pointless # in multi-process data loading as the assignment order of batches # to workers is an implementation detail so users can not control # how to batchify each worker's iterable. Thus, we disable this # option. If this turns out to be useful in future, we can re-enable # this, and support custom samplers that specify the assignments to # specific workers. if shuffle is not False: raise ValueError( "DataLoader with IterableDataset: expected unspecified " "shuffle option, but got shuffle={}".format(shuffle)) elif sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "sampler option, but got sampler={}".format(sampler)) elif batch_sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format(batch_sampler)) else: self._dataset_kind = _DatasetKind.Map if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if batch_sampler is not None: # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') batch_size = None drop_last = False elif batch_size is None: # no auto_collation if shuffle or drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with ' 'shuffle, and drop_last') if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn self.persistent_workers = persistent_workers self.__initialized = True self._iterator = None def _get_iterator(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) @property def multiprocessing_context(self): return self.__multiprocessing_context @multiprocessing_context.setter def multiprocessing_context(self, multiprocessing_context): if multiprocessing_context is not None: if self.num_workers > 0: if not multiprocessing._supports_context: raise ValueError('multiprocessing_context relies on Python >= 3.4, with ' 'support for different start methods') if isinstance(multiprocessing_context, string_classes): valid_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_context not in valid_start_methods: raise ValueError( ('multiprocessing_context option ' 'should specify a valid start method in {}, but got ' 'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context)) multiprocessing_context = multiprocessing.get_context( multiprocessing_context) if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): raise ValueError(('multiprocessing_context option should be a valid context ' 'object or a string specifying the start method, but got ' 'multiprocessing_context={}').format(multiprocessing_context)) else: raise ValueError(('multiprocessing_context can only be used with ' 'multi-process loading (num_workers > 0), but got ' 'num_workers={}').format(self.num_workers)) self.__multiprocessing_context = multiprocessing_context def __setattr__(self, attr, val): if self.__initialized and attr in ( 'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self): # When using a single worker the returned iterator should be # created everytime to avoid reseting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused if self.persistent_workers and self.num_workers > 0: if self._iterator is None: self._iterator = self._get_iterator() else: self._iterator._reset(self) return self._iterator else: return self._get_iterator() @property def _auto_collation(self): return self.batch_sampler is not None @property def _index_sampler(self): # The actual sampler used for generating indices for `_DatasetFetcher` # (see _utils/fetch.py) to read data at each time. This would be # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. # We can't change `.sampler` and `.batch_sampler` attributes for BC # reasons. if self._auto_collation: return self.batch_sampler else: return self.sampler def __len__(self): # with iterable-style dataset, this will error return len(self._index_sampler) class _BaseDataLoaderIter(object): def __init__(self, loader): self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = loader.pin_memory and torch.cuda.is_available() self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) self._base_seed = torch.empty((), dtype=torch.int64).random_().item() self._persistent_workers = loader.persistent_workers def __iter__(self): return self def _reset(self, loader, first_iter=False): self._sampler_iter = iter(self._index_sampler) def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def __next__(self): raise NotImplementedError def __len__(self): return len(self._index_sampler) def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError( "{} cannot be pickled", self.__class__.__name__) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data next = __next__ # Python 2 compatibility class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_MultiProcessingDataLoaderIter, self).__init__(loader) assert self._num_workers > 0 assert self._prefetch_factor > 0 if loader.multiprocessing_context is None: multiprocessing_context = multiprocessing else: multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn self._worker_queue_idx_cycle = itertools.cycle( range(self._num_workers)) self._worker_result_queue = multiprocessing_context.Queue() self._worker_pids_set = False self._shutdown = False self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] self._workers = [] for i in range(self._num_workers): index_queue = multiprocessing_context.Queue() index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_worker_loop, args=(self._dataset_kind, self._dataset, index_queue, self._worker_result_queue, self._workers_done_event, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed + i, self._worker_init_fn, i, self._num_workers, self._persistent_workers)) w.daemon = True # NB: Process.start() actually take some time as it needs to # start a process and pass the arguments over via a pipe. # Therefore, we only add a worker to self._workers list after # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self._index_queues.append(index_queue) self._workers.append(w) if self._pin_memory: self._pin_memory_thread_done_event = threading.Event() self._data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_pin_memory_loop, args=(self._worker_result_queue, self._data_queue, torch.cuda.current_device(), self._pin_memory_thread_done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: self._data_queue = self._worker_result_queue _utils.signal_handling._set_worker_pids( id(self), tuple(w.pid for w in self._workers)) _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) def _reset(self, loader, first_iter=False): super()._reset(loader, first_iter) self._send_idx = 0 # idx of the next task to be sent to workers self._rcvd_idx = 0 # idx of the next task to be returned in __next__ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} # always equal to count(v for v in task_info.values() if len(v) == 1) self._tasks_outstanding = 0 # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled if not first_iter: for idx in range(self._num_workers): self._index_queues[idx].put(_ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: data = self._get_data() if isinstance(data, _ResumeIteration): resume_iteration_cnt -= 1 # prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers): self._try_put_index() def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # Tries to fetch data from `self._data_queue` once for a given timeout. # This can also be used as inner loop of fetching without timeout, with # the sender status as the loop condition. # # This raises a `RuntimeError` if any worker died expectedly. This error # can come from either the SIGCHLD handler in `_utils/signal_handling.py` # (only for non-Windows platforms), or the manual check below on errors # and timeouts. # # Returns a 2-tuple: # (bool: whether successfully get data, any: data if successful else None) try: data = self._data_queue.get(timeout=timeout) return (True, data) except Exception as e: # At timeout and error, we manually check whether any worker has # failed. Note that this is the only mechanism for Windows to detect # worker failures. failed_workers = [] for worker_id, w in enumerate(self._workers): if self._workers_status[worker_id] and not w.is_alive(): failed_workers.append(w) self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ', '.join(str(w.pid) for w in failed_workers) raise RuntimeError( 'DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e if isinstance(e, queue.Empty): return (False, None) raise def _get_data(self): # Fetches data from `self._data_queue`. # # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` # in a loop. This is the only mechanism to detect worker failures for # Windows. For other platforms, a SIGCHLD handler is also used for # worker failure detection. # # If `pin_memory=True`, we also need check if `pin_memory_thread` had # died at timeouts. if self._timeout > 0: success, data = self._try_get_data(self._timeout) if success: return data else: raise RuntimeError( 'DataLoader timed out after {} seconds'.format(self._timeout)) elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() if success: return data else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError('Pin memory thread exited unexpectedly') # In this case, `self._data_queue` is a `queue.Queue`,. But we don't # need to call `.task_done()` because we don't use `.join()`. else: while True: success, data = self._try_get_data() if success: return data @profileit def __next__(self): while True: # If the worker responsible for `self._rcvd_idx` has already ended # and was unable to fulfill this task (due to exhausting an `IterableDataset`), # we try to advance `self._rcvd_idx` to find the next valid index. # # This part needs to run in the loop because both the `self._get_data()` # call and `_IterableDatasetStopIteration` check below can mark # extra worker(s) as dead. while self._rcvd_idx < self._send_idx: info = self._task_info[self._rcvd_idx] worker_id = info[0] # has data or is still active if len(info) == 2 or self._workers_status[worker_id]: break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) if not self._persistent_workers: self._shutdown_workers() raise StopIteration # Now `self._rcvd_idx` is the batch index we want to fetch # Check if the next sample has already been generated if len(self._task_info[self._rcvd_idx]) == 2: data = self._task_info.pop(self._rcvd_idx)[1] return self._process_data(data) assert not self._shutdown and self._tasks_outstanding > 0 idx, data = self._get_data() self._tasks_outstanding -= 1 if self._dataset_kind == _DatasetKind.Iterable: # Check for _IterableDatasetStopIteration if isinstance(data, _utils.worker._IterableDatasetStopIteration): if self._persistent_workers: self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) self._try_put_index() continue if idx != self._rcvd_idx: # store out-of-order samples self._task_info[idx] += (data,) else: del self._task_info[idx] return self._process_data(data) next = __next__ # Python 2 compatibility def _try_put_index(self): assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try: index = self._next_index() except StopIteration: return for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: break else: # not found (i.e., didn't break) return self._index_queues[worker_queue_idx].put((self._send_idx, index)) self._task_info[self._send_idx] = (worker_queue_idx,) self._tasks_outstanding += 1 self._send_idx += 1 def _process_data(self, data): self._rcvd_idx += 1 self._try_put_index() if isinstance(data, ExceptionWrapper): data.reraise() return data def _mark_worker_as_unavailable(self, worker_id, shutdown=False): # Mark a worker as having finished its work e.g., due to # exhausting an `IterableDataset`. This should be used only when this # `_MultiProcessingDataLoaderIter` is going to continue running. assert self._workers_status[worker_id] or ( self._persistent_workers and shutdown) # Signal termination to that specific worker. q = self._index_queues[worker_id] # Indicate that no more data will be put on this queue by the current # process. q.put(None) # Note that we don't actually join the worker here, nor do we remove the # worker's pid from C side struct because (1) joining may be slow, and # (2) since we don't join, the worker may still raise error, and we # prefer capturing those, rather than ignoring them, even though they # are raised after the worker has finished its job. # Joinning is deferred to `_shutdown_workers`, which it is called when # all workers finish their jobs (e.g., `IterableDataset` replicas) or # when this iterator is garbage collected. self._workers_status[worker_id] = False assert self._workers_done_event.is_set() == shutdown def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. python_exit_status = _utils.python_exit_status if python_exit_status is True or python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. # See (1) and the second half of the note. if not self._shutdown: self._shutdown = True try: # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. if hasattr(self, '_pin_memory_thread'): # Use hasattr in case error happens before we set the attribute. self._pin_memory_thread_done_event.set() # Send something to pin_memory_thread in case it is waiting # so that it can wake up and check `pin_memory_thread_done_event` self._worker_result_queue.put((None, None)) self._pin_memory_thread.join() self._worker_result_queue.cancel_join_thread() self._worker_result_queue.close() # Exit workers now. self._workers_done_event.set() for worker_id in range(len(self._workers)): # Get number of workers from `len(self._workers)` instead of # `self._num_workers` in case we error before starting all # workers. # If we are using workers_status with persistent_workers # we have to shut it down because the worker is paused if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable( worker_id, shutdown=True) for w in self._workers: w.join() for q in self._index_queues: q.cancel_join_thread() q.close() finally: # Even though all this function does is putting into queues that # we have called `cancel_join_thread` on, weird things can # happen when a worker is killed by a signal, e.g., hanging in # `Event.set()`. So we need to guard this with SIGCHLD handler, # and remove pids from the C side data structure only at the # end. # # FIXME: Unfortunately, for Windows, we are missing a worker # error detection mechanism here in this function, as it # doesn't provide a SIGCHLD handler. if self._worker_pids_set: _utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False for w in self._workers: if w.is_alive(): # Existing mechanisms try to make the workers exit # peacefully, but in case that we unfortunately reach # here, which we shouldn't, (e.g., pytorch/pytorch#39570), # we kill the worker. w.terminate() def __del__(self): self._shutdown_workers()