File size: 4,295 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torchvision
import torch.nn as nn

from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder

from .ops import ScaleLayer


class IHModelWithBackbone(nn.Module):
    def __init__(
            self,
            model, backbone,
            downsize_backbone_input=False,
            mask_fusion='sum',
            backbone_conv1_channels=64, opt=None
    ):
        super(IHModelWithBackbone, self).__init__()
        self.downsize_backbone_input = downsize_backbone_input
        self.mask_fusion = mask_fusion

        self.backbone = backbone
        self.model = model
        self.opt = opt

        self.mask_conv = nn.Sequential(
            nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
            ScaleLayer(init_value=0.1, lr_mult=1)
        )

    def forward(self, image, mask, coord=None, start_proportion=None):
        if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
            backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
            backbone_mask = torch.cat(
                (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
                 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
        else:
            backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
            backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
                                       1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)

        backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
        backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)

        output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
        return output


class DeepImageHarmonization(nn.Module):
    def __init__(
            self,
            depth,
            norm_layer=nn.BatchNorm2d, batchnorm_from=0,
            attend_from=-1,
            image_fusion=False,
            ch=64, max_channels=512,
            backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
    ):
        super(DeepImageHarmonization, self).__init__()
        self.depth = depth
        self.encoder = ConvEncoder(
            depth, ch,
            norm_layer, batchnorm_from, max_channels,
            backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
        )
        self.opt = opt
        if opt.INRDecode:
            "See Table 2 in the paper to test with different INR decoders' structures."
            self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
        else:
            "Baseline: https://github.com/SamsungLabs/image_harmonization"
            self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)

    def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
        if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
            x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
                           torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
        else:
            x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
                           torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)

        intermediates = self.encoder(x, backbone_features)

        if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
            output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
        else:
            output = self.decoder(intermediates, image, mask)
        return output