|
import os |
|
from pathlib import Path |
|
|
|
import random |
|
|
|
import numpy as np |
|
import pickle as pk |
|
import cv2 |
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
import torchvision.transforms as transforms |
|
import torch |
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
class VideoDataset(Dataset): |
|
|
|
def __init__(self, directory_list, local_rank=0, enable_GPUs_num=0, distributed_load=False, resize_shape=[224, 224] , mode='train', clip_len=32, crop_size = 168): |
|
|
|
self.clip_len, self.crop_size, self.resize_shape = clip_len, crop_size, resize_shape |
|
self.mode = mode |
|
|
|
self.fnames, labels = [],[] |
|
|
|
for directory in directory_list: |
|
folder = Path(directory) |
|
print("Load dataset from folder : ", folder) |
|
for label in sorted(os.listdir(folder)): |
|
for fname in os.listdir(os.path.join(folder, label)) if mode=="train" else os.listdir(os.path.join(folder, label))[:10]: |
|
self.fnames.append(os.path.join(folder, label, fname)) |
|
labels.append(label) |
|
|
|
random_list = list(zip(self.fnames, labels)) |
|
random.shuffle(random_list) |
|
self.fnames[:], labels[:] = zip(*random_list) |
|
self.labels = labels |
|
|
|
|
|
|
|
if mode == 'train' and distributed_load: |
|
single_num_ = len(self.fnames)//enable_GPUs_num |
|
self.fnames = self.fnames[local_rank*single_num_:((local_rank+1)*single_num_)] |
|
labels = labels[local_rank*single_num_:((local_rank+1)*single_num_)] |
|
|
|
|
|
self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))} |
|
|
|
self.label_array = np.array([self.label2index[label] for label in labels], dtype=int) |
|
|
|
def __getitem__(self, index): |
|
|
|
buffer = self.loadvideo(self.fnames[index]) |
|
|
|
height_index = np.random.randint(buffer.shape[2] - self.crop_size) |
|
width_index = np.random.randint(buffer.shape[3] - self.crop_size) |
|
|
|
return buffer[:,:,height_index:height_index + self.crop_size, width_index:width_index + self.crop_size], self.label_array[index] |
|
|
|
|
|
def __len__(self): |
|
return len(self.fnames) |
|
|
|
|
|
def loadvideo(self, fname): |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize([self.resize_shape[0], self.resize_shape[1]]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
|
|
flip, flipCode = 1, random.choice([-1,0,1]) if np.random.random() < 0.5 and self.mode=="train" else 0 |
|
|
|
try: |
|
video_stream = cv2.VideoCapture(fname) |
|
frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
except RuntimeError: |
|
index = np.random.randint(self.__len__()) |
|
video_stream = cv2.VideoCapture(self.fnames[index]) |
|
frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
while frame_count<self.clip_len+2: |
|
index = np.random.randint(self.__len__()) |
|
video_stream = cv2.VideoCapture(self.fnames[index]) |
|
frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
speed_rate = np.random.randint(1, 3) if frame_count > self.clip_len*2+2 else 1 |
|
time_index = np.random.randint(frame_count - self.clip_len * speed_rate) |
|
|
|
start_idx, end_idx, final_idx = time_index, time_index+(self.clip_len*speed_rate), frame_count-1 |
|
count, sample_count, retaining = 0, 0, True |
|
|
|
|
|
buffer = np.empty((self.clip_len, 3, self.resize_shape[0], self.resize_shape[1]), np.dtype('float32')) |
|
|
|
while (count <= end_idx and retaining): |
|
retaining, frame = video_stream.read() |
|
if count < start_idx: |
|
count += 1 |
|
continue |
|
if count % speed_rate == speed_rate-1 and count >= start_idx and sample_count < self.clip_len: |
|
if flip: |
|
frame = cv2.flip(frame, flipCode=flipCode) |
|
try: |
|
buffer[sample_count] = self.transform(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) |
|
except cv2.error as err: |
|
continue |
|
sample_count += 1 |
|
count += 1 |
|
video_stream.release() |
|
|
|
return buffer.transpose((1, 0, 2, 3)) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
datapath = ['/data/datasets/ucf101/videos'] |
|
|
|
dataset = VideoDataset(datapath, |
|
resize_shape=[224, 224], |
|
mode='validation') |
|
x, y = dataset[0] |
|
|
|
print(x.shape, y.shape, y) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|