Spaces:
Sleeping
Sleeping
from typing import Iterable, Callable, Optional, Any, Union | |
import time | |
import platform | |
import threading | |
import queue | |
import torch | |
import torch.multiprocessing as tm | |
from ding.torch_utils import to_device | |
from ding.utils import LockContext, LockContextType | |
from .base_dataloader import IDataLoader | |
from .collate_fn import default_collate | |
class AsyncDataLoader(IDataLoader): | |
""" | |
Overview: | |
An asynchronous dataloader. | |
Interfaces: | |
``__init__``, ``__iter__``, ``__next__``, ``_get_data``, ``_async_loop``, ``_worker_loop``, ``_cuda_loop``, \ | |
``_get_data``, ``close`` | |
""" | |
def __init__( | |
self, | |
data_source: Union[Callable, dict], | |
batch_size: int, | |
device: str, | |
chunk_size: Optional[int] = None, | |
collate_fn: Optional[Callable] = None, | |
num_workers: int = 0 | |
) -> None: | |
""" | |
Overview: | |
Init dataloader with input parameters. | |
If ``data_source`` is ``dict``, data will only be processed in ``get_data_thread`` and put into | |
``async_train_queue``. | |
If ``data_source`` is ``Callable``, data will be processed by implementing functions, and can be sorted | |
in two types: | |
- ``num_workers`` == 0 or 1: Only main worker will process it and put into ``async_train_queue``. | |
- ``num_workers`` > 1: Main worker will divide a job into several pieces, push every job into \ | |
``job_queue``; Then slave workers get jobs and implement; Finally they will push procesed data \ | |
into ``async_train_queue``. | |
At the last step, if ``device`` contains "cuda", data in ``async_train_queue`` will be transferred to | |
``cuda_queue`` for uer to access. | |
Arguments: | |
- data_source (:obj:`Union[Callable, dict]`): The data source, e.g. function to be implemented(Callable), \ | |
replay buffer's real data(dict), etc. | |
- batch_size (:obj:`int`): Batch size. | |
- device (:obj:`str`): Device. | |
- chunk_size (:obj:`int`): The size of a chunked piece in a batch, should exactly divide ``batch_size``, \ | |
only function when there are more than 1 worker. | |
- collate_fn (:obj:`Callable`): The function which is used to collate batch size into each data field. | |
- num_workers (:obj:`int`): Number of extra workers. \ | |
0 or 1 means only 1 main worker and no extra ones, i.e. Multiprocessing is disabled. \ | |
More than 1 means multiple workers implemented by multiprocessing are to processs data respectively. | |
""" | |
self.data_source = data_source | |
self.batch_size = batch_size | |
self.device = device | |
self.use_cuda = 'cuda' in self.device | |
if self.use_cuda: | |
self.stream = torch.cuda.Stream() | |
if chunk_size is None: | |
self.chunk_size = 1 | |
else: | |
self.chunk_size = chunk_size | |
assert self.batch_size >= self.chunk_size and self.batch_size % self.chunk_size == 0, '{}/{}'.format( | |
self.batch_size, self.chunk_size | |
) | |
if collate_fn is None: | |
self.collate_fn = default_collate | |
else: | |
self.collate_fn = collate_fn | |
self.num_workers = num_workers | |
if self.num_workers < 0: | |
raise ValueError( | |
'"num_workers" should be non-negative; ' | |
'Use num_workers = 0 or 1 to disable multiprocessing.' | |
) | |
# Up to "2 * num_workers" pieces of data will be stored in dataloader, waiting for learner to get. | |
# Up to "2 * num_workers" jobs will be stored in dataloader, waiting for slave process to get and accomplish. | |
queue_maxsize = max(1, self.num_workers) * 2 | |
self.queue_maxsize = queue_maxsize | |
# For multiprocessing: Use ``spawn`` on Windows, ``fork`` on other platforms. | |
context_str = 'spawn' if platform.system().lower() == 'windows' else 'fork' | |
self.mp_context = tm.get_context(context_str) | |
self.manager = self.mp_context.Manager() | |
# ``async_train_queue`` is the queue to store processed data. | |
# User can directly access data if don't use cuda; Otherwise, user will access data from ``cuda_queue``. | |
self.async_train_queue = self.mp_context.Queue(maxsize=queue_maxsize) | |
self.end_flag = False | |
# Multiprocessing workers: If num_workers > 1, more than 1 worker are to process data. | |
if self.num_workers > 1: | |
self.batch_id = self.mp_context.Value('i', 0) | |
self.cur_batch = self.mp_context.Value('i', 0) | |
if self.batch_size != self.chunk_size: | |
# job_result {batch_id: result_list} is used to store processed result in temporal. | |
self.job_result = self.manager.dict() | |
self.job_result_lock = LockContext(type_=LockContextType.PROCESS_LOCK) | |
self.job_queue = self.mp_context.Queue(maxsize=queue_maxsize) | |
self.worker = [ | |
self.mp_context.Process( | |
target=self._worker_loop, args=(), name='dataloader_worker{}_{}'.format(i, time.time()) | |
) for i in range(self.num_workers) | |
] | |
for w in self.worker: | |
w.daemon = True | |
w.start() | |
print('Using {} workers to load data'.format(self.num_workers)) | |
# Parent and child pipes. Used by ``async_process`` and ``get_data_thread`` to coordinate. | |
p, c = self.mp_context.Pipe() | |
# Async process (Main worker): Process data if num_workers <= 1; Assign job to other workers if num_workers > 1. | |
self.async_process = self.mp_context.Process(target=self._async_loop, args=(p, c)) | |
self.async_process.daemon = True | |
self.async_process.start() | |
# Get data thread: Get data from ``data_source`` and send it to ``async_process``.` | |
self.get_data_thread = threading.Thread(target=self._get_data, args=(p, c)) | |
self.get_data_thread.daemon = True | |
self.get_data_thread.start() | |
# Cuda thread: If use cuda, data in ``async_train_queue`` will be transferred to ``cuda_queue``; | |
# Then user will access data from ``cuda_queue``. | |
if self.use_cuda: | |
self.cuda_queue = queue.Queue(maxsize=queue_maxsize) | |
self.cuda_thread = threading.Thread(target=self._cuda_loop, args=(), name='dataloader_cuda') | |
self.cuda_thread.daemon = True | |
self.cuda_thread.start() | |
def __iter__(self) -> Iterable: | |
""" | |
Overview: | |
Return the iterable self as an iterator. | |
Returns: | |
- self (:obj:`Iterable`): Self as an iterator. | |
""" | |
return self | |
def _get_data(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None: | |
""" | |
Overview: | |
Init dataloader with input parameters. Will run as a thread through ``self.get_data_thread``. | |
Arguments: | |
- p (:obj:`tm.multiprocessing.connection`): Parent connection. | |
- c (:obj:`tm.multiprocessing.connection`): Child connection. | |
""" | |
c.close() # Close unused c, only use p | |
while not self.end_flag: | |
if not p.poll(timeout=0.2): | |
time.sleep(0.01) | |
continue | |
try: | |
cmd = p.recv() | |
except EOFError: | |
break | |
if cmd == 'get_data': | |
# Main worker asks for data. | |
data = self.data_source(self.batch_size) | |
# ``data`` can be callable, e.g. a function to read data from file, therefore we can divide | |
# this job to pieces, assign to every slave worker and accomplish jobs asynchronously. | |
# But if we get a list of dicts, which means the data has already been processed and | |
# can be used directly, we can put it directly in async_train_queue and wait it | |
# to be accessed by a user, e.g. learner. | |
if isinstance(data[0], dict): | |
data = self.collate_fn(data) | |
self.async_train_queue.put(data) | |
p.send('pass') | |
else: | |
p.send(data) | |
p.close() | |
def _async_loop(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None: | |
""" | |
Overview: | |
Main worker process. Run through ``self.async_process``. | |
Firstly, get data from ``self.get_data_thread``. | |
If multiple workers, put data in ``self.job_queue`` for further multiprocessing operation; | |
If only one worker, process data and put directly into ``self.async_train_queue``. | |
Arguments: | |
- p (:obj:`tm.multiprocessing.connection`): Parent connection. | |
- c (:obj:`tm.multiprocessing.connection`): Child connection. | |
""" | |
torch.set_num_threads(1) | |
p.close() # Close unused p, only use c | |
while not self.end_flag: | |
if self.num_workers > 1: | |
# Multiple workers: Put jobs (chunked data) into job_queue | |
if self.job_queue.full(): | |
time.sleep(0.001) | |
else: | |
# Get data from ``_get_data`` thread. | |
c.send('get_data') | |
data = c.recv() | |
if isinstance(data, str) and data == 'pass': | |
continue | |
# Get data to be processed, chunk it into pieces and put them into job_queue. | |
chunk_num = self.batch_size // self.chunk_size | |
with self.batch_id.get_lock(): | |
for i in range(chunk_num): | |
start, end = i * self.chunk_size, (i + 1) * self.chunk_size | |
self.job_queue.put({'batch_id': self.batch_id.value, 'job': data[start:end]}) | |
self.batch_id.value = (self.batch_id.value + 1) % self.queue_maxsize # Increment batch_id | |
time.sleep(0.001) | |
else: | |
# Only one worker: Process data and directly put it into async_train_queue | |
if self.async_train_queue.full(): | |
time.sleep(0.001) | |
else: | |
c.send('get_data') | |
data = c.recv() | |
if isinstance(data, str) and data == 'pass': | |
continue | |
data = [fn() for fn in data] # Implement functions in list ``data``. | |
data = self.collate_fn(data) | |
self.async_train_queue.put(data) | |
c.close() | |
def _worker_loop(self) -> None: | |
""" | |
Overview: | |
Worker process. Run through each element in list ``self.worker``. | |
Get data job from ``self.job_queue``, process it and then put into ``self.async_train_queue``. | |
Only function when ``self.num_workers`` > 1, which means using multiprocessing. | |
""" | |
while not self.end_flag: | |
if self.job_queue.empty() or self.async_train_queue.full(): | |
# No left job to be done, or finished job have no space to store. | |
time.sleep(0.01) | |
continue | |
else: | |
try: | |
element = self.job_queue.get() | |
except (ConnectionResetError, ConnectionRefusedError) as e: | |
break | |
batch_id, job = element['batch_id'], element['job'] | |
# Process the assigned data. | |
data = [fn() for fn in job] # Only function-type job will arrive here, dict-type will not | |
if len(data) == self.batch_size == self.chunk_size: | |
# Data not chunked: Finish the assigned one means finishing a whole batch. | |
data = self.collate_fn(data) | |
while batch_id != self.cur_batch.value: | |
time.sleep(0.01) | |
self.async_train_queue.put(data) | |
# Directly update cur_batch, since a whole batch is finished | |
with self.cur_batch.get_lock(): | |
self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize | |
else: | |
# Data chunked: Must wait for all chunked pieces in a batch to be accomplished. | |
finish_flag = False # indicate whether a whole batch is accomplished | |
with self.job_result_lock: | |
if batch_id not in self.job_result: | |
# The first one in a batch | |
self.job_result[batch_id] = data | |
elif len(self.job_result[batch_id]) + len(data) == self.batch_size: | |
# The last one in a batch | |
data += self.job_result.pop(batch_id) | |
assert batch_id not in self.job_result | |
finish_flag = True | |
else: | |
# Middle pieces in a batch | |
self.job_result[batch_id] += data | |
if finish_flag: | |
data = self.collate_fn(data) | |
while batch_id != self.cur_batch.value: | |
time.sleep(0.01) | |
self.async_train_queue.put(data) | |
with self.cur_batch.get_lock(): | |
self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize | |
# If ``self.end_flag`` is True, clear and close job_queue, because _worker_loop gets jobs from job_queue. | |
while not self.job_queue.empty(): | |
try: | |
_ = self.job_queue.get() | |
except Exception as e: | |
break | |
self.job_queue.close() | |
self.job_queue.join_thread() | |
def _cuda_loop(self) -> None: | |
""" | |
Overview: | |
Only when using cuda, would this be run as a thread through ``self.cuda_thread``. | |
Get data from ``self.async_train_queue``, change its device and put it into ``self.cuda_queue`` | |
""" | |
with torch.cuda.stream(self.stream): | |
while not self.end_flag: | |
if self.async_train_queue.empty() or self.cuda_queue.full(): | |
time.sleep(0.01) | |
else: | |
data = self.async_train_queue.get() | |
data = to_device(data, self.device) | |
self.cuda_queue.put(data) | |
# If ``self.end_flag``` is True, clear and close async_train_queue, | |
# because _cuda_loop gets data from async_train_queue. | |
while not self.async_train_queue.empty(): | |
_ = self.async_train_queue.get() | |
self.async_train_queue.close() | |
self.async_train_queue.join_thread() | |
def __next__(self) -> Any: | |
""" | |
Overview: | |
Return next data in the iterator. If use cuda, get from ``self.cuda_queue``; | |
Otherwise, get from ``self.async_train_queue``. | |
Returns: | |
- data (:obj:`torch.Tensor`): Next data in the dataloader iterator. | |
""" | |
while not self.end_flag: | |
if self.use_cuda: | |
if self.cuda_queue.empty(): | |
time.sleep(0.01) | |
else: | |
data = self.cuda_queue.get(timeout=60) | |
self.cuda_queue.task_done() | |
return data | |
else: | |
if self.async_train_queue.empty(): | |
time.sleep(0.01) | |
else: | |
return self.async_train_queue.get() | |
# If ``self.end_flag``` is True, clear and close either 1) or 2): | |
# 1) cuda_queue. Because user get data from cuda_queue, and async_train_queue is closed by cuda_loop. | |
# 2) async_train_queue. Because user get data from async_train_queue. | |
if self.use_cuda: | |
while not self.cuda_queue.empty(): | |
_ = self.cuda_queue.get() | |
self.cuda_queue.task_done() | |
self.cuda_queue.join() | |
else: | |
while not self.async_train_queue.empty(): | |
_ = self.async_train_queue.get() | |
self.async_train_queue.close() | |
self.async_train_queue.join_thread() | |
def __del__(self) -> None: | |
""" | |
Overview: | |
Delete this dataloader. | |
""" | |
self.close() | |
def close(self) -> None: | |
""" | |
Overview: | |
Delete this dataloader. First set ``end_flag`` to True, which means different processes/threads | |
will clear and close all data queues; Then all processes will be terminated and joined. | |
""" | |
if self.end_flag: | |
return | |
self.end_flag = True | |
self.async_process.terminate() | |
self.async_process.join() | |
if self.num_workers > 1: | |
for w in self.worker: | |
w.terminate() | |
w.join() | |
print('Del AsyncDataLoader') | |