import numpy as np import shutil import os import argparse import torch import glob from tqdm import tqdm from PIL import Image from collections import OrderedDict from src.models.vit.config import load_config import torchvision.transforms as transforms import cv2 from skimage import io from src.models.CNN.ColorVidNet import GeneralColorVidNet from src.models.vit.embed import GeneralEmbedModel from src.models.CNN.NonlocalNet import GeneralWarpNet from src.models.CNN.FrameColor import frame_colorization from src.utils import ( RGB2Lab, ToTensor, Normalize, uncenter_l, tensor_lab2rgb, SquaredPadding, UnpaddingSquare ) def load_params(ckpt_file): params = torch.load(ckpt_file) new_params = [] for key, value in params.items(): new_params.append((key, value)) return OrderedDict(new_params) def custom_transform(transforms, img): for transform in transforms: if isinstance(transform, SquaredPadding): img,padding=transform(img, return_paddings=True) else: img = transform(img) return img.to(device), padding def save_frames(predicted_rgb, video_name, frame_name): if predicted_rgb is not None: predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8) io.imsave(os.path.join(args.output_video_path, video_name, frame_name), predicted_rgb) def colorize_video(video_name): frames_list = os.listdir(os.path.join(args.input_videos_path, video_name)) frames_list.sort() refs_list = os.listdir(os.path.join(args.reference_images_path, video_name)) refs_list.sort() for ref_path in refs_list: frame_ref = Image.open(os.path.join(args.reference_images_path, video_name, ref_path)).convert("RGB") I_last_lab_predict = None IB_lab, IB_paddings = custom_transform(transforms, frame_ref) IB_lab = IB_lab.unsqueeze(0).to(device) IB_l = IB_lab[:, 0:1, :, :] IB_ab = IB_lab[:, 1:3, :, :] with torch.no_grad(): I_reference_lab = IB_lab I_reference_l = I_reference_lab[:, 0:1, :, :] I_reference_ab = I_reference_lab[:, 1:3, :, :] I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) features_B = embed_net(I_reference_rgb) for frame_name in frames_list: curr_frame = Image.open(os.path.join(args.input_videos_path, video_name, frame_name)).convert("RGB") IA_lab, IA_paddings = custom_transform(transforms, curr_frame) IA_lab = IA_lab.unsqueeze(0).to(device) IA_l = IA_lab[:, 0:1, :, :] IA_ab = IA_lab[:, 1:3, :, :] if I_last_lab_predict is None: I_last_lab_predict = torch.zeros_like(IA_lab).to(device) with torch.no_grad(): I_current_lab = IA_lab I_current_ab_predict, _, _ = frame_colorization( I_current_lab, I_reference_lab, I_last_lab_predict, features_B, embed_net, nonlocal_net, colornet, luminance_noise=0, temperature=1e-10, joint_training=False ) I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1)) save_frames(IA_predict_rgb, video_name, frame_name) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Video Colorization') parser.add_argument("--input_videos_path", type=str, help="path to input video") parser.add_argument("--reference_images_path", type=str, help="path to reference image") parser.add_argument("--output_video_path", type=str, help="path to output video") parser.add_argument("--weight_path", type=str, default="checkpoints/epoch_5/", help="path to weight") parser.add_argument("--device", type=str, default="cpu", help="device to run the model") parser.add_argument("--high_resolution", action="store_true", help="use high resolution") parser.add_argument("--wls_filter_on", action="store_true", help="use wls filter") args = parser.parse_args() device = torch.device(args.device) if os.path.exists(args.output_video_path): shutil.rmtree(args.output_video_path) os.makedirs(args.output_video_path, exist_ok=True) videos_list = os.listdir(args.input_videos_path) embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device) nonlocal_net = GeneralWarpNet(feature_channel=128).to(device) colornet=GeneralColorVidNet(7).to(device) embed_net.eval() nonlocal_net.eval() colornet.eval() # Load weights embed_net_params = load_params(os.path.join(args.weight_path, "embed_net.pth")) nonlocal_net_params = load_params(os.path.join(args.weight_path, "nonlocal_net.pth")) colornet_params = load_params(os.path.join(args.weight_path, "colornet.pth")) embed_net.load_state_dict(embed_net_params, strict=True) nonlocal_net.load_state_dict(nonlocal_net_params, strict=True) colornet.load_state_dict(colornet_params, strict=True) transforms = [SquaredPadding(target_size=224), RGB2Lab(), ToTensor(), Normalize()] # center_padder = CenterPad((224,224)) with torch.no_grad(): for video_name in tqdm(videos_list): colorize_video(video_name)