Spaces:
Sleeping
Sleeping
from src.models.CNN.ColorVidNet import ColorVidNet | |
from src.models.vit.embed import SwinModel | |
from src.models.CNN.NonlocalNet import WarpNet | |
from src.models.CNN.FrameColor import frame_colorization | |
import torch | |
from src.models.vit.utils import load_params | |
import os | |
import cv2 | |
from PIL import Image | |
from PIL import ImageEnhance as IE | |
import torchvision.transforms as T | |
from src.utils import ( | |
RGB2Lab, | |
ToTensor, | |
Normalize, | |
uncenter_l, | |
tensor_lab2rgb | |
) | |
import numpy as np | |
class SwinTExCo: | |
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None): | |
if device == None: | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
else: | |
self.device = device | |
self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device) | |
self.nonlocal_net = WarpNet(feature_channel=128).to(self.device) | |
self.colornet = ColorVidNet(7).to(self.device) | |
self.embed_net.eval() | |
self.nonlocal_net.eval() | |
self.colornet.eval() | |
self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth")) | |
self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth")) | |
self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth")) | |
self.processor = T.Compose([ | |
T.Resize((224,224)), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize() | |
]) | |
pass | |
def __load_models(self, model, weight_path): | |
params = load_params(weight_path, self.device) | |
model.load_state_dict(params, strict=True) | |
def __preprocess_reference(self, img): | |
color_enhancer = IE.Color(img) | |
img = color_enhancer.enhance(1.5) | |
return img | |
def __upscale_image(self, large_IA_l, I_current_ab_predict): | |
H, W = large_IA_l.shape[2:] | |
large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict, | |
size=(H,W), | |
mode="bilinear", | |
align_corners=False) | |
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1) | |
large_current_rgb_predict = tensor_lab2rgb(large_IA_l) | |
return large_current_rgb_predict.cpu() | |
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B): | |
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0) | |
large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device) | |
IA_lab = self.processor(curr_frame) | |
IA_lab = IA_lab.unsqueeze(0).to(self.device) | |
IA_l = IA_lab[:, 0:1, :, :] | |
if I_last_lab_predict is None: | |
I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device) | |
with torch.no_grad(): | |
I_current_ab_predict, _ = frame_colorization( | |
IA_l, | |
I_reference_lab, | |
I_last_lab_predict, | |
features_B, | |
self.embed_net, | |
self.nonlocal_net, | |
self.colornet, | |
luminance_noise=0, | |
temperature=1e-10, | |
joint_training=False | |
) | |
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) | |
IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict) | |
IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.) | |
IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8) | |
return I_last_lab_predict, IA_predict_rgb | |
def predict_video(self, video, ref_image): | |
ref_image = self.__preprocess_reference(ref_image) | |
I_last_lab_predict = None | |
IB_lab = self.processor(ref_image) | |
IB_lab = IB_lab.unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
I_reference_lab = IB_lab | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) | |
features_B = self.embed_net(I_reference_rgb) | |
while video.isOpened(): | |
ret, curr_frame = video.read() | |
if not ret: | |
break | |
curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
curr_frame = Image.fromarray(curr_frame) | |
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) | |
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) | |
yield IA_predict_rgb | |
video.release() | |
def predict_image(self, image, ref_image): | |
ref_image = self.__preprocess_reference(ref_image) | |
I_last_lab_predict = None | |
IB_lab = self.processor(ref_image) | |
IB_lab = IB_lab.unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
I_reference_lab = IB_lab | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) | |
features_B = self.embed_net(I_reference_rgb) | |
curr_frame = image | |
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) | |
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) | |
return IA_predict_rgb | |
if __name__ == "__main__": | |
model = SwinTExCo('checkpoints/epoch_20/') | |
# Initialize video reader and writer | |
video = cv2.VideoCapture('sample_input/video_2.mp4') | |
fps = video.get(cv2.CAP_PROP_FPS) | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
# Initialize reference image | |
ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB') | |
for colorized_frame in model.predict_video(video, ref_image): | |
video_writer.write(colorized_frame) | |
video_writer.release() |