SwinTExCo / src /models /CNN /FrameColor.py
duongttr's picture
Update new app
3d85088
raw
history blame
1.87 kB
import torch
from src.utils import *
from src.models.vit.vit import FeatureTransform
def warp_color(
IA_l,
IB_lab,
features_B,
embed_net,
nonlocal_net,
temperature=0.01,
):
IA_rgb_from_gray = gray2rgb_batch(IA_l)
with torch.no_grad():
A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray)
B_feat0, B_feat1, B_feat2, B_feat3 = features_B
A_feat0 = feature_normalize(A_feat0)
A_feat1 = feature_normalize(A_feat1)
A_feat2 = feature_normalize(A_feat2)
A_feat3 = feature_normalize(A_feat3)
B_feat0 = feature_normalize(B_feat0)
B_feat1 = feature_normalize(B_feat1)
B_feat2 = feature_normalize(B_feat2)
B_feat3 = feature_normalize(B_feat3)
return nonlocal_net(
IB_lab,
A_feat0,
A_feat1,
A_feat2,
A_feat3,
B_feat0,
B_feat1,
B_feat2,
B_feat3,
temperature=temperature,
)
def frame_colorization(
IA_l,
IB_lab,
IA_last_lab,
features_B,
embed_net,
nonlocal_net,
colornet,
joint_training=True,
luminance_noise=0,
temperature=0.01,
):
if luminance_noise:
IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise
with torch.autograd.set_grad_enabled(joint_training):
nonlocal_BA_lab, similarity_map = warp_color(
IA_l,
IB_lab,
features_B,
embed_net,
nonlocal_net,
temperature=temperature,
)
nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
IA_ab_predict = colornet(
torch.cat(
(IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab),
dim=1,
)
)
return IA_ab_predict, nonlocal_BA_lab