File size: 1,871 Bytes
3d85088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from src.utils import *
from src.models.vit.vit import FeatureTransform


def warp_color(

    IA_l,

    IB_lab,

    features_B,

    embed_net,

    nonlocal_net,

    temperature=0.01,

):
    IA_rgb_from_gray = gray2rgb_batch(IA_l)

    with torch.no_grad():
        A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray)
        B_feat0, B_feat1, B_feat2, B_feat3 = features_B

    A_feat0 = feature_normalize(A_feat0)
    A_feat1 = feature_normalize(A_feat1)
    A_feat2 = feature_normalize(A_feat2)
    A_feat3 = feature_normalize(A_feat3)

    B_feat0 = feature_normalize(B_feat0)
    B_feat1 = feature_normalize(B_feat1)
    B_feat2 = feature_normalize(B_feat2)
    B_feat3 = feature_normalize(B_feat3)

    return nonlocal_net(
        IB_lab,
        A_feat0,
        A_feat1,
        A_feat2,
        A_feat3,
        B_feat0,
        B_feat1,
        B_feat2,
        B_feat3,
        temperature=temperature,
    )


def frame_colorization(

    IA_l,

    IB_lab,

    IA_last_lab,

    features_B,

    embed_net,

    nonlocal_net,

    colornet,

    joint_training=True,

    luminance_noise=0,

    temperature=0.01,

):
    if luminance_noise:
        IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise

    with torch.autograd.set_grad_enabled(joint_training):
        nonlocal_BA_lab, similarity_map = warp_color(
            IA_l,
            IB_lab,
            features_B,
            embed_net,
            nonlocal_net,
            temperature=temperature,
        )
        nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
        IA_ab_predict = colornet(
            torch.cat(
                (IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab),
                dim=1,
            )
        )

    return IA_ab_predict, nonlocal_BA_lab