# '''
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
# '''
#
# import torch
# import torch.jit
# import torch.nn.functional as F
#
#
# @torch.jit.script
# def create_window(window_size: int, sigma: float, channel: int):
#     '''
#     Create 1-D gauss kernel
#     :param window_size: the size of gauss kernel
#     :param sigma: sigma of normal distribution
#     :param channel: input channel
#     :return: 1D kernel
#     '''
#     coords = torch.arange(window_size, dtype=torch.float)
#     coords -= window_size // 2
#
#     g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
#     g /= g.sum()
#
#     g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
#     return g
#
#
# @torch.jit.script
# def _gaussian_filter(x, window_1d, use_padding: bool):
#     '''
#     Blur input with 1-D kernel
#     :param x: batch of tensors to be blured
#     :param window_1d: 1-D gauss kernel
#     :param use_padding: padding image before conv
#     :return: blured tensors
#     '''
#     C = x.shape[1]
#     padding = 0
#     if use_padding:
#         window_size = window_1d.shape[3]
#         padding = window_size // 2
#     out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
#     out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
#     return out
#
#
# @torch.jit.script
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
#     '''
#     Calculate ssim index for X and Y
#     :param X: images [B, C, H, N_bins]
#     :param Y: images [B, C, H, N_bins]
#     :param window: 1-D gauss kernel
#     :param data_range: value range of input images. (usually 1.0 or 255)
#     :param use_padding: padding image before conv
#     :return:
#     '''
#
#     K1 = 0.01
#     K2 = 0.03
#     compensation = 1.0
#
#     C1 = (K1 * data_range) ** 2
#     C2 = (K2 * data_range) ** 2
#
#     mu1 = _gaussian_filter(X, window, use_padding)
#     mu2 = _gaussian_filter(Y, window, use_padding)
#     sigma1_sq = _gaussian_filter(X * X, window, use_padding)
#     sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
#     sigma12 = _gaussian_filter(X * Y, window, use_padding)
#
#     mu1_sq = mu1.pow(2)
#     mu2_sq = mu2.pow(2)
#     mu1_mu2 = mu1 * mu2
#
#     sigma1_sq = compensation * (sigma1_sq - mu1_sq)
#     sigma2_sq = compensation * (sigma2_sq - mu2_sq)
#     sigma12 = compensation * (sigma12 - mu1_mu2)
#
#     cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
#     # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
#     cs_map = cs_map.clamp_min(0.)
#     ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
#
#     ssim_val = ssim_map.mean(dim=(1, 2, 3))  # reduce along CHW
#     cs = cs_map.mean(dim=(1, 2, 3))
#
#     return ssim_val, cs
#
#
# @torch.jit.script
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
#     '''
#     interface of ms-ssim
#     :param X: a batch of images, (N,C,H,W)
#     :param Y: a batch of images, (N,C,H,W)
#     :param window: 1-D gauss kernel
#     :param data_range: value range of input images. (usually 1.0 or 255)
#     :param weights: weights for different levels
#     :param use_padding: padding image before conv
#     :param eps: use for avoid grad nan.
#     :return:
#     '''
#     levels = weights.shape[0]
#     cs_vals = []
#     ssim_vals = []
#     for _ in range(levels):
#         ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
#         # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
#         ssim_val = ssim_val.clamp_min(eps)
#         cs = cs.clamp_min(eps)
#         cs_vals.append(cs)
#
#         ssim_vals.append(ssim_val)
#         padding = (X.shape[2] % 2, X.shape[3] % 2)
#         X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
#         Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
#
#     cs_vals = torch.stack(cs_vals, dim=0)
#     ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
#     return ms_ssim_val
#
#
# class SSIM(torch.jit.ScriptModule):
#     __constants__ = ['data_range', 'use_padding']
#
#     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
#         '''
#         :param window_size: the size of gauss kernel
#         :param window_sigma: sigma of normal distribution
#         :param data_range: value range of input images. (usually 1.0 or 255)
#         :param channel: input channels (default: 3)
#         :param use_padding: padding image before conv
#         '''
#         super().__init__()
#         assert window_size % 2 == 1, 'Window size must be odd.'
#         window = create_window(window_size, window_sigma, channel)
#         self.register_buffer('window', window)
#         self.data_range = data_range
#         self.use_padding = use_padding
#
#     @torch.jit.script_method
#     def forward(self, X, Y):
#         r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
#         return r[0]
#
#
# class MS_SSIM(torch.jit.ScriptModule):
#     __constants__ = ['data_range', 'use_padding', 'eps']
#
#     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
#                  levels=None, eps=1e-8):
#         '''
#         class for ms-ssim
#         :param window_size: the size of gauss kernel
#         :param window_sigma: sigma of normal distribution
#         :param data_range: value range of input images. (usually 1.0 or 255)
#         :param channel: input channels
#         :param use_padding: padding image before conv
#         :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
#         :param levels: number of downsampling
#         :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
#         '''
#         super().__init__()
#         assert window_size % 2 == 1, 'Window size must be odd.'
#         self.data_range = data_range
#         self.use_padding = use_padding
#         self.eps = eps
#
#         window = create_window(window_size, window_sigma, channel)
#         self.register_buffer('window', window)
#
#         if weights is None:
#             weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
#         weights = torch.tensor(weights, dtype=torch.float)
#
#         if levels is not None:
#             weights = weights[:levels]
#             weights = weights / weights.sum()
#
#         self.register_buffer('weights', weights)
#
#     @torch.jit.script_method
#     def forward(self, X, Y):
#         return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
#                        use_padding=self.use_padding, eps=self.eps)
#
#
# if __name__ == '__main__':
#     print('Simple Test')
#     im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
#     img1 = im / 255
#     img2 = img1 * 0.5
#
#     losser = SSIM(data_range=1.).cuda()
#     loss = losser(img1, img2).mean()
#
#     losser2 = MS_SSIM(data_range=1.).cuda()
#     loss2 = losser2(img1, img2).mean()
#
#     print(loss.item())
#     print(loss2.item())
#
# if __name__ == '__main__':
#     print('Training Test')
#     import cv2
#     import torch.optim
#     import numpy as np
#     import imageio
#     import time
#
#     out_test_video = False
#     # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
#     video_use_gif = False
#
#     im = cv2.imread('test_img1.jpg', 1)
#     t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
#
#     if out_test_video:
#         if video_use_gif:
#             fps = 0.5
#             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
#             suffix = '.gif'
#         else:
#             fps = 5
#             out_wh = (im.shape[1], im.shape[0])
#             suffix = '.mkv'
#         video_last_time = time.perf_counter()
#         video = imageio.get_writer('ssim_test' + suffix, fps=fps)
#
#     # 测试ssim
#     print('Training SSIM')
#     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
#     rand_im.requires_grad = True
#     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
#     losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
#     ssim_score = 0
#     while ssim_score < 0.999:
#         optim.zero_grad()
#         loss = losser(rand_im, t_im)
#         (-loss).sum().backward()
#         ssim_score = loss.item()
#         optim.step()
#         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
#         r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
#
#         if out_test_video:
#             if time.perf_counter() - video_last_time > 1. / fps:
#                 video_last_time = time.perf_counter()
#                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
#                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
#                 if isinstance(out_frame, cv2.UMat):
#                     out_frame = out_frame.get()
#                 video.append_data(out_frame)
#
#         cv2.imshow('ssim', r_im)
#         cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
#         cv2.waitKey(1)
#
#     if out_test_video:
#         video.close()
#
#     # 测试ms_ssim
#     if out_test_video:
#         if video_use_gif:
#             fps = 0.5
#             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
#             suffix = '.gif'
#         else:
#             fps = 5
#             out_wh = (im.shape[1], im.shape[0])
#             suffix = '.mkv'
#         video_last_time = time.perf_counter()
#         video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
#
#     print('Training MS_SSIM')
#     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
#     rand_im.requires_grad = True
#     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
#     losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
#     ssim_score = 0
#     while ssim_score < 0.999:
#         optim.zero_grad()
#         loss = losser(rand_im, t_im)
#         (-loss).sum().backward()
#         ssim_score = loss.item()
#         optim.step()
#         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
#         r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
#
#         if out_test_video:
#             if time.perf_counter() - video_last_time > 1. / fps:
#                 video_last_time = time.perf_counter()
#                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
#                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
#                 if isinstance(out_frame, cv2.UMat):
#                     out_frame = out_frame.get()
#                 video.append_data(out_frame)
#
#         cv2.imshow('ms_ssim', r_im)
#         cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
#         cv2.waitKey(1)
#
#     if out_test_video:
#         video.close()

"""
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
"""

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp


def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


window = None


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    global window
    if window is None:
        window = create_window(window_size, channel)
        if img1.is_cuda:
            window = window.cuda(img1.get_device())
        window = window.type_as(img1)
    return _ssim(img1, img2, window, window_size, channel, size_average)