Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING | |
from threading import Thread, Event | |
from queue import Queue | |
import time | |
import numpy as np | |
import torch | |
from easydict import EasyDict | |
from ding.framework import task | |
from ding.data import Dataset, DataLoader | |
from ding.utils import get_rank, get_world_size | |
if TYPE_CHECKING: | |
from ding.framework import OfflineRLContext | |
class OfflineMemoryDataFetcher: | |
def __new__(cls, *args, **kwargs): | |
if task.router.is_active and not task.has_role(task.role.FETCHER): | |
return task.void() | |
return super(OfflineMemoryDataFetcher, cls).__new__(cls) | |
def __init__(self, cfg: EasyDict, dataset: Dataset): | |
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' | |
if device != 'cpu': | |
stream = torch.cuda.Stream() | |
def producer(queue, dataset, batch_size, device, event): | |
torch.set_num_threads(4) | |
if device != 'cpu': | |
nonlocal stream | |
sbatch_size = batch_size * get_world_size() | |
rank = get_rank() | |
idx_list = np.random.permutation(len(dataset)) | |
temp_idx_list = [] | |
for i in range(len(dataset) // sbatch_size): | |
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) | |
idx_iter = iter(temp_idx_list) | |
if device != 'cpu': | |
with torch.cuda.stream(stream): | |
while True: | |
if queue.full(): | |
time.sleep(0.1) | |
else: | |
data = [] | |
for _ in range(batch_size): | |
try: | |
data.append(dataset.__getitem__(next(idx_iter))) | |
except StopIteration: | |
del idx_iter | |
idx_list = np.random.permutation(len(dataset)) | |
idx_iter = iter(idx_list) | |
data.append(dataset.__getitem__(next(idx_iter))) | |
data = [[i[j] for i in data] for j in range(len(data[0]))] | |
data = [torch.stack(x).to(device) for x in data] | |
queue.put(data) | |
if event.is_set(): | |
break | |
else: | |
while True: | |
if queue.full(): | |
time.sleep(0.1) | |
else: | |
data = [] | |
for _ in range(batch_size): | |
try: | |
data.append(dataset.__getitem__(next(idx_iter))) | |
except StopIteration: | |
del idx_iter | |
idx_list = np.random.permutation(len(dataset)) | |
idx_iter = iter(idx_list) | |
data.append(dataset.__getitem__(next(idx_iter))) | |
data = [[i[j] for i in data] for j in range(len(data[0]))] | |
data = [torch.stack(x) for x in data] | |
queue.put(data) | |
if event.is_set(): | |
break | |
self.queue = Queue(maxsize=50) | |
self.event = Event() | |
self.producer_thread = Thread( | |
target=producer, | |
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), | |
name='cuda_fetcher_producer' | |
) | |
def __call__(self, ctx: "OfflineRLContext"): | |
if not self.producer_thread.is_alive(): | |
time.sleep(5) | |
self.producer_thread.start() | |
while self.queue.empty(): | |
time.sleep(0.001) | |
ctx.train_data = self.queue.get() | |
def __del__(self): | |
if self.producer_thread.is_alive(): | |
self.event.set() | |
del self.queue | |