Spaces:
Sleeping
Sleeping
File size: 6,620 Bytes
3d85088 da5e78e 3d85088 da5e78e 3d85088 da5e78e 3d85088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from src.models.CNN.ColorVidNet import ColorVidNet
from src.models.vit.embed import SwinModel
from src.models.CNN.NonlocalNet import WarpNet
from src.models.CNN.FrameColor import frame_colorization
import torch
from src.models.vit.utils import load_params
import os
import cv2
from PIL import Image
from PIL import ImageEnhance as IE
import torchvision.transforms as T
from src.utils import (
RGB2Lab,
ToTensor,
Normalize,
uncenter_l,
tensor_lab2rgb
)
import numpy as np
class SwinTExCo:
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
if device == None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device)
self.nonlocal_net = WarpNet(feature_channel=128).to(self.device)
self.colornet = ColorVidNet(7).to(self.device)
self.embed_net.eval()
self.nonlocal_net.eval()
self.colornet.eval()
self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth"))
self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth"))
self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth"))
self.processor = T.Compose([
T.Resize((224,224)),
RGB2Lab(),
ToTensor(),
Normalize()
])
pass
def __load_models(self, model, weight_path):
params = load_params(weight_path, self.device)
model.load_state_dict(params, strict=True)
def __preprocess_reference(self, img):
color_enhancer = IE.Color(img)
img = color_enhancer.enhance(1.5)
return img
def __upscale_image(self, large_IA_l, I_current_ab_predict):
H, W = large_IA_l.shape[2:]
large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict,
size=(H,W),
mode="bilinear",
align_corners=False)
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1)
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
return large_current_rgb_predict.cpu()
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device)
IA_lab = self.processor(curr_frame)
IA_lab = IA_lab.unsqueeze(0).to(self.device)
IA_l = IA_lab[:, 0:1, :, :]
if I_last_lab_predict is None:
I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device)
with torch.no_grad():
I_current_ab_predict, _ = frame_colorization(
IA_l,
I_reference_lab,
I_last_lab_predict,
features_B,
self.embed_net,
self.nonlocal_net,
self.colornet,
luminance_noise=0,
temperature=1e-10,
joint_training=False
)
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict)
IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.)
IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8)
return I_last_lab_predict, IA_predict_rgb
def predict_video(self, video, ref_image):
ref_image = self.__preprocess_reference(ref_image)
I_last_lab_predict = None
IB_lab = self.processor(ref_image)
IB_lab = IB_lab.unsqueeze(0).to(self.device)
with torch.no_grad():
I_reference_lab = IB_lab
I_reference_l = I_reference_lab[:, 0:1, :, :]
I_reference_ab = I_reference_lab[:, 1:3, :, :]
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
features_B = self.embed_net(I_reference_rgb)
while video.isOpened():
ret, curr_frame = video.read()
if not ret:
break
curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
curr_frame = Image.fromarray(curr_frame)
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
yield IA_predict_rgb
video.release()
def predict_image(self, image, ref_image):
ref_image = self.__preprocess_reference(ref_image)
I_last_lab_predict = None
IB_lab = self.processor(ref_image)
IB_lab = IB_lab.unsqueeze(0).to(self.device)
with torch.no_grad():
I_reference_lab = IB_lab
I_reference_l = I_reference_lab[:, 0:1, :, :]
I_reference_ab = I_reference_lab[:, 1:3, :, :]
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
features_B = self.embed_net(I_reference_rgb)
curr_frame = image
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
return IA_predict_rgb
if __name__ == "__main__":
model = SwinTExCo('checkpoints/epoch_20/')
# Initialize video reader and writer
video = cv2.VideoCapture('sample_input/video_2.mp4')
fps = video.get(cv2.CAP_PROP_FPS)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
# Initialize reference image
ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB')
for colorized_frame in model.predict_video(video, ref_image):
video_writer.write(colorized_frame)
video_writer.release() |