Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
from src.data.utils import pre_caption | |
from src.tools.files import read_txt | |
normalize = transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
class ImageDataset(Dataset): | |
def __init__( | |
self, | |
image_dir, | |
img_ext: str = "png", | |
save_dir=None, | |
): | |
self.image_dir = Path(image_dir) | |
self.img_pths = self.image_dir.glob(f"*.{img_ext}") | |
self.id2pth = {img_pth.stem: img_pth for img_pth in self.img_pths} | |
self.video_ids = list(self.id2pth.keys()) | |
if save_dir is not None: | |
save_dir = Path(save_dir) | |
done_paths = list(save_dir.glob("*.pth")) | |
done_paths = {p.stem for p in done_paths} | |
print(f"video_ids: {len(self.video_ids)} - {len(done_paths)} = ", end="") | |
self.video_ids = list(set(self.video_ids) - done_paths) | |
print(len(self.video_ids)) | |
self.video_ids.sort() | |
if len(self.video_ids) == 0: | |
print("All videos are done") | |
exit() | |
def __len__(self): | |
return len(self.video_ids) | |
def __getitem__(self, index): | |
video_id = self.video_ids[index] | |
img_pth = self.id2pth[video_id] | |
img = Image.open(img_pth).convert("RGB") | |
img = transform(img) | |
return img, video_id | |
class VideoDataset(Dataset): | |
def __init__( | |
self, | |
video_dir, | |
todo_ids=None, | |
shard_id=0, | |
num_shards=1, | |
frames_video=15, | |
extention="mp4", | |
save_dir=None, | |
): | |
self.video_dir = Path(video_dir) | |
if isinstance(todo_ids, (str, Path)): | |
todo_ids = read_txt(todo_ids) | |
found_paths = list(video_dir.glob(f"*/*.{extention}")) | |
if todo_ids is not None: | |
video_paths = [video_dir / f"{v}.{extention}" for v in todo_ids] | |
video_paths = list(set(video_paths) & set(found_paths)) | |
else: | |
video_paths = found_paths | |
video_paths.sort() | |
self.id2path = {pth.parent.name + "/" + pth.stem: pth for pth in video_paths} | |
self.video_ids = list(self.id2path.keys()) | |
self.video_ids.sort() | |
if save_dir is not None: | |
save_dir = Path(save_dir) | |
done_paths = list(save_dir.glob("*/*.pth")) | |
done_paths = {p.parent.name + "/" + p.stem for p in done_paths} | |
print(f"video_ids: {len(self.video_ids)} - {len(done_paths)} = ", end="") | |
self.video_ids = list(set(self.video_ids) - done_paths) | |
print(len(self.video_ids)) | |
self.video_ids.sort() | |
if len(self.video_ids) == 0: | |
print("All videos are done") | |
exit() | |
assert len(self.video_ids) > 0, f"video_ids is empty" | |
# shard the dataset | |
n_videos = len(self.video_ids) | |
self.video_ids = self.video_ids[ | |
shard_id * n_videos // num_shards : (shard_id + 1) * n_videos // num_shards | |
] | |
self.frames_video = frames_video | |
def __len__(self): | |
return len(self.video_ids) | |
def __getitem__(self, index): | |
video_id = self.video_ids[index] | |
video_path = self.id2path[video_id] | |
frames, f_idxs = get_video_frames(video_path, self.frames_video) | |
frames = [transform(frame) for frame in frames] | |
frames = torch.stack(frames, dim=0) | |
f_idxs = torch.tensor(f_idxs) | |
return video_id, f_idxs, frames | |
class TextDataset(Dataset): | |
def __init__( | |
self, | |
csv_path, | |
max_words=30, | |
): | |
self.df = pd.read_csv(csv_path) | |
self.texts = list(set(self.df["edit"].unique().tolist())) | |
self.texts.sort() | |
self.max_words = max_words | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, index): | |
txt = self.texts[index] | |
txt = pre_caption(txt, self.max_words) | |
return txt | |
def get_video_frames(video_pth, frames_video=15): | |
import cv2 | |
video_pth = str(video_pth) | |
# use OpenCV to read the video | |
cap = cv2.VideoCapture(video_pth) | |
# get the total number of frames in the video | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_idxs = sample_frames(total_frames, frames_video) | |
frames = [] | |
f_idxs = [] | |
for frame_idx in frame_idxs: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
ret, frame = cap.read() | |
if not ret or frame is None: | |
print(f"Video {video_pth} is corrupted") | |
frames = [ | |
Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) | |
] * frames_video | |
f_idxs = [-1] * frames_video | |
return frames, f_idxs | |
frames.append(Image.fromarray(frame)) | |
f_idxs.append(frame_idx) | |
# pad frames to have the same number of frames | |
n_frames = len(frames) | |
if n_frames < frames_video: | |
frames += [Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8))] * ( | |
frames_video - n_frames | |
) | |
# Add -1 to f_idxs for the remaining frames | |
f_idxs += [-1] * (frames_video - len(f_idxs)) | |
return frames, f_idxs | |
def sample_frames(vlen, frames_per_video=15): | |
acc_samples = min(vlen, frames_per_video) | |
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) | |
ranges = [] | |
for idx, interv in enumerate(intervals[:-1]): | |
ranges.append((interv, intervals[idx + 1] - 1)) | |
frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] | |
return frame_idxs | |