Spaces:
Sleeping
Sleeping
import numpy as np | |
import shutil | |
import os | |
import argparse | |
import torch | |
import glob | |
from tqdm import tqdm | |
from PIL import Image | |
from collections import OrderedDict | |
from src.models.vit.config import load_config | |
import torchvision.transforms as transforms | |
import cv2 | |
from skimage import io | |
from src.models.CNN.ColorVidNet import GeneralColorVidNet | |
from src.models.vit.embed import GeneralEmbedModel | |
from src.models.CNN.NonlocalNet import GeneralWarpNet | |
from src.models.CNN.FrameColor import frame_colorization | |
from src.utils import ( | |
RGB2Lab, | |
ToTensor, | |
Normalize, | |
uncenter_l, | |
tensor_lab2rgb, | |
SquaredPadding, | |
UnpaddingSquare | |
) | |
def load_params(ckpt_file): | |
params = torch.load(ckpt_file) | |
new_params = [] | |
for key, value in params.items(): | |
new_params.append((key, value)) | |
return OrderedDict(new_params) | |
def custom_transform(transforms, img): | |
for transform in transforms: | |
if isinstance(transform, SquaredPadding): | |
img,padding=transform(img, return_paddings=True) | |
else: | |
img = transform(img) | |
return img.to(device), padding | |
def save_frames(predicted_rgb, video_name, frame_name): | |
if predicted_rgb is not None: | |
predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8) | |
io.imsave(os.path.join(args.output_video_path, video_name, frame_name), predicted_rgb) | |
def colorize_video(video_name): | |
frames_list = os.listdir(os.path.join(args.input_videos_path, video_name)) | |
frames_list.sort() | |
refs_list = os.listdir(os.path.join(args.reference_images_path, video_name)) | |
refs_list.sort() | |
for ref_path in refs_list: | |
frame_ref = Image.open(os.path.join(args.reference_images_path, video_name, ref_path)).convert("RGB") | |
I_last_lab_predict = None | |
IB_lab, IB_paddings = custom_transform(transforms, frame_ref) | |
IB_lab = IB_lab.unsqueeze(0).to(device) | |
IB_l = IB_lab[:, 0:1, :, :] | |
IB_ab = IB_lab[:, 1:3, :, :] | |
with torch.no_grad(): | |
I_reference_lab = IB_lab | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) | |
features_B = embed_net(I_reference_rgb) | |
for frame_name in frames_list: | |
curr_frame = Image.open(os.path.join(args.input_videos_path, video_name, frame_name)).convert("RGB") | |
IA_lab, IA_paddings = custom_transform(transforms, curr_frame) | |
IA_lab = IA_lab.unsqueeze(0).to(device) | |
IA_l = IA_lab[:, 0:1, :, :] | |
IA_ab = IA_lab[:, 1:3, :, :] | |
if I_last_lab_predict is None: | |
I_last_lab_predict = torch.zeros_like(IA_lab).to(device) | |
with torch.no_grad(): | |
I_current_lab = IA_lab | |
I_current_ab_predict, _, _ = frame_colorization( | |
I_current_lab, | |
I_reference_lab, | |
I_last_lab_predict, | |
features_B, | |
embed_net, | |
nonlocal_net, | |
colornet, | |
luminance_noise=0, | |
temperature=1e-10, | |
joint_training=False | |
) | |
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) | |
IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1)) | |
save_frames(IA_predict_rgb, video_name, frame_name) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Video Colorization') | |
parser.add_argument("--input_videos_path", type=str, help="path to input video") | |
parser.add_argument("--reference_images_path", type=str, help="path to reference image") | |
parser.add_argument("--output_video_path", type=str, help="path to output video") | |
parser.add_argument("--weight_path", type=str, default="checkpoints/epoch_5/", help="path to weight") | |
parser.add_argument("--device", type=str, default="cpu", help="device to run the model") | |
parser.add_argument("--high_resolution", action="store_true", help="use high resolution") | |
parser.add_argument("--wls_filter_on", action="store_true", help="use wls filter") | |
args = parser.parse_args() | |
device = torch.device(args.device) | |
if os.path.exists(args.output_video_path): | |
shutil.rmtree(args.output_video_path) | |
os.makedirs(args.output_video_path, exist_ok=True) | |
videos_list = os.listdir(args.input_videos_path) | |
embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device) | |
nonlocal_net = GeneralWarpNet(feature_channel=128).to(device) | |
colornet=GeneralColorVidNet(7).to(device) | |
embed_net.eval() | |
nonlocal_net.eval() | |
colornet.eval() | |
# Load weights | |
embed_net_params = load_params(os.path.join(args.weight_path, "embed_net.pth")) | |
nonlocal_net_params = load_params(os.path.join(args.weight_path, "nonlocal_net.pth")) | |
colornet_params = load_params(os.path.join(args.weight_path, "colornet.pth")) | |
embed_net.load_state_dict(embed_net_params, strict=True) | |
nonlocal_net.load_state_dict(nonlocal_net_params, strict=True) | |
colornet.load_state_dict(colornet_params, strict=True) | |
transforms = [SquaredPadding(target_size=224), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize()] | |
# center_padder = CenterPad((224,224)) | |
with torch.no_grad(): | |
for video_name in tqdm(videos_list): | |
colorize_video(video_name) |