|
import options as opt |
|
import matplotlib.pyplot as plt |
|
import torch.optim as optim |
|
import numpy as np |
|
import time |
|
|
|
from dataset import GridDataset |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def dataset2dataloader( |
|
dataset, num_workers=opt.num_workers, shuffle=True |
|
): |
|
return DataLoader( |
|
dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=shuffle, |
|
num_workers=num_workers, |
|
drop_last=False |
|
) |
|
|
|
|
|
dataset = GridDataset( |
|
video_path=opt.video_path, |
|
alignments_dir=opt.alignments_dir, |
|
file_list=opt.train_list, |
|
vid_pad=opt.vid_padding, |
|
image_dir=opt.images_dir, |
|
txt_pad=opt.txt_padding, |
|
phase='train' |
|
) |
|
|
|
loader = dataset2dataloader(dataset) |
|
|
|
|
|
def fetch_samples(num_samples=10): |
|
samples = [] |
|
sample_no = 0 |
|
|
|
for sample in loader: |
|
sample_no += 1 |
|
samples.append(sample) |
|
|
|
if sample_no >= num_samples: |
|
break |
|
|
|
return samples |
|
|
|
|
|
samples = fetch_samples() |
|
print(samples[0]) |
|
print('END') |