from torchvision.transforms import transforms, ToTensor from torchvision.transforms import Resize from torch.cuda.amp import autocast import torch.nn.functional as F from PIL import Image import gradio as gr import subprocess import os import torch import cv2 from model import UNet from frames import extract_frames device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def save_frames(tensor, out_path) -> None: image = normalize_frames(tensor) image = Image.fromarray(image) image.save(out_path) def normalize_frames(tensor): tensor = tensor.squeeze(0).detach().cpu() tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1] tensor = (tensor * 255).byte() # Scale to [0, 255] tensor = tensor.permute( 1, 2, 0 ).numpy() # Convert to [H, W, C] height width channels return tensor def laod_allframes(frame_dir): frames_path = sorted( [ os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith(".png") ], key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1]), ) print(frames_path) for frame_path in frames_path: yield load_frames(frame_path) def load_frames(image_path) -> torch.Tensor: """ Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU :params image_path :return: pytorch tensor """ transform = transforms.Compose([Resize((720, 1280)), ToTensor()]) img = Image.open(image_path).convert("RGB") tensor = transform(img).unsqueeze(0).to(device) return tensor def time_steps(input_fps, output_fps) -> list[float]: """ Generates Time intervals to interpolate between frames A and B :param input_fps: Video FPS(Original) :param output_fps: Target FPS(Output) :return: List of intermediate FPS required between 2 Frames A and B """ if output_fps <= input_fps: return [] k = output_fps // input_fps n = k - 1 return [i / (n + 1) for i in range(1, n + 1)] def interpolate_video(frames_dir, model_fc, input_fps, ouput_fps, output_dir): os.makedirs(output_dir, exist_ok=True) count = 0 iterator = laod_allframes(frames_dir) try: prev_frame = next(iterator) for curr_frame in iterator: interpolated_frames = interpolate( model_fc, prev_frame, curr_frame, input_fps, ouput_fps ) save_frames( prev_frame, os.path.join(output_dir, "frame_{}.png".format(count)) ) count += 1 for frame in interpolated_frames: save_frames( frame[:, :3, :, :], os.path.join(output_dir, "frame_{}.png".format(count)), ) count += 1 prev_frame = curr_frame save_frames(prev_frame, os.path.join(output_dir, "frame_{}.png".format(count))) except StopIteration: print("no more Frames") def interpolate(model_FC, A, B, input_fps, output_fps) -> list[torch.Tensor]: interval = time_steps(input_fps, output_fps) input_tensor = torch.cat( (A, B), dim=1 ) # Concatenate Frame A and B to Compare difference with torch.no_grad(): flow_output = model_FC(input_tensor) flow_forward = flow_output[:, :2, :, :] # Forward flow flow_backward = flow_output[:, 2:4, :, :] # Backward flow generated_frames = [] with torch.no_grad(): for t in interval: t_tensor = ( torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device) ) with autocast(): warped_A = warp_frames(A, flow_forward * t_tensor) warped_B = warp_frames(B, flow_backward * (1 - t_tensor)) interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor generated_frames.append(interpolated_frame) return generated_frames def warp_frames(frame, flow): b, c, h, w = frame.size() i, j, flow_h, flow_w = flow.size() if h != flow_h or w != flow_w: frame = F.interpolate( frame, size=(flow_h, flow_w), mode="bilinear", align_corners=True ) grid_y, grid_x = torch.meshgrid( torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij" ) grid_x = grid_x.float().to(device) grid_y = grid_y.float().to(device) flow_x = flow[:, 0, :, :] flow_y = flow[:, 1, :, :] x = grid_x.unsqueeze(0) + flow_x y = grid_y.unsqueeze(0) + flow_y x = 2.0 * x / (flow_w - 1) - 1.0 y = 2.0 * y / (flow_h - 1) - 1.0 grid = torch.stack((x, y), dim=-1) warped_frame = F.grid_sample( frame, grid, align_corners=True, mode="bilinear", padding_mode="border" ) return warped_frame def frames_to_video(frame_dir, output_video, fps): frame_files = sorted( [f for f in os.listdir(frame_dir) if f.endswith(".png")], key=lambda x: int(os.path.splitext(x)[0].split("_")[-1]), ) print(frame_files) for i, frame in enumerate(frame_files): os.rename( os.path.join(frame_dir, frame), os.path.join(frame_dir, f"frame_{i}.png") ) frame_pattern = os.path.join(frame_dir, "frame_%d.png") subprocess.run( [ # run shell command "ffmpeg", "-framerate", str(fps), "-i", frame_pattern, "-c:v", "libx264", "-pix_fmt", "yuv420p", "-y", output_video, ], check=True, ) # def solve(): # checkpoint = torch.load("SuperSloMo.ckpt") # model_FC = UNet(6, 4).to(device) # Initialize flow computation model # model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights # model_FC.eval() # model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model # model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights # model_AT.eval() # frames_dir = "output" # input_fps = 59 # output_fps = 120 # output_dir = "interpolated_frames2" # interpolate_video(frames_dir, model_FC, input_fps, output_fps, output_dir) # final_video = "result6.mp4" # frames_to_video(output_dir, final_video, output_fps) # def main(): # solve() # if __name__ == "__main__": # main() def process_video(video_path, output_fps): # Ensure the output directory for frames exists input_fps = extract_frames(video_path, "output_frames") # Load model model_FC = UNet(6, 4).to(device) checkpoint = torch.load("SuperSloMo.ckpt", map_location=device) model_FC.load_state_dict(checkpoint["state_dictFC"]) model_FC.eval() # Interpolate video output_dir = "interpolated_frames" interpolate_video("output_frames", model_FC, input_fps, output_fps, output_dir) # Generate output video final_video_path = "result.mp4" frames_to_video(output_dir, final_video_path, output_fps) return final_video_path # Return the output video file path interface = gr.Interface( fn=process_video, inputs=[ gr.Video(label="Upload Input Video"), gr.Slider(minimum=30, maximum=120, step=1, value=60, label="Desired Output FPS"), ], outputs=gr.File(label="Download Enhanced Video"), # Change output to File title="Video Frame Interpolation with SuperSloMo", description="This application allows you to input a video and increase its frame rate by interpolation using a deep learning model.", ) if __name__ == "__main__": interface.launch()