File size: 2,392 Bytes
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from .hourglass import HourGlass
from utils.dct import DCT_Lowfrequency
from utils.filters_tensor import bgr2gray

from collections import OrderedDict
import numpy as np


class Quantize(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.round()
        return y

    @staticmethod
    def backward(ctx, grad_output):
        inputX = ctx.saved_tensors
        return grad_output


class ResHalf(nn.Module):
    def __init__(self, train=True, warm_stage=False):
        super(ResHalf, self).__init__()
        self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4)
        self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4)
        self.dcter = DCT_Lowfrequency(size=256, fLimit=50)
        # quantize [-1,1] data to be {-1,1}
        self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
        self.isTrain = train
        if warm_stage:
            for name, param in self.decoder.named_parameters():
                param.requires_grad = False

    def add_impluse_noise(self, input_halfs, p=0.0):
        N,C,H,W = input_halfs.shape
        SNR = 1-p
        np_input_halfs = input_halfs.detach().to("cpu").numpy()
        np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1))
        for i in range(N):
            mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.])
            np_input_halfs[i, mask==1] = 1.0
            np_input_halfs[i, mask==2] = -1.0
        return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device)

    def forward(self, input_img, decoding_only=False):
        if decoding_only:
            halfResQ = self.quantizer(input_img)
            restored = self.decoder(halfResQ)
            return restored
            
        noise = torch.randn_like(input_img) * 0.3
        halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1))
        halfResQ = self.quantizer(halfRes)
        restored = self.decoder(halfResQ)
        if self.isTrain:
            halfDCT = self.dcter(halfRes / 2. + 0.5)
            refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5))
            return halfRes, halfDCT, refDCT, restored
        else:
            return halfRes, restored