Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import torch | |
class DataPrefetcher: | |
""" | |
DataPrefetcher is inspired by code of following file: | |
https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py | |
It could speedup your pytorch dataloader. For more information, please check | |
https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789. | |
""" | |
def __init__(self, loader): | |
self.loader = iter(loader) | |
self.stream = torch.cuda.Stream() | |
self.input_cuda = self._input_cuda_for_image | |
self.record_stream = DataPrefetcher._record_stream_for_image | |
self.preload() | |
def preload(self): | |
try: | |
self.next_input, self.next_target, _, _ = next(self.loader) | |
except StopIteration: | |
self.next_input = None | |
self.next_target = None | |
return | |
with torch.cuda.stream(self.stream): | |
self.input_cuda() | |
self.next_target = self.next_target.cuda(non_blocking=True) | |
def next(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
input = self.next_input | |
target = self.next_target | |
if input is not None: | |
self.record_stream(input) | |
if target is not None: | |
target.record_stream(torch.cuda.current_stream()) | |
self.preload() | |
return input, target | |
def _input_cuda_for_image(self): | |
self.next_input = self.next_input.cuda(non_blocking=True) | |
def _record_stream_for_image(input): | |
input.record_stream(torch.cuda.current_stream()) | |