torchnet / dataset_test.py
milselarch's picture
push to main
df07554
raw
history blame contribute delete
991 Bytes
import options as opt
import matplotlib.pyplot as plt
import torch.optim as optim
import numpy as np
import time
from dataset import GridDataset
from torch.utils.data import DataLoader
def dataset2dataloader(
dataset, num_workers=opt.num_workers, shuffle=True
):
return DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=False
)
dataset = GridDataset(
video_path=opt.video_path,
alignments_dir=opt.alignments_dir,
file_list=opt.train_list,
vid_pad=opt.vid_padding,
image_dir=opt.images_dir,
txt_pad=opt.txt_padding,
phase='train'
)
loader = dataset2dataloader(dataset)
def fetch_samples(num_samples=10):
samples = []
sample_no = 0
for sample in loader:
sample_no += 1
samples.append(sample)
if sample_no >= num_samples:
break
return samples
samples = fetch_samples()
print(samples[0])
print('END')