File size: 6,257 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""This code is refer from:
https://github.com/Canjie-Luo/MORAN_v2
"""

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F


class MORN(nn.Module):

    def __init__(self, in_channels, target_shape=[32, 100], enhance=1):
        super(MORN, self).__init__()
        self.targetH = target_shape[0]
        self.targetW = target_shape[1]
        self.enhance = enhance
        self.out_channels = in_channels
        self.cnn = nn.Sequential(nn.MaxPool2d(2, 2),
                                 nn.Conv2d(in_channels, 64, 3, 1, 1),
                                 nn.BatchNorm2d(64), nn.ReLU(True),
                                 nn.MaxPool2d(2,
                                              2), nn.Conv2d(64, 128, 3, 1, 1),
                                 nn.BatchNorm2d(128), nn.ReLU(True),
                                 nn.MaxPool2d(2,
                                              2), nn.Conv2d(128, 64, 3, 1, 1),
                                 nn.BatchNorm2d(64), nn.ReLU(True),
                                 nn.Conv2d(64, 16, 3, 1, 1),
                                 nn.BatchNorm2d(16), nn.ReLU(True),
                                 nn.Conv2d(16, 1, 3, 1, 1), nn.BatchNorm2d(1))
        self.pool = nn.MaxPool2d(2, 1)
        h_list = np.arange(self.targetH) * 2. / (self.targetH - 1) - 1
        w_list = np.arange(self.targetW) * 2. / (self.targetW - 1) - 1
        grid = np.meshgrid(w_list, h_list, indexing='ij')
        grid = np.stack(grid, axis=-1)
        grid = np.transpose(grid, (1, 0, 2))
        grid = np.expand_dims(grid, 0)
        self.grid = nn.Parameter(
            torch.from_numpy(grid).float(),
            requires_grad=False,
        )

    def forward(self, x):

        bs = x.shape[0]
        grid = self.grid.tile([bs, 1, 1, 1])
        grid_x = self.grid[:, :, :, 0].unsqueeze(3).tile([bs, 1, 1, 1])
        grid_y = self.grid[:, :, :, 1].unsqueeze(3).tile([bs, 1, 1, 1])
        x_small = F.upsample(x,
                             size=(self.targetH, self.targetW),
                             mode='bilinear')

        offsets = self.cnn(x_small)
        offsets_posi = F.relu(offsets, inplace=False)
        offsets_nega = F.relu(-offsets, inplace=False)
        offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)

        offsets_grid = F.grid_sample(offsets_pool, grid)
        offsets_grid = offsets_grid.permute(0, 2, 3, 1).contiguous()
        offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
        x_rectified = F.grid_sample(x, offsets_x)

        for iteration in range(self.enhance):
            offsets = self.cnn(x_rectified)

            offsets_posi = F.relu(offsets, inplace=False)
            offsets_nega = F.relu(-offsets, inplace=False)
            offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega)

            offsets_grid += F.grid_sample(offsets_pool,
                                          grid).permute(0, 2, 3,
                                                        1).contiguous()
            offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3)
            x_rectified = F.grid_sample(x, offsets_x)

        # if debug:

        #     offsets_mean = torch.mean(offsets_grid.view(x.size(0), -1), 1)
        #     offsets_max, _ = torch.max(offsets_grid.view(x.size(0), -1), 1)
        #     offsets_min, _ = torch.min(offsets_grid.view(x.size(0), -1), 1)

        #     import matplotlib.pyplot as plt
        #     from colour import Color
        #     from torchvision import transforms
        #     import cv2

        #     alpha = 0.7
        #     density_range = 256
        #     color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
        #     cmap = plt.get_cmap("rainbow")
        #     blue = Color("blue")
        #     hex_colors = list(blue.range_to(Color("red"), density_range))
        #     rgb_colors = [[rgb * 255 for rgb in color.rgb] for color in hex_colors][::-1]
        #     to_pil_image = transforms.ToPILImage()

        #     for i in range(x.size(0)):

        #         img_small = x_small[i].data.cpu().mul_(0.5).add_(0.5)
        #         img = to_pil_image(img_small)
        #         img = np.array(img)
        #         if len(img.shape) == 2:
        #             img = cv2.merge([img.copy()]*3)
        #         img_copy = img.copy()

        #         v_max = offsets_max.data[i]
        #         v_min = offsets_min.data[i]
        #         if self.cuda:
        #             img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cuda().add_(-v_min).mul_(1./(v_max-v_min))
        #         else:
        #             img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).data.cpu().add_(-v_min).mul_(1./(v_max-v_min))
        #         img_offsets = to_pil_image(img_offsets)
        #         img_offsets = np.array(img_offsets)
        #         color_map = np.empty([self.targetH, self.targetW, 3], dtype=int)
        #         for h_i in range(self.targetH):
        #             for w_i in range(self.targetW):
        #                 color_map[h_i][w_i] = rgb_colors[int(img_offsets[h_i, w_i]/256.*density_range)]
        #         color_map = color_map.astype(np.uint8)
        #         cv2.addWeighted(color_map, alpha, img_copy, 1-alpha, 0, img_copy)

        #         img_processed = x_rectified[i].data.cpu().mul_(0.5).add_(0.5)
        #         img_processed = to_pil_image(img_processed)
        #         img_processed = np.array(img_processed)
        #         if len(img_processed.shape) == 2:
        #             img_processed = cv2.merge([img_processed.copy()]*3)

        #         total_img = np.ones([self.targetH, self.targetW*3+10, 3], dtype=int)*255
        #         total_img[0:self.targetH, 0:self.targetW] = img
        #         total_img[0:self.targetH, self.targetW+5:2*self.targetW+5] = img_copy
        #         total_img[0:self.targetH, self.targetW*2+10:3*self.targetW+10] = img_processed
        #         total_img = cv2.resize(total_img.astype(np.uint8), (300, 50))
        #         # cv2.imshow("Input_Offsets_Output", total_img)
        #         # cv2.waitKey()

        #     return x_rectified, total_img

        return x_rectified