File size: 4,906 Bytes
4730cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-06-09 14:59:55

import torch
import random
import numpy as np
from einops import rearrange

def batch_inpainging_from_grad(im_in, mask, gradx, grady):
    '''
    Recovering from gradient for batch data (torch tensro).
    Input:
        im_in: N x c x h x w, torch tensor, masked image
        mask: N x 1 x  h x w, torch tensor
        gradx, grady: N x c x h x w, torch tensor, image gradient
    '''
    im_out = torch.zeros_like(im_in.data)
    for ii in range(im_in.shape[0]):
        im_current, gradx_current, grady_current = [rearrange(x[ii,].cpu().numpy(), 'c h w -> h w c')
                                                    for x in [im_in, gradx, grady]]
        mask_current = mask[ii, 0,].cpu().numpy()
        out_current = inpainting_from_grad(im_current, mask_current, gradx_current, grady_current)
        im_out[ii,] = torch.from_numpy(rearrange(out_current, 'h w c -> c h w')).to(
                device=im_in.device,
                dtype=im_in.dtype
                )
    return im_out

def inpainting_from_grad(im_in, mask, gradx, grady):
    '''
    Input:
        im_in: h x w x c, masked image, numpy array
        mask: h x w, image mask, 1 represents missing value
        gradx: h x w x c, gradient along x-axis, numpy array
        grady: h x w x c, gradient along y-axis, numpy array
    Output:
        im_out: recoverd image
    '''
    h, w = im_in.shape[:2]
    counts_h = np.sum(1-mask, axis=0, keepdims=False)
    counts_w = np.sum(1-mask, axis=1, keepdims=False)
    if np.any(counts_h[1:-1,] == h):
        idx = find_first_index(counts_h[1:-1,], h) + 1
        im_out = fill_image_from_gradx(im_in, mask, gradx, idx)
    elif np.any(counts_w[1:-1,] == w):
        idx = find_first_index(counts_w[1:-1,], w) + 1
        im_out = inpainting_from_grad(im_in.T, mask.T, gradx.T, idx)
    else:
        idx = random.choices(list(range(1,w-1)), k=1, weights=counts_h[1:-1])[0]
        line = fill_line(im_in[:, idx, ], mask[:, idx,], grady[:, idx,])
        im_in[:, idx,] = line
        im_out = fill_image_from_gradx(im_in, mask, gradx, idx)
    if im_in.ndim > mask.ndim:
        mask = mask[:, :, None]
    im_out = im_in + im_out * mask
    return im_out

def fill_image_from_gradx(im_in, mask, gradx, idx):
    init = np.zeros_like(im_in)
    init[:, idx,] = im_in[:, idx,]
    right = np.cumsum(init[:, idx:-1, ] + gradx[:, idx+1:, ], axis=1)
    left = np.cumsum(
        init[:, idx:0:-1, ] - gradx[:, idx:0:-1, ],
        axis=1
        )[:, ::-1]
    center = im_in[:, idx, ][:, None]     # h x 1 x 3
    im_out = np.concatenate((left, center, right), axis=1)
    return im_out

def fill_line(xx, mm, grad):
    '''
    Fill one line from grad.
    Input:
        xx: n x c array, masked vector
        mm: (n,) array, mask, 1 represent missing value
        grad: (n,) array
    '''
    n = xx.shape[0]
    assert mm.sum() < n
    if mm.sum() == 0:
        return xx
    else:
        idx1 = find_first_index(mm, 1)
        if idx1 == 0:
            idx2 = find_first_index(mm, 0)
            subx = xx[idx2::-1,].copy()
            subgrad = grad[idx2::-1, ].copy()
            subx -= subgrad
            xx[:idx2,] = np.cumsum(subx, axis=0)[idx2-1::-1,]
            mm[idx1:idx2,] = 0
        else:
            idx2 = find_first_index(mm[idx1:,], 0) + idx1
            subx = xx[idx1-1:idx2-1,].copy()
            subgrad = grad[idx1:idx2,].copy()
            subx += subgrad
            xx[idx1:idx2,] = np.cumsum(subx, axis=0)
            mm[idx1:idx2,] = 0
        return fill_line(xx, mm, grad)

def find_first_index(mm, value):
    '''
    Input:
        mm: (n, ) array
        value: scalar
    '''
    try:
        out = next((idx for idx, val in np.ndenumerate(mm) if val == value))[0]
    except StopIteration:
        out = mm.shape[0]
    return out

if __name__ == '__main__':
    import sys
    from pathlib import Path
    sys.path.append(str(Path(__file__).resolve().parents[1]))
    from utils import util_image
    from datapipe.masks.train import process_mask

    # mask_file_names = [x for x in Path('../lama/LaMa_test_images').glob('*mask*.png')]
    mask_file_names = [x for x in Path('./testdata/inpainting/val/places/').glob('*mask*.png')]
    file_names = [x.parents[0]/(x.stem.rsplit('_mask',1)[0]+'.png') for x in  mask_file_names]

    for im_path, mask_path in zip(file_names, mask_file_names):
        im = util_image.imread(im_path, chn='rgb', dtype='float32')
        mask = process_mask(util_image.imread(mask_path, chn='rgb', dtype='float32')[:, :, 0])
        grad_dict = util_image.imgrad(im)

        im_masked = im * (1 - mask[:, :, None])
        im_recover = inpainting_from_grad(im_masked, mask, grad_dict['gradx'], grad_dict['grady'])
        error_max = np.abs(im_recover -im).max()
        print('Error Max: {:.2e}'.format(error_max))