import os, sys import json import cv2 import math import shutil import numpy as np import random from PIL import Image import torch.nn.functional as F import torch import os.path as osp import time from moviepy.editor import VideoFileClip from torch.utils.data import Dataset # Import files from the local folder root_path = os.path.abspath('.') sys.path.append(root_path) from utils.img_utils import resize_with_antialiasing, numpy_to_pt from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian from data_loader.video_dataset import tokenize_captions # For the 2D dilation blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True) def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False): ''' Args: idx (int): The index to the folder we need to process ''' # Read file file_path = os.path.join(intput_dir, "data.txt") file1 = open(file_path, 'r') Lines = file1.readlines() # Initial the optical flow format we want thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty # Init the image sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg")) org_height, org_width, _ = sample_img.shape # Prepare masking controlnet_image_index = [] coordinate_values = [] # Iterate all points in the txt file for idx in range(len(Lines)): # Read points frame_idx, horizontal, vertical = Lines[idx].split(' ') frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal)) # Read the mask frame idx controlnet_image_index.append(frame_idx) coordinate_values.append((vertical, horizontal)) # Init the base image base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size base_img.fill(255) # Draw square around the target position dot_range = 10 # Diameter for i in range(-1*dot_range, dot_range+1): for j in range(-1*dot_range, dot_range+1): dil_vertical, dil_horizontal = vertical + i, horizontal + j if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]): if idx == 0: base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red else: base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point # Dilate if config["dilate"]: base_img = cv2.filter2D(base_img, -1, blur_kernel) ############################################################################################################################## ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store # Resize frames Don't use negative and don't resize in [0,1] base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) # Flip the image for aug if needed if flip: base_img = np.fliplr(base_img) # Channel Transform and Range Shift if config["conditioning_channels"] == 3: # Map to [0, 1] range if store_dir is not None and verbose: # For the first frame condition visualization cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img) base_img = base_img / 255.0 else: raise NotImplementedError() # ReOrganize shape base_img = base_img.transpose(2, 0, 1) # hwc -> chw # Check the min max value range # if verbose: # print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img))) # Write base img based on frame_idx thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized ############################################################################################################################## if config["motion_bucket_id"] is None: # take the motion to stats collected before reflected_motion_bucket_id = 200 else: reflected_motion_bucket_id = config["motion_bucket_id"] # print("Motion Bucket ID is ", reflected_motion_bucket_id) return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values) class Video_ThisThat_Dataset(Dataset): ''' Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow ''' def __init__(self, config, device, normalize=True, tokenizer=None): # Attribute variables self.config = config self.device = device self.normalize = normalize self.tokenizer = tokenizer # Obtain values self.video_seq_length = config["video_seq_length"] self.height = config["height"] self.width = config["width"] # Process data self.video_lists = [] for dataset_path in config["dataset_path"]: for video_name in sorted(os.listdir(dataset_path)): if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")): continue self.video_lists.append(os.path.join(dataset_path, video_name)) print("length of the dataset is ", len(self.video_lists)) def __len__(self): return len(self.video_lists) def _extract_frame_bridge(self, idx, flip=False): ''' Extract the frame in video based on the needed fps from already extracted frame Args: idx (int): The index to the file in the directory flip (bool): Bool for whether we will flip Returns: video_frames (numpy): Extracted video frames in numpy format ''' # Init the the Video Reader # The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to video_frame_path = self.video_lists[idx] # Find needed file needed_img_path = [] for idx in range(self.video_seq_length): img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg") needed_img_path.append(img_path) # Read all img_path based on the order video_frames = [] for img_path in needed_img_path: if not os.path.exists(img_path): print("We don't have ", img_path) frame = cv2.imread(img_path) try: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) except Exception: print("The exception place is ", img_path) # Resize frames frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC) # Flip aug if flip: frame = np.fliplr(frame) # Collect frames video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. # Concatenate video_frames = np.concatenate(video_frames, axis=0) assert(len(video_frames) == self.video_seq_length) # Returns return video_frames def __getitem__(self, idx): ''' Get item by idx and pre-process by Resize and Normalize to [0, 1] Args: idx (int): The index to the file in the directory Returns: return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1] ''' # Prepare the text if needed: if self.config["use_text"]: # Read the file file_path = os.path.join(self.video_lists[idx], "lang.txt") file = open(file_path, 'r') prompt = file.readlines()[0] # Only read the first line if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): # If we don't have this txt file, we skip ######################################################## Mix up prompt ######################################################## # Read the file file_path = os.path.join(self.video_lists[idx], "processed_text.txt") file = open(file_path, 'r') prompts = [line for line in file.readlines()] # Only read the first line # Get the componenet action = prompts[0][:-1] this = prompts[1][:-1] there = prompts[2][:-1] random_value = random.random() # If less than 0.4, we don't care, just use the most concrete one if random_value >= 0.4 and random_value < 0.6: # Mask pick object to "This" prompt = action + " this to " + there elif random_value >= 0.6 and random_value < 0.8: # Mask place position to "There" prompt = action + " " + this + " to there" elif random_value >= 0.8 and random_value < 1.0: # Just be like "this to there" prompt = action + " this to there" # print("New prompt is ", prompt) ################################################################################################################################################### # else: # print("We don't have llama processed prompt at ", self.video_lists[idx]) else: prompt = "" # Tokenize text prompt tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) flip = False if random.random() < self.config["flip_aug_prob"]: if self.config["use_text"]: if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) flip = True else: flip = True # Read frames for different dataset; Currently, we have WebVid / Bridge if self.config["dataset_name"] == "Bridge": video_frames_raw = self._extract_frame_bridge(idx, flip=flip) else: raise NotImplementedError("We don't support this dataset loader") # Scale [0, 255] -> [-1, 1] if needed if self.normalize: video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 # Transform to Pytorch Tensor in the range [-1, 1] video_frames = numpy_to_pt(video_frames) # Generate the pairs we need intput_dir = self.video_lists[idx] # Get the This That point information controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip) controlnet_condition = torch.from_numpy(controlnet_condition) # Cast other value to tensor reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32) controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32) coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32) # The tensor we returned is torch float32. We won't cast here for mixed precision training! return {"video_frames" : video_frames, "controlnet_condition" : controlnet_condition, "reflected_motion_bucket_id" : reflected_motion_bucket_id, "controlnet_image_index": controlnet_image_index, "prompt": tokenized_prompt, "coordinate_values": coordinate_values, # Useless now, but I still passed back }