lunde's picture
Initial commit
bd65e34
from torchvision.transforms import Compose
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
import polars as pl
class FrameDataset(Dataset):
def __init__(
self,
df: pl.DataFrame,
augments: Compose,
frames_per_clip: int,
stride: int | None = None,
is_train: bool = True,
):
super().__init__()
self.paths = df["path"].to_list()
self.is_train = is_train
if is_train:
self.y = torch.tensor(df["label"])
self.frames_per_clip = frames_per_clip
self.augments = augments
self.stride = stride or frames_per_clip
def __len__(self):
return len(self.paths) // self.stride
def __getitem__(self, idx):
start = idx * self.stride
stop = start + self.frames_per_clip
if stop - start <= 1:
path = self.paths[start]
frames_tr = self._open_augment_img(path)
if self.is_train:
y = self.y[start]
else:
frames = [self._open_augment_img(path) for path in self.paths[start:stop]]
frames_tr = torch.stack(frames)
if self.is_train:
y = self.y[start:stop].max()
if self.is_train:
return frames_tr, y
else:
return frames_tr
def _open_augment_img(self, path):
img = default_loader(path)
img = self.augments(img)
return img