File size: 3,971 Bytes
ba9144c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import cv2
import numpy as np

from .model import BiSeNet

mask_regions = {
    "Background":0,
    "Skin":1,
    "L-Eyebrow":2,
    "R-Eyebrow":3,
    "L-Eye":4,
    "R-Eye":5,
    "Eye-G":6,
    "L-Ear":7,
    "R-Ear":8,
    "Ear-R":9,
    "Nose":10,
    "Mouth":11,
    "U-Lip":12,
    "L-Lip":13,
    "Neck":14,
    "Neck-L":15,
    "Cloth":16,
    "Hair":17,
    "Hat":18
}

# Borrowed from simswap
# https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
class SoftErosion(nn.Module):
    def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
        super(SoftErosion, self).__init__()
        r = kernel_size // 2
        self.padding = r
        self.iterations = iterations
        self.threshold = threshold

        # Create kernel
        y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
        dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
        kernel = dist.max() - dist
        kernel /= kernel.sum()
        kernel = kernel.view(1, 1, *kernel.shape)
        self.register_buffer('weight', kernel)

    def forward(self, x):
        x = x.float()
        for i in range(self.iterations - 1):
            x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
        x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)

        mask = x >= self.threshold
        x[mask] = 1.0
        x[~mask] /= x[~mask].max()

        return x, mask

device = "cpu"

def init_parser(pth_path, mode="cpu"):
    global device
    device = mode
    n_classes = 19
    net = BiSeNet(n_classes=n_classes)
    if device == "cuda":
        net.cuda()
        net.load_state_dict(torch.load(pth_path))
    else:
        net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
    net.eval()
    return net


def image_to_parsing(img, net):
    img = cv2.resize(img, (512, 512))
    img = img[:,:,::-1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img = transform(img.copy())
    img = torch.unsqueeze(img, 0)

    with torch.no_grad():
        img = img.to(device)
        out = net(img)[0]
        parsing = out.squeeze(0).cpu().numpy().argmax(0)
        return parsing


def get_mask(parsing, classes):
    res = parsing == classes[0]
    for val in classes[1:]:
        res += parsing == val
    return res


def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
    parsing = image_to_parsing(source, net)

    if len(includes) == 0:
        return source, np.zeros_like(source)

    include_mask = get_mask(parsing, includes)
    mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")

    if smooth_mask is not None:
        mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
        face_mask_tensor = mask_tensor[0] + mask_tensor[1]
        soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
        soft_face_mask_tensor.squeeze_()
        mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)

    if blur > 0:
        mask = cv2.GaussianBlur(mask, (0, 0), blur)

    resized_source = cv2.resize((source).astype("float32"), (512, 512))
    resized_target = cv2.resize((target).astype("float32"), (512, 512))
    result = mask * resized_source + (1 - mask) * resized_target
    result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))

    return result

def mask_regions_to_list(values):
    out_ids = []
    for value in values:
        if value in mask_regions.keys():
            out_ids.append(mask_regions.get(value))
    return out_ids