import re from typing import Union import torch def pre_caption(caption, max_words=50): caption = re.sub( r"([.!\"()*#:;~])", " ", caption.lower(), ) caption = re.sub( r"\s{2,}", " ", caption, ) caption = caption.rstrip("\n") caption = caption.strip(" ") # truncate caption caption_words = caption.split(" ") if len(caption_words) > max_words: caption = " ".join(caption_words[:max_words]) return caption def id2int(data, sub=""): if isinstance(data, list): return [remove_non_digits(d, sub) for d in data] else: return remove_non_digits(data, sub) def remove_non_digits(string, sub: str = ""): return int(re.sub(r"\D", sub, string)) def get_middle_frame(reference_vid_pth): from pathlib import Path import cv2 import numpy as np from PIL import Image reference_vid_pth = str(reference_vid_pth) if not Path(reference_vid_pth).exists(): print(f"Video {reference_vid_pth} does not exist") return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) # use OpenCV to read the video cap = cv2.VideoCapture(reference_vid_pth) # get the total number of frames in the video total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # calculate the index of the middle frame middle_frame_index = total_frames // 2 # set the current frame index to the middle frame index cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_index) # read the middle frame ret, frame = cap.read() if not ret or frame is None: print(f"Video {reference_vid_pth} is corrupted") return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) # convert the frame from BGR to RGB using OpenCV frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # create a PIL Image object from the middle frame pil_image = Image.fromarray(frame) return pil_image def get_random_frame(reference_vid_pth): from pathlib import Path import cv2 import numpy as np from PIL import Image reference_vid_pth = str(reference_vid_pth) if not Path(reference_vid_pth).exists(): print(f"Video {reference_vid_pth} does not exist") return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) # use OpenCV to read the video cap = cv2.VideoCapture(reference_vid_pth) # get the total number of frames in the video total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # calculate the index of random frame random_frame_index = np.random.randint(0, total_frames) # set the current frame index to the random frame index cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_index) # read the frame ret, frame = cap.read() if not ret or frame is None: print(f"Video {reference_vid_pth} is corrupted") return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) # convert the frame from BGR to RGB using OpenCV frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # create a PIL Image object from the middle frame pil_image = Image.fromarray(frame) return pil_image def sample_frames(frames_videos, vlen): import numpy as np acc_samples = min(frames_videos, vlen) intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] return frame_idxs class FrameLoader: def __init__(self, transform, frames_video=1, method="middle"): self.transform = transform self.method = method if method == "middle": self.get_frame = get_middle_frame assert frames_video == 1, "frames_video must be 1 for middle frame method" elif method == "random": self.get_frame = get_random_frame assert frames_video == 1, "frames_video must be 1 for random frame method" elif method == "sample": assert frames_video > 1, "frames_video must be > 1 for sample frame method" self.frames_video = frames_video else: raise ValueError(f"Invalid method: {method}") def __call__(self, video_pth: str): if self.method == "sample": frames = self.get_video_frames(video_pth, 0.0, None) return torch.stack(frames) else: return self.transform(self.get_frame(video_pth)) def get_video_frames( self, video_pth: str, start_time: float = 0.0, end_time: Union[float, None] = None, ) -> list: import cv2 from PIL import Image cap = cv2.VideoCapture(video_pth) fps = cap.get(cv2.CAP_PROP_FPS) if end_time is not None: start_frame = int(fps * start_time) end_frame = int(fps * end_time) vlen = end_frame - start_frame else: start_frame = 0 vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_idxs = sample_frames(self.frames_video, vlen) frame_idxs = [frame_idx + start_frame for frame_idx in frame_idxs] if self.frames_video != len(frame_idxs): frame_idxs = (frame_idxs * self.frames_video)[: self.frames_video] print(f"Video {video_pth} has less than {self.frames_video} frames") frames = [] for index in frame_idxs: cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb).convert("RGB")) cap.release() if len(frames) > 0: video_data = [self.transform(frame) for frame in frames] return video_data else: raise ValueError(f"video path: {video_pth} error.")