|
from torch.utils.data._utils.pin_memory import ( |
|
torch, queue, pin_memory, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper) |
|
|
|
from .worker import _ResumeIteration |
|
|
|
from petrel_client.utils.profile import profileit, wrap_with_stat_qsize |
|
|
|
|
|
def _pin_memory_loop(in_queue, out_queue, device_id, done_event): |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
|
torch.cuda.set_device(device_id) |
|
|
|
in_queue_get = wrap_with_stat_qsize( |
|
in_queue, in_queue.get, '_pin_memory_loop.in_queue.qsize:') |
|
out_queue_put = wrap_with_stat_qsize( |
|
out_queue, out_queue.put, '_pin_memory_loop.out_queue.qsize:') |
|
|
|
in_queue.get = in_queue_get |
|
out_queue.put = out_queue_put |
|
|
|
cnt = 1 |
|
brk = 0 |
|
|
|
def loop(): |
|
try: |
|
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) |
|
except queue.Empty: |
|
return cnt |
|
if not isinstance(r, _ResumeIteration): |
|
idx, data = r |
|
if not done_event.is_set() and not isinstance(data, ExceptionWrapper): |
|
try: |
|
data = pin_memory(data) |
|
except Exception: |
|
data = ExceptionWrapper( |
|
where="in pin memory thread for device {}".format(device_id)) |
|
r = (idx, data) |
|
while not done_event.is_set(): |
|
try: |
|
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) |
|
break |
|
except queue.Full: |
|
continue |
|
del r |
|
|
|
loop = profileit(loop, name='_pin_memory_loop.loop') |
|
|
|
|
|
while not done_event.is_set(): |
|
if loop() == brk: |
|
break |
|
|