File size: 8,311 Bytes
3d85088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
import numpy as np
import lpips
import torch
from pytorch_fid.fid_score import calculate_frechet_distance
from pytorch_fid.inception import InceptionV3
import torch.nn as nn
import cv2
from scipy import stats
import os

def calc_ssim(pred_image, gt_image):
    '''
    Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is
    caused by processing such as data compression or by losses in data transmission.
    
    # Arguments
        img1: PIL.Image
        img2: PIL.Image
    # Returns
        ssim: float (-1.0, 1.0)
    '''
    pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
    gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
    ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.)
    return ssim

def calc_psnr(pred_image, gt_image):
    '''
    Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal
    and the power of distorting noise that affects the quality of its representation.
    
    # Arguments
        img1: PIL.Image
        img2: PIL.Image
    # Returns
        psnr: float
    '''
    pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
    gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
    
    psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.)
    return psnr

class LPIPS_utils:
    def __init__(self, device = 'cuda'):
        self.loss_fn = lpips.LPIPS(net='vgg', spatial=True)  # Can set net = 'squeeze' or 'vgg'or 'alex'
        self.loss_fn = self.loss_fn.to(device)
        self.device = device
    
    def compare_lpips(self,img_fake, img_real, data_range=255.):         # input: torch 1 c h w    / h w c
        img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range)
        img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range)
        if img_fake.ndim==3:
            img_fake = img_fake.permute(2,0,1).unsqueeze(0)
            img_real = img_real.permute(2,0,1).unsqueeze(0)
        img_fake = img_fake.to(self.device)
        img_real = img_real.to(self.device)
        
        dist = self.loss_fn.forward(img_fake,img_real)
        return dist.mean().item()

class FID_utils(nn.Module):
    """Class for computing the Fréchet Inception Distance (FID) metric score.
    It is implemented as a class in order to hold the inception model instance
    in its state.
    Parameters
    ----------
    resize_input : bool (optional)
        Whether or not to resize the input images to the image size (299, 299)
        on which the inception model was trained. Since the model is fully
        convolutional, the score also works without resizing. In literature
        and when working with GANs people tend to set this value to True,
        however, for internal evaluation this is not necessary.
    device : str or torch.device
        The device on which to run the inception model.
    """

    def __init__(self, resize_input=True, device="cuda"):
        super(FID_utils, self).__init__()
        self.device = device
        if self.device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        #self.model = InceptionV3(resize_input=resize_input).to(device)
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.model = InceptionV3([block_idx]).to(device)
        self.model = self.model.eval()

    def get_activations(self,batch):                   # 1 c h w
        with torch.no_grad():
            pred = self.model(batch)[0]
        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            #pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
            print("error in get activations!")
        #pred = pred.squeeze(3).squeeze(2).cpu().numpy()
        return pred


    def _get_mu_sigma(self, batch,data_range):
        """Compute the inception mu and sigma for a batch of images.
        Parameters
        ----------
        images : np.ndarray
            A batch of images with shape (n_images,3, width, height).
        Returns
        -------
        mu : np.ndarray
            The array of mean activations with shape (2048,).
        sigma : np.ndarray
            The covariance matrix of activations with shape (2048, 2048).
        """
        # forward pass
        if batch.ndim ==3 and batch.shape[2]==3:
            batch=batch.permute(2,0,1).unsqueeze(0) 
        batch /= data_range           
        #batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1))
        batch = batch.to(self.device, torch.float32)
        #(activations,) = self.model(batch)
        activations = self.get_activations(batch)
        activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2)

        # compute statistics
        mu = np.mean(activations,axis=0)
        sigma = np.cov(activations, rowvar=False)

        return mu, sigma

    def score(self, images_1, images_2, data_range=255.):
        """Compute the FID score.
        The input batches should have the shape (n_images,3, width, height). or (h,w,3)
        Parameters
        ----------
        images_1 : np.ndarray
            First batch of images.
        images_2 : np.ndarray
            Section batch of images.
        Returns
        -------
        score : float
            The FID score.
        """
        images_1 = torch.from_numpy(np.array(images_1).astype(np.float32))
        images_2 = torch.from_numpy(np.array(images_2).astype(np.float32))
        images_1 = images_1.to(self.device)
        images_2 = images_2.to(self.device)
        
        mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range)
        mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range)
        score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2)

        return score

def JS_divergence(p, q):
    M = (p + q) / 2
    return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M)


def compute_JS_bgr(input_dir, dilation=1):
    input_img_list = os.listdir(input_dir)
    input_img_list.sort()
    # print(input_img_list)

    hist_b_list = []   # [img1_histb, img2_histb, ...]
    hist_g_list = []
    hist_r_list = []
    
    for img_name in input_img_list:
        # print(os.path.join(input_dir, img_name))
        img_in = cv2.imread(os.path.join(input_dir, img_name))
        H, W, C = img_in.shape
        
        hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B
        hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G
        hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R
        
        hist_b = hist_b / (H * W)
        hist_g = hist_g / (H * W)
        hist_r = hist_r / (H * W)
        
        hist_b_list.append(hist_b)
        hist_g_list.append(hist_g)
        hist_r_list.append(hist_r)
    
    JS_b_list = []
    JS_g_list = []
    JS_r_list = []
    
    for i in range(len(hist_b_list)):
        if i + dilation > len(hist_b_list) - 1:
            break
        hist_b_img1 = hist_b_list[i]
        hist_b_img2 = hist_b_list[i + dilation]     
        JS_b = JS_divergence(hist_b_img1, hist_b_img2)
        JS_b_list.append(JS_b)
        
        hist_g_img1 = hist_g_list[i]
        hist_g_img2 = hist_g_list[i+dilation]     
        JS_g = JS_divergence(hist_g_img1, hist_g_img2)
        JS_g_list.append(JS_g)
        
        hist_r_img1 = hist_r_list[i]
        hist_r_img2 = hist_r_list[i+dilation]     
        JS_r = JS_divergence(hist_r_img1, hist_r_img2)
        JS_r_list.append(JS_r)
        
    return JS_b_list, JS_g_list, JS_r_list


def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]):
    mean_b, mean_g, mean_r = 0, 0, 0
    for d, w in zip(dilation, weight):
        JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d)
        mean_b += w * np.mean(JS_b_list_one)
        mean_g += w * np.mean(JS_g_list_one)
        mean_r += w * np.mean(JS_r_list_one)
    
    cdc = np.mean([mean_b, mean_g, mean_r])
    return cdc