File size: 4,103 Bytes
2e34814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torchvision.transforms as transforms
import criteria.deeplab as deeplab
import PIL.Image as Image
import torch.nn as nn
import torch.nn.functional as F
from configs import paths_config, global_config
import numpy as np


class Mask(nn.Module):
    def __init__(self):
        """

        |  Class     | Number | Class | Number |
        |------------|--------|-------|--------|
        | background |  0     | mouth |  10    |
        | skin       |  1     | u_lip |  11    |
        | nose       |  2     | l_lip |  12    |
        | eye_g      |  3     | hair  |  13    |
        | l_eye      |  4     | hat   |  14    |
        | r_eye      |  5     | ear_r |  15    |
        | l_brow     |  6     | neck_l|  16    |
        | r_brow     |  7     | neck  |  17    |
        | l_ear      |  8     | cloth |  18    |
        | r_ear      |  9     |

        """
        super().__init__()

        self.seg_model = (
            getattr(deeplab, "resnet101")(
                path=paths_config.deeplab,
                pretrained=True,
                num_classes=19,
                num_groups=32,
                weight_std=True,
                beta=False,
            )
            .eval()
            .requires_grad_(False)
        )

        ckpt = torch.load(paths_config.deeplab, map_location=global_config.device)
        state_dict = {
            k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k
        }
        self.seg_model.load_state_dict(state_dict)
        self.seg_model = self.seg_model.to(global_config.device)

        self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17]
        self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device)

    def get_labels(self, img):
        """Returns a mask from an input image"""
        data_transforms = transforms.Compose(
            [
                transforms.Resize((513, 513)),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        img = data_transforms(img)
        with torch.no_grad():
            out = self.seg_model(img)
        _, label = torch.max(out, 1)
        label = label.unsqueeze(0).type(torch.float32)

        label = (
            F.interpolate(label, size=(256, 256), mode="nearest")
            .squeeze()
            .type(torch.LongTensor)
        )
        return label

    def get_mask(self, label):
        mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float)
        for idx in self.labels:
            mask[label == idx] = 1

        # smooth the mask with a mean convolution
        """mask = (
            1
            - torch.clamp(
                torch.nn.functional.conv2d(
                    1 - mask[None, None, :, :], self.kernel, padding="same"
                ),
                0,
                1,
            ).squeeze()
        )"""
        """ mask = torch.clamp(
            torch.nn.functional.conv2d(
                mask[None, None, :, :], self.kernel, padding="same"
            ),
            0,
            1,
        ).squeeze()"""
        mask[label == 13] = 0.1
        return mask

    def forward(self, real_imgs, generated_imgs):
        #return real_imgs, generated_imgs
        label = self.get_labels(real_imgs)
        mask = self.get_mask(label)
        real_imgs = real_imgs * mask
        generated_imgs = generated_imgs * mask

        """out = (real_imgs * mask).squeeze().detach()

        out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
        Image.fromarray(out.cpu().numpy()).save("real_mask.png")

        out = (generated_imgs).squeeze().detach()

        out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
        Image.fromarray(out.cpu().numpy()).save("generated_mask.png")

        mask = (mask).squeeze().detach()
        mask = mask.repeat(3, 1, 1)
        mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
        Image.fromarray(mask.cpu().numpy()).save("mask.png")"""

        return real_imgs, generated_imgs