Spaces:
Sleeping
Sleeping
from torchvision.transforms import Compose | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.datasets.folder import default_loader | |
import polars as pl | |
class FrameDataset(Dataset): | |
def __init__( | |
self, | |
df: pl.DataFrame, | |
augments: Compose, | |
frames_per_clip: int, | |
stride: int | None = None, | |
is_train: bool = True, | |
): | |
super().__init__() | |
self.paths = df["path"].to_list() | |
self.is_train = is_train | |
if is_train: | |
self.y = torch.tensor(df["label"]) | |
self.frames_per_clip = frames_per_clip | |
self.augments = augments | |
self.stride = stride or frames_per_clip | |
def __len__(self): | |
return len(self.paths) // self.stride | |
def __getitem__(self, idx): | |
start = idx * self.stride | |
stop = start + self.frames_per_clip | |
if stop - start <= 1: | |
path = self.paths[start] | |
frames_tr = self._open_augment_img(path) | |
if self.is_train: | |
y = self.y[start] | |
else: | |
frames = [self._open_augment_img(path) for path in self.paths[start:stop]] | |
frames_tr = torch.stack(frames) | |
if self.is_train: | |
y = self.y[start:stop].max() | |
if self.is_train: | |
return frames_tr, y | |
else: | |
return frames_tr | |
def _open_augment_img(self, path): | |
img = default_loader(path) | |
img = self.augments(img) | |
return img | |