Spaces:
Runtime error
Runtime error
File size: 2,298 Bytes
7734d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import torch
import torch.distributed as dist
from yolox.utils import synchronize
import random
class DataPrefetcher:
"""
DataPrefetcher is inspired by code of following file:
https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
It could speedup your pytorch dataloader. For more information, please check
https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
"""
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.input_cuda = self._input_cuda_for_image
self.record_stream = DataPrefetcher._record_stream_for_image
self.preload()
def preload(self):
try:
self.next_input, self.next_target, _, _ = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.input_cuda()
self.next_target = self.next_target.cuda(non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
self.record_stream(input)
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target
def _input_cuda_for_image(self):
self.next_input = self.next_input.cuda(non_blocking=True)
@staticmethod
def _record_stream_for_image(input):
input.record_stream(torch.cuda.current_stream())
def random_resize(data_loader, exp, epoch, rank, is_distributed):
tensor = torch.LongTensor(1).cuda()
if is_distributed:
synchronize()
if rank == 0:
if epoch > exp.max_epoch - 10:
size = exp.input_size
else:
size = random.randint(*exp.random_size)
size = int(32 * size)
tensor.fill_(size)
if is_distributed:
synchronize()
dist.broadcast(tensor, 0)
input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None)
return input_size
|