Spaces:
Sleeping
Sleeping
import numpy as np | |
import shutil | |
import os | |
import argparse | |
import torch | |
import glob | |
from tqdm import tqdm | |
from PIL import Image | |
from collections import OrderedDict | |
from src.models.vit.config import load_config | |
import torchvision.transforms as transforms | |
import cv2 | |
from skimage import io | |
from src.models.CNN.ColorVidNet import GeneralColorVidNet | |
from src.models.vit.embed import GeneralEmbedModel | |
from src.models.CNN.NonlocalNet import GeneralWarpNet | |
from src.models.CNN.FrameColor import frame_colorization | |
from src.utils import ( | |
RGB2Lab, | |
ToTensor, | |
Normalize, | |
uncenter_l, | |
tensor_lab2rgb, | |
SquaredPadding, | |
UnpaddingSquare | |
) | |
import gradio as gr | |
def load_params(ckpt_file): | |
params = torch.load(ckpt_file, map_location=device) | |
new_params = [] | |
for key, value in params.items(): | |
new_params.append((key, value)) | |
return OrderedDict(new_params) | |
def custom_transform(transforms, img): | |
for transform in transforms: | |
if isinstance(transform, SquaredPadding): | |
img,padding=transform(img, return_paddings=True) | |
else: | |
img = transform(img) | |
return img.to(device), padding | |
def save_frames(predicted_rgb, video_name, frame_name): | |
if predicted_rgb is not None: | |
predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8) | |
# frame_path_parts = frame_path.split(os.sep) | |
# if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])): | |
# shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])) | |
# os.makedirs(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]), exist_ok=True) | |
predicted_rgb = np.transpose(predicted_rgb, (1,2,0)) | |
pil_img = Image.fromarray(predicted_rgb) | |
pil_img.save(os.path.join(OUTPUT_RESULT_PATH, video_name, frame_name)) | |
def extract_frames_from_video(video_path): | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# remove if exists folder | |
output_frames_path = os.path.join(INPUT_VIDEO_FRAMES_PATH, os.path.basename(video_path)) | |
if os.path.exists(output_frames_path): | |
shutil.rmtree(output_frames_path) | |
# make new folder | |
os.makedirs(output_frames_path) | |
currentframe = 0 | |
frame_path_list = [] | |
while(True): | |
# reading from frame | |
ret,frame = cap.read() | |
if ret: | |
name = os.path.join(output_frames_path, f'{currentframe:09d}.jpg') | |
frame_path_list.append(name) | |
cv2.imwrite(name, frame) | |
currentframe += 1 | |
else: | |
break | |
cap.release() | |
cv2.destroyAllWindows() | |
return frame_path_list, fps | |
def combine_frames_from_folder(frames_list_path, fps = 30): | |
frames_list = glob.glob(f'{frames_list_path}/*.jpg') | |
frames_list.sort() | |
sample_shape = cv2.imread(frames_list[0]).shape | |
output_video_path = os.path.join(frames_list_path, 'output_video.mp4') | |
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (sample_shape[1], sample_shape[0])) | |
for filename in frames_list: | |
img = cv2.imread(filename) | |
out.write(img) | |
out.release() | |
return output_video_path | |
def upscale_image(I_current_rgb, I_current_ab_predict): | |
H, W = I_current_rgb.size | |
high_lab_transforms = [ | |
SquaredPadding(target_size=max(H,W)), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize() | |
] | |
# current_frame_pil_rgb = Image.fromarray(np.clip(I_current_rgb.squeeze(0).permute(1,2,0).cpu().numpy() * 255, 0, 255).astype('uint8')) | |
high_lab_current, paddings = custom_transform(high_lab_transforms, I_current_rgb) | |
high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device) | |
high_l_current = high_lab_current[:, 0:1, :, :] | |
high_ab_current = high_lab_current[:, 1:3, :, :] | |
upsampler = torch.nn.Upsample(scale_factor=max(H,W)/224,mode="bilinear") | |
high_ab_predict = upsampler(I_current_ab_predict) | |
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1)) | |
upadded = UnpaddingSquare() | |
I_predict_rgb = upadded(I_predict_rgb, paddings) | |
return I_predict_rgb | |
def colorize_video(video_path, ref_np): | |
frames_list, fps = extract_frames_from_video(video_path) | |
frame_ref = Image.fromarray(ref_np).convert("RGB") | |
I_last_lab_predict = None | |
IB_lab, IB_paddings = custom_transform(transforms, frame_ref) | |
IB_lab = IB_lab.unsqueeze(0).to(device) | |
IB_l = IB_lab[:, 0:1, :, :] | |
IB_ab = IB_lab[:, 1:3, :, :] | |
with torch.no_grad(): | |
I_reference_lab = IB_lab | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) | |
features_B = embed_net(I_reference_rgb) | |
video_path_parts = frames_list[0].split(os.sep) | |
if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])): | |
shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])) | |
os.makedirs(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), exist_ok=True) | |
for frame_path in tqdm(frames_list): | |
curr_frame = Image.open(frame_path).convert("RGB") | |
IA_lab, IA_paddings = custom_transform(transforms, curr_frame) | |
IA_lab = IA_lab.unsqueeze(0).to(device) | |
IA_l = IA_lab[:, 0:1, :, :] | |
IA_ab = IA_lab[:, 1:3, :, :] | |
if I_last_lab_predict is None: | |
I_last_lab_predict = torch.zeros_like(IA_lab).to(device) | |
with torch.no_grad(): | |
I_current_lab = IA_lab | |
I_current_ab_predict, _ = frame_colorization( | |
IA_l, | |
I_reference_lab, | |
I_last_lab_predict, | |
features_B, | |
embed_net, | |
nonlocal_net, | |
colornet, | |
luminance_noise=0, | |
temperature=1e-10, | |
joint_training=False | |
) | |
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) | |
# IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1)) | |
IA_predict_rgb = upscale_image(curr_frame, I_current_ab_predict) | |
#IA_predict_rgb = torch.nn.functional.upsample_bilinear(IA_predict_rgb, scale_factor=2) | |
save_frames(IA_predict_rgb.squeeze(0).cpu().numpy() * 255, video_path_parts[-2], os.path.basename(frame_path)) | |
return combine_frames_from_folder(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), fps) | |
if __name__ == '__main__': | |
# Init global variables | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
INPUT_VIDEO_FRAMES_PATH = 'inputs' | |
OUTPUT_RESULT_PATH = 'outputs' | |
weight_path = 'checkpoints' | |
embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device) | |
nonlocal_net = GeneralWarpNet(feature_channel=128).to(device) | |
colornet=GeneralColorVidNet(7).to(device) | |
embed_net.eval() | |
nonlocal_net.eval() | |
colornet.eval() | |
# Load weights | |
# embed_net_params = load_params(os.path.join(weight_path, "embed_net.pth")) | |
nonlocal_net_params = load_params(os.path.join(weight_path, "nonlocal_net.pth")) | |
colornet_params = load_params(os.path.join(weight_path, "colornet.pth")) | |
# embed_net.load_state_dict(embed_net_params, strict=True) | |
nonlocal_net.load_state_dict(nonlocal_net_params, strict=True) | |
colornet.load_state_dict(colornet_params, strict=True) | |
transforms = [SquaredPadding(target_size=224), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize()] | |
#examples = [[vid, ref] for vid, ref in zip(sorted(glob.glob('examples/*/*.mp4')), sorted(glob.glob('examples/*/*.jpg')))] | |
demo = gr.Interface(colorize_video, | |
inputs=[gr.Video(), gr.Image()], | |
outputs="playable_video")#, | |
#examples=examples, | |
#cache_examples=True) | |
demo.launch() | |