|
from petrel_client.utils.profile import profileit, wrap_with_stat_qsize, WORKER_LOOP_PROFILE_COUNT |
|
from torch.utils.data._utils.worker import ( |
|
torch, random, queue, ExceptionWrapper, ManagerWatchdog, WorkerInfo, |
|
_IterableDatasetStopIteration, signal_handling, MP_STATUS_CHECK_INTERVAL) |
|
|
|
|
|
class _ResumeIteration(object): |
|
r"""Dummy class used to resume the fetching when worker reuse is enabled""" |
|
pass |
|
|
|
|
|
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, |
|
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, |
|
num_workers, persistent_workers): |
|
|
|
|
|
index_queue_get = wrap_with_stat_qsize( |
|
index_queue, index_queue.get, '_worker_loop.index_queue.qsize', count=WORKER_LOOP_PROFILE_COUNT) |
|
data_queue_put = wrap_with_stat_qsize( |
|
data_queue, data_queue.put, '_worker_loop.data_queue.qsize', count=WORKER_LOOP_PROFILE_COUNT) |
|
index_queue.get = index_queue_get |
|
data_queue.put = data_queue_put |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
signal_handling._set_worker_signal_handlers() |
|
|
|
torch.set_num_threads(1) |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, |
|
seed=seed, dataset=dataset) |
|
|
|
from torch.utils.data._utils import worker as pt_worker |
|
pt_worker._worker_info = _worker_info |
|
|
|
from torch.utils.data import _DatasetKind |
|
|
|
init_exception = None |
|
|
|
try: |
|
if init_fn is not None: |
|
init_fn(worker_id) |
|
|
|
fetcher = _DatasetKind.create_fetcher( |
|
dataset_kind, dataset, auto_collation, collate_fn, drop_last) |
|
except Exception: |
|
init_exception = ExceptionWrapper( |
|
where="in DataLoader worker process {}".format(worker_id)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iteration_end = False |
|
|
|
watchdog = ManagerWatchdog() |
|
cnt = 1 |
|
brk = 0 |
|
|
|
def loop(): |
|
nonlocal iteration_end, init_exception, fetcher |
|
try: |
|
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) |
|
except queue.Empty: |
|
return cnt |
|
if isinstance(r, _ResumeIteration): |
|
|
|
data_queue.put(r) |
|
iteration_end = False |
|
|
|
fetcher = _DatasetKind.create_fetcher( |
|
dataset_kind, dataset, auto_collation, collate_fn, drop_last) |
|
return cnt |
|
elif r is None: |
|
|
|
assert done_event.is_set() or iteration_end |
|
return brk |
|
elif done_event.is_set() or iteration_end: |
|
|
|
|
|
|
|
return cnt |
|
idx, index = r |
|
if init_exception is not None: |
|
data = init_exception |
|
init_exception = None |
|
else: |
|
try: |
|
data = fetcher.fetch(index) |
|
except Exception as e: |
|
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: |
|
data = _IterableDatasetStopIteration(worker_id) |
|
|
|
|
|
|
|
iteration_end = True |
|
else: |
|
|
|
|
|
|
|
data = ExceptionWrapper( |
|
where="in DataLoader worker process {}".format(worker_id)) |
|
data_queue.put((idx, data)) |
|
del data, idx, index, r |
|
|
|
loop = profileit(loop, name='_worker_loop.loop', |
|
count=WORKER_LOOP_PROFILE_COUNT) |
|
while watchdog.is_alive(): |
|
if loop() == brk: |
|
break |
|
except KeyboardInterrupt: |
|
|
|
pass |
|
if done_event.is_set(): |
|
data_queue.cancel_join_thread() |
|
data_queue.close() |
|
|