Weiyun1025's picture
Upload folder using huggingface_hub
2abfccb verified
from torch.utils.data._utils.pin_memory import (
torch, queue, pin_memory, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper)
from .worker import _ResumeIteration
from petrel_client.utils.profile import profileit, wrap_with_stat_qsize
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
torch.cuda.set_device(device_id)
in_queue_get = wrap_with_stat_qsize(
in_queue, in_queue.get, '_pin_memory_loop.in_queue.qsize:')
out_queue_put = wrap_with_stat_qsize(
out_queue, out_queue.put, '_pin_memory_loop.out_queue.qsize:')
in_queue.get = in_queue_get
out_queue.put = out_queue_put
cnt = 1
brk = 0
def loop():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
return cnt
if not isinstance(r, _ResumeIteration):
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
data = pin_memory(data)
except Exception:
data = ExceptionWrapper(
where="in pin memory thread for device {}".format(device_id))
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
del r # save memory
loop = profileit(loop, name='_pin_memory_loop.loop')
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
if loop() == brk:
break